Skip to content

Commit 29474af

Browse files
update to address reviewers' comments
1 parent 0b5254b commit 29474af

File tree

7 files changed

+39
-19
lines changed

7 files changed

+39
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
157157

158158
- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))
159159

160-
- Deprecated `on_{train/val/test/predict}_trainer()` from `DataHooks` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
160+
- Deprecated `on_{train/val/test/predict}_trainer()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
161161

162162
-
163163

pytorch_lightning/core/hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def on_train_dataloader(self) -> None:
689689
"""
690690
rank_zero_deprecation(
691691
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
692-
" Please use `train_dataloader()` directly. "
692+
" Please use `train_dataloader()` directly."
693693
)
694694

695695
def on_val_dataloader(self) -> None:
@@ -701,7 +701,7 @@ def on_val_dataloader(self) -> None:
701701
"""
702702
rank_zero_deprecation(
703703
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
704-
" Please use ``val_dataloader()`` directly. "
704+
" Please use `val_dataloader()` directly."
705705
)
706706

707707
def on_test_dataloader(self) -> None:
@@ -713,7 +713,7 @@ def on_test_dataloader(self) -> None:
713713
"""
714714
rank_zero_deprecation(
715715
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
716-
" Please use ``test_dataloader()`` directly. "
716+
" Please use `test_dataloader()` directly."
717717
)
718718

719719
def on_predict_dataloader(self) -> None:
@@ -725,7 +725,7 @@ def on_predict_dataloader(self) -> None:
725725
"""
726726
rank_zero_deprecation(
727727
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
728-
" Please use ``predict_dataloader()`` directly. "
728+
" Please use `predict_dataloader()` directly."
729729
)
730730

731731
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s
120120
rank_zero_warn(f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop")
121121
if has_step and not has_loader:
122122
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
123+
123124
# ----------------------------------------------
124125
# verify model does not have
125126
# - on_val_dataloader

pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,11 @@ class FxValidator:
7878
validation_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
7979
test_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
8080
configure_optimizers=None,
81+
on_train_dataloader=None,
8182
train_dataloader=None,
83+
on_val_dataloader=None,
8284
val_dataloader=None,
85+
on_test_dataloader=None,
8386
test_dataloader=None,
8487
prepare_data=None,
8588
configure_callbacks=None,

tests/models/test_hooks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def training_step(self, batch, batch_idx):
505505
dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
506506
dict(name="on_pretrain_routine_end"),
507507
dict(name="Callback.on_sanity_check_start", args=(trainer, model)),
508+
dict(name="on_val_dataloader"),
508509
dict(name="val_dataloader"),
509510
dict(name="train", args=(False,)),
510511
dict(name="on_validation_model_eval"),
@@ -519,6 +520,7 @@ def training_step(self, batch, batch_idx):
519520
dict(name="Callback.on_sanity_check_end", args=(trainer, model)),
520521
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
521522
dict(name="train", args=(True,)),
523+
dict(name="on_train_dataloader"),
522524
dict(name="train_dataloader"),
523525
dict(name="Callback.on_train_start", args=(trainer, model)),
524526
dict(name="on_train_start"),
@@ -631,8 +633,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
631633
dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
632634
dict(name="on_pretrain_routine_end"),
633635
dict(name="train", args=(True,)),
636+
dict(name="on_train_dataloader"),
634637
dict(name="train_dataloader"),
635638
# even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
639+
dict(name="on_val_dataloader"),
636640
dict(name="val_dataloader"),
637641
dict(name="Callback.on_train_start", args=(trainer, model)),
638642
dict(name="on_train_start"),
@@ -704,6 +708,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
704708
dict(name="setup", kwargs=dict(stage=verb)),
705709
dict(name="configure_sharded_model"),
706710
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
711+
dict(name=f"on_{dataloader}_dataloader"),
707712
dict(name=f"{dataloader}_dataloader"),
708713
*(hooks if batches else []),
709714
dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage=verb)),
@@ -735,6 +740,7 @@ def test_trainer_model_hook_system_predict(tmpdir):
735740
dict(name="setup", kwargs=dict(stage="predict")),
736741
dict(name="configure_sharded_model"),
737742
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
743+
dict(name="on_predict_dataloader"),
738744
dict(name="predict_dataloader"),
739745
dict(name="train", args=(False,)),
740746
dict(name="on_predict_model_eval"),

tests/trainer/logging_/test_logger_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def test_fx_validator_integration(tmpdir):
218218
"on_fit_start": "You can't",
219219
"on_pretrain_routine_start": "You can't",
220220
"on_pretrain_routine_end": "You can't",
221+
"on_train_dataloader": "You can't",
221222
"train_dataloader": "You can't",
223+
"on_val_dataloader": "You can't",
222224
"val_dataloader": "You can't",
223225
"on_validation_end": "You can't",
224226
"on_train_end": "You can't",
@@ -252,6 +254,7 @@ def test_fx_validator_integration(tmpdir):
252254
{
253255
# `lightning_module` ref is now present from the `fit` call
254256
"on_before_accelerator_backend_setup": "You can't",
257+
"on_test_dataloader": "You can't",
255258
"test_dataloader": "You can't",
256259
"on_test_model_eval": "You can't",
257260
"on_test_end": "You can't",
@@ -262,6 +265,7 @@ def test_fx_validator_integration(tmpdir):
262265
not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported})
263266
not_supported.update(
264267
{
268+
"on_test_dataloader": "You can't",
265269
"predict_dataloader": "ResultCollection` is not registered yet",
266270
"on_predict_model_eval": "ResultCollection` is not registered yet",
267271
"on_predict_start": "ResultCollection` is not registered yet",

