Skip to content
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))

- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)

-

### Removed

Expand Down
47 changes: 42 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn


class ModelHooks:
Expand Down Expand Up @@ -681,16 +682,52 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
"""

def on_train_dataloader(self) -> None:
"""Called before requesting the train dataloader."""
"""Called before requesting the train dataloader.

.. deprecated:: v1.5
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`train_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `train_dataloader()` directly."
)

def on_val_dataloader(self) -> None:
"""Called before requesting the val dataloader."""
"""Called before requesting the val dataloader.

.. deprecated:: v1.5
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`val_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)

def on_test_dataloader(self) -> None:
"""Called before requesting the test dataloader."""
"""Called before requesting the test dataloader.

.. deprecated:: v1.5
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`test_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `test_dataloader()` directly."
)

def on_predict_dataloader(self) -> None:
"""Called before requesting the predict dataloader."""
"""Called before requesting the predict dataloader.

.. deprecated:: v1.5
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
Please use :meth:`predict_dataloader()` directly.
"""
rank_zero_deprecation(
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `predict_dataloader()` directly."
)

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
"""
Expand Down
50 changes: 49 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -75,6 +75,25 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
)

# ----------------------------------------------
# verify model does not have
# - on_train_dataloader
# - on_val_dataloader
# ----------------------------------------------
has_on_train_dataloader = is_overridden("on_train_dataloader", model)
if has_on_train_dataloader:
rank_zero_deprecation(
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `train_dataloader()` directly."
)

has_on_val_dataloader = is_overridden("on_val_dataloader", model)
if has_on_val_dataloader:
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)

trainer = self.trainer

trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
Expand Down Expand Up @@ -102,10 +121,39 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s
if has_step and not has_loader:
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")

# ----------------------------------------------
# verify model does not have
# - on_val_dataloader
# - on_test_dataloader
# ----------------------------------------------
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
if has_on_val_dataloader:
rank_zero_deprecation(
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `val_dataloader()` directly."
)

has_on_test_dataloader = is_overridden("on_test_dataloader", model)
if has_on_test_dataloader:
rank_zero_deprecation(
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `test_dataloader()` directly."
)

def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None:
has_predict_dataloader = is_overridden("predict_dataloader", model)
if not has_predict_dataloader:
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
# ----------------------------------------------
# verify model does not have
# - on_predict_dataloader
# ----------------------------------------------
has_on_predict_dataloader = is_overridden("on_predict_dataloader", model)
if has_on_predict_dataloader:
rank_zero_deprecation(
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
" Please use `predict_dataloader()` directly."
)

def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> None:
"""Raise Misconfiguration exception since these hooks are not supported in DP mode"""
Expand Down
21 changes: 21 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
_ = Trainer(prepare_data_per_node=False)


def test_v1_7_0_deprecated_on_train_dataloader(tmpdir):

model = BoringModel()
with pytest.deprecated_call(
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_train_dataloader()
with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_val_dataloader()
with pytest.deprecated_call(
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_test_dataloader()
with pytest.deprecated_call(
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
model.on_predict_dataloader()


@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
def test_v1_7_0_test_tube_logger(_, tmpdir):
with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,3 +1868,30 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
trainer.test(model)
with pytest.raises(Exception, match=r"Error during predict"), patch("pytorch_lightning.Trainer._on_exception"):
trainer.predict(model, model.val_dataloader(), return_predictions=False)


def test_overridden_on_dataloaders(tmpdir):
model = BoringModel()
with pytest.deprecated_call(
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)

with pytest.deprecated_call(
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.validate(model)

with pytest.deprecated_call(
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.test(model)

with pytest.deprecated_call(
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.predict(model)