Skip to content

Commit dc3391b

Browse files
kaushikb11rohitgr7
andauthored
Remove deprecation warnings being called for on_{task}_dataloader (#9279)
* Avoid deprecation warnings being called when hooks are not implemented * Update tests & changelog * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]>
1 parent 912fd31 commit dc3391b

File tree

4 files changed

+34
-50
lines changed

4 files changed

+34
-50
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
268268
- Removed deprecated properties `DeepSpeedPlugin.cpu_offload*` in favor of `offload_optimizer`, `offload_parameters` and `pin_memory` ([#9244](https://github.com/PyTorchLightning/pytorch-lightning/pull/9244))
269269

270270

271+
- Removed deprecation warnings being called for `on_{task}_dataloader` ([#9279](https://github.com/PyTorchLightning/pytorch-lightning/pull/9279))
272+
273+
271274
### Fixed
272275

273276
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (

pytorch_lightning/core/hooks.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from pytorch_lightning.utilities import move_data_to_device
2222
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
23-
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
2423

2524

2625
class ModelHooks:
@@ -691,10 +690,6 @@ def on_train_dataloader(self) -> None:
691690
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
692691
Please use :meth:`train_dataloader()` directly.
693692
"""
694-
rank_zero_deprecation(
695-
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
696-
" Please use `train_dataloader()` directly."
697-
)
698693

699694
def on_val_dataloader(self) -> None:
700695
"""Called before requesting the val dataloader.
@@ -703,10 +698,6 @@ def on_val_dataloader(self) -> None:
703698
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
704699
Please use :meth:`val_dataloader()` directly.
705700
"""
706-
rank_zero_deprecation(
707-
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
708-
" Please use `val_dataloader()` directly."
709-
)
710701

711702
def on_test_dataloader(self) -> None:
712703
"""Called before requesting the test dataloader.
@@ -715,10 +706,6 @@ def on_test_dataloader(self) -> None:
715706
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
716707
Please use :meth:`test_dataloader()` directly.
717708
"""
718-
rank_zero_deprecation(
719-
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
720-
" Please use `test_dataloader()` directly."
721-
)
722709

723710
def on_predict_dataloader(self) -> None:
724711
"""Called before requesting the predict dataloader.
@@ -727,10 +714,6 @@ def on_predict_dataloader(self) -> None:
727714
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
728715
Please use :meth:`predict_dataloader()` directly.
729716
"""
730-
rank_zero_deprecation(
731-
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
732-
" Please use `predict_dataloader()` directly."
733-
)
734717

735718
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
736719
"""

tests/deprecated_api/test_remove_1-7.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,25 +91,50 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
9191
_ = Trainer(prepare_data_per_node=False)
9292

9393

94-
def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):
94+
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
95+
class CustomBoringModel(BoringModel):
96+
def on_train_dataloader(self):
97+
print("on_train_dataloader")
98+
99+
def on_val_dataloader(self):
100+
print("on_val_dataloader")
101+
102+
def on_test_dataloader(self):
103+
print("on_test_dataloader")
104+
105+
def on_predict_dataloader(self):
106+
print("on_predict_dataloader")
107+
108+
def _run(model, task="fit"):
109+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
110+
getattr(trainer, task)(model)
111+
112+
model = CustomBoringModel()
95113

96-
model = BoringModel()
97114
with pytest.deprecated_call(
98115
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
99116
):
100-
model.on_train_dataloader()
117+
_run(model, "fit")
118+
101119
with pytest.deprecated_call(
102120
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
103121
):
104-
model.on_val_dataloader()
122+
_run(model, "fit")
123+
124+
with pytest.deprecated_call(
125+
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
126+
):
127+
_run(model, "validate")
128+
105129
with pytest.deprecated_call(
106130
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
107131
):
108-
model.on_test_dataloader()
132+
_run(model, "test")
133+
109134
with pytest.deprecated_call(
110135
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
111136
):
112-
model.on_predict_dataloader()
137+
_run(model, "predict")
113138

114139

115140
@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")

tests/trainer/test_trainer.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,30 +1953,3 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
19531953
) as exception_hook:
19541954
trainer.predict(model, model.val_dataloader(), return_predictions=False)
19551955
exception_hook.assert_called()
1956-
1957-
1958-
def test_overridden_on_dataloaders(tmpdir):
1959-
model = BoringModel()
1960-
with pytest.deprecated_call(
1961-
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1962-
):
1963-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1964-
trainer.fit(model)
1965-
1966-
with pytest.deprecated_call(
1967-
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1968-
):
1969-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1970-
trainer.validate(model)
1971-
1972-
with pytest.deprecated_call(
1973-
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1974-
):
1975-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1976-
trainer.test(model)
1977-
1978-
with pytest.deprecated_call(
1979-
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
1980-
):
1981-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
1982-
trainer.predict(model)

0 commit comments

Comments
 (0)