Skip to content

Commit 5ab9d53

Browse files
kingjunorohitgr7carmocca
authored
Remove the deprecated on_{train,val,test,predict}_dataloader hooks (#13033)
Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent d243617 commit 5ab9d53

File tree

12 files changed

+12
-183
lines changed

12 files changed

+12
-183
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
115115
- Removed the deprecated `checkpoint_callback` argument from the `Trainer` constructor ([#13027](https://github.com/PyTorchLightning/pytorch-lightning/pull/13027))
116116

117117

118+
- Removed the deprecated `on_{train,val,test,predict}_dataloader` hooks from the `LightningModule` and `LightningDataModule` ([#13033](https://github.com/PyTorchLightning/pytorch-lightning/pull/13033))
119+
120+
118121
- Removed the deprecated `TestTubeLogger` ([#12859](https://github.com/PyTorchLightning/pytorch-lightning/pull/12859))
119122

120123

docs/source/common/lightning_module.rst

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,30 +1609,6 @@ predict_dataloader
16091609
.. automethod:: pytorch_lightning.core.lightning.LightningModule.predict_dataloader
16101610
:noindex:
16111611

1612-
on_train_dataloader
1613-
~~~~~~~~~~~~~~~~~~~
1614-
1615-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_dataloader
1616-
:noindex:
1617-
1618-
on_val_dataloader
1619-
~~~~~~~~~~~~~~~~~
1620-
1621-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_val_dataloader
1622-
:noindex:
1623-
1624-
on_test_dataloader
1625-
~~~~~~~~~~~~~~~~~~
1626-
1627-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_dataloader
1628-
:noindex:
1629-
1630-
on_predict_dataloader
1631-
~~~~~~~~~~~~~~~~~~~~~
1632-
1633-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_dataloader
1634-
:noindex:
1635-
16361612
transfer_batch_to_device
16371613
~~~~~~~~~~~~~~~~~~~~~~~~
16381614

docs/source/data/datamodule.rst

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -360,30 +360,6 @@ state_dict
360360
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.state_dict
361361
:noindex:
362362

363-
on_train_dataloader
364-
===================
365-
366-
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_train_dataloader
367-
:noindex:
368-
369-
on_val_dataloader
370-
=================
371-
372-
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_val_dataloader
373-
:noindex:
374-
375-
on_test_dataloader
376-
==================
377-
378-
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_test_dataloader
379-
:noindex:
380-
381-
on_predict_dataloader
382-
=====================
383-
384-
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_predict_dataloader
385-
:noindex:
386-
387363
teardown
388364
========
389365

pytorch_lightning/core/hooks.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -628,38 +628,6 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
628628
"`predict_dataloader` must be implemented to be used with the Lightning Trainer"
629629
)
630630

631-
def on_train_dataloader(self) -> None:
632-
"""Called before requesting the train dataloader.
633-
634-
.. deprecated:: v1.5
635-
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
636-
Please use :meth:`train_dataloader()` directly.
637-
"""
638-
639-
def on_val_dataloader(self) -> None:
640-
"""Called before requesting the val dataloader.
641-
642-
.. deprecated:: v1.5
643-
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
644-
Please use :meth:`val_dataloader()` directly.
645-
"""
646-
647-
def on_test_dataloader(self) -> None:
648-
"""Called before requesting the test dataloader.
649-
650-
.. deprecated:: v1.5
651-
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
652-
Please use :meth:`test_dataloader()` directly.
653-
"""
654-
655-
def on_predict_dataloader(self) -> None:
656-
"""Called before requesting the predict dataloader.
657-
658-
.. deprecated:: v1.5
659-
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
660-
Please use :meth:`predict_dataloader()` directly.
661-
"""
662-
663631
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
664632
"""Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom
665633
data structure.

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
9494
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
9595
)
9696

97-
# ----------------------------------------------
98-
# verify model does not have on_train_dataloader
99-
# ----------------------------------------------
100-
has_on_train_dataloader = is_overridden("on_train_dataloader", model)
101-
if has_on_train_dataloader:
102-
rank_zero_deprecation(
103-
"Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
104-
" Please use `train_dataloader()` directly."
105-
)
106-
10797
trainer.overridden_optimizer_step = is_overridden("optimizer_step", model)
10898
trainer.overridden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model)
10999
automatic_optimization = model.automatic_optimization
@@ -129,16 +119,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
129119
if has_val_step and not has_val_loader:
130120
rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
131121

132-
# ----------------------------------------------
133-
# verify model does not have on_val_dataloader
134-
# ----------------------------------------------
135-
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
136-
if has_on_val_dataloader:
137-
rank_zero_deprecation(
138-
"Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
139-
" Please use `val_dataloader()` directly."
140-
)
141-
142122

143123
def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
144124
r"""
@@ -158,20 +138,9 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning
158138
loader_name = f"{stage}_dataloader"
159139
step_name = "validation_step" if stage == "val" else f"{stage}_step"
160140
trainer_method = "validate" if stage == "val" else stage
161-
on_eval_hook = f"on_{loader_name}"
162141

163142
has_loader = getattr(trainer._data_connector, f"_{stage}_dataloader_source").is_defined()
164143
has_step = is_overridden(step_name, model)
165-
has_on_eval_dataloader = is_overridden(on_eval_hook, model)
166-
167-
# ----------------------------------------------
168-
# verify model does not have on_eval_dataloader
169-
# ----------------------------------------------
170-
if has_on_eval_dataloader:
171-
rank_zero_deprecation(
172-
f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will"
173-
f" be removed in v1.7.0. Please use `{loader_name}()` directly."
174-
)
175144

176145
# -----------------------------------
177146
# verify model has an eval_dataloader

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _reset_eval_dataloader(
348348
assert mode.evaluating or mode == RunningStage.PREDICTING
349349

350350
# always get the loaders first so we can count how many there are
351-
dataloaders = self._request_dataloader(mode, model=model)
351+
dataloaders = self._request_dataloader(mode)
352352

353353
if self.trainer.overfit_batches > 0:
354354
dataloaders = self._resolve_overfit_batches(dataloaders, mode)
@@ -423,18 +423,14 @@ def _reset_eval_dataloader(
423423

424424
return loader_num_batches, dataloaders
425425

426-
def _request_dataloader(
427-
self, stage: RunningStage, model: Optional["pl.LightningModule"] = None
428-
) -> Union[DataLoader, List[DataLoader]]:
426+
def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[DataLoader]]:
429427
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.
430428
431429
Returns:
432430
The requested dataloader
433431
"""
434432
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")
435433

436-
hook = f"{stage.dataloader_prefix}_dataloader"
437-
self.trainer._call_lightning_module_hook("on_" + hook, pl_module=model)
438434
with _replace_dataloader_init_method():
439435
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
440436
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning

pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,8 @@ class _LogOptions(TypedDict):
162162
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
163163
),
164164
"configure_optimizers": None,
165-
"on_train_dataloader": None,
166165
"train_dataloader": None,
167-
"on_val_dataloader": None,
168166
"val_dataloader": None,
169-
"on_test_dataloader": None,
170167
"test_dataloader": None,
171168
"prepare_data": None,
172169
"configure_callbacks": None,

pytorch_lightning/trainer/data_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ def request_dataloader(
5959
rank_zero_deprecation(
6060
"`TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 and will be removed in v1.8."
6161
)
62-
return self._data_connector._request_dataloader(stage, model)
62+
return self._data_connector._request_dataloader(stage)

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,13 +1805,13 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
18051805
model: The ``LightningModule`` if calling this outside of the trainer scope.
18061806
"""
18071807
source = self._data_connector._train_dataloader_source
1808-
pl_module = self.lightning_module or model
1808+
pl_module = model or self.lightning_module
18091809
has_step = is_overridden("training_step", pl_module)
18101810
enable_training = self.limit_train_batches > 0
18111811
if not (source.is_defined() and has_step and enable_training):
18121812
return
18131813

1814-
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)
1814+
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING)
18151815

18161816
if self.overfit_batches > 0:
18171817
self.train_dataloader = self._data_connector._resolve_overfit_batches(

tests/deprecated_api/test_remove_1-7.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -38,52 +38,6 @@
3838
from tests.plugins.environments.test_lsf_environment import _make_rankfile
3939

4040

41-
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
42-
class CustomBoringModel(BoringModel):
43-
def on_train_dataloader(self):
44-
print("on_train_dataloader")
45-
46-
def on_val_dataloader(self):
47-
print("on_val_dataloader")
48-
49-
def on_test_dataloader(self):
50-
print("on_test_dataloader")
51-
52-
def on_predict_dataloader(self):
53-
print("on_predict_dataloader")
54-
55-
def _run(model, task="fit"):
56-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
57-
getattr(trainer, task)(model)
58-
59-
model = CustomBoringModel()
60-
61-
with pytest.deprecated_call(
62-
match="Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
63-
):
64-
_run(model, "fit")
65-
66-
with pytest.deprecated_call(
67-
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
68-
):
69-
_run(model, "fit")
70-
71-
with pytest.deprecated_call(
72-
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
73-
):
74-
_run(model, "validate")
75-
76-
with pytest.deprecated_call(
77-
match="Method `on_test_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
78-
):
79-
_run(model, "test")
80-
81-
with pytest.deprecated_call(
82-
match="Method `on_predict_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
83-
):
84-
_run(model, "predict")
85-
86-
8741
def test_v1_7_0_on_interrupt(tmpdir):
8842
class HandleInterruptCallback(Callback):
8943
def on_keyboard_interrupt(self, trainer, pl_module):

0 commit comments

Comments
 (0)