tests/trainer/test_dataloaders.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pytorch_lightning.callbacks import ModelCheckpoint
2828
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
30-
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
3130
from tests.base import EvalModelTemplate
3231
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen
3332
from tests.helpers.runif import RunIf
@@ -1471,28 +1470,35 @@ def __iter__(self):
14711470
def __next__(self):
14721471
return next(self._iter)
14731472

1473+
class DataLoaderFunc:
1474+
def __init__(self, loader):
1475+
self.loader = loader
1476+
1477+
def __call__(self):
1478+
return self.loader
1479+
14741480
class TestModel(BoringModel):
14751481
def __init__(self):
14761482
super().__init__()
1477-
self.train_dataloader_called = False
1483+
self.on_train_dataloader_called = False
14781484
self.on_train_batch_start_called = False
1479-
self.val_dataloader_called = False
1485+
self.on_val_dataloader_called = False
14801486
self.on_val_batch_start_called = False
14811487

1482-
def train_dataloader(self) -> TRAIN_DATALOADERS:
1483-
loader = super().train_dataloader()
1484-
self.train_dataloader_called = True
1485-
return DataLoaderWrapper(loader)
1486-
1487-
def val_dataloader(self) -> EVAL_DATALOADERS:
1488-
loader = super().val_dataloader()
1489-
self.val_dataloader_called = True
1490-
return DataLoaderWrapper(loader)
1488+
def on_train_dataloader(self) -> None:
1489+
loader = self.train_dataloader()
1490+
self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader))
1491+
self.on_train_dataloader_called = True
14911492

14921493
def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
14931494
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
14941495
self.on_train_batch_start_called = True
14951496

1497+
def on_val_dataloader(self) -> None:
1498+
loader = self.val_dataloader()
1499+
self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader))
1500+
self.on_val_dataloader_called = True
1501+
14961502
def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
14971503
assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper)
14981504
self.on_val_batch_start_called = True
@@ -1501,7 +1507,7 @@ def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int)
15011507
model = TestModel()
15021508
trainer.fit(model)
15031509
trainer.test(model)
1504-
assert model.train_dataloader_called
1505-
assert model.val_dataloader_called
1510+
assert model.on_train_dataloader_called
15061511
assert model.on_train_batch_start_called
1512+
assert model.on_val_dataloader_called
15071513
assert model.on_val_batch_start_called

0 commit comments

Comments
 (0)