diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cd91612a6aea..225fced44ec4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -442,6 +442,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks in favor of `on_fit_start` ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) +- Deprecated `LightningModule.on_pretrain_routine_start` and `LightningModule.on_pretrain_routine_end` hooks in favor of `on_fit_start` ([#12122](https://github.com/PyTorchLightning/pytorch-lightning/pull/12122)) + + - Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 136e3e98164ab..dc10f235ceb39 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1225,9 +1225,6 @@ for more information. setup("fit") configure_optimizers() - on_pretrain_routine_start() - on_pretrain_routine_end() - # the sanity check runs here on_train_start() @@ -1391,18 +1388,6 @@ on_validation_end .. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_end :noindex: -on_pretrain_routine_start -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_start - :noindex: - -on_pretrain_routine_end -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_end - :noindex: - on_test_batch_start ~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index d321b151a27dc..1f7f5a82a9b86 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -68,6 +68,10 @@ def on_pretrain_routine_start(self) -> None: - pretrain_routine start - pretrain_routine end - training_start + + .. deprecated:: v1.6 + :meth:`on_pretrain_routine_start` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on_fit_start`` instead. """ def on_pretrain_routine_end(self) -> None: @@ -77,6 +81,10 @@ def on_pretrain_routine_end(self) -> None: - pretrain_routine start - pretrain_routine end - training_start + + .. deprecated:: v1.6 + :meth:`on_pretrain_routine_end` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on_fit_start`` instead. """ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 2e98ab76b3ac5..1285596301383 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -60,6 +60,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: _check_on_epoch_start_end(model) # TODO: Delete CheckpointHooks off PrecisionPlugin in v1.8 _check_precision_plugin_checkpoint_hooks(trainer) + # TODO: Delete on_pretrain_routine_start/end hooks in v1.8 + _check_on_pretrain_routine(model) def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: @@ -306,6 +308,16 @@ def _check_on_epoch_start_end(model: "pl.LightningModule") -> None: ) +def _check_on_pretrain_routine(model: "pl.LightningModule") -> None: + hooks = (("on_pretrain_routine_start", "on_fit_start"), ("on_pretrain_routine_end", "on_fit_start")) + for hook, alternative_hook in hooks: + if is_overridden(hook, model): + rank_zero_deprecation( + f"The `LightningModule.{hook}` hook was deprecated in v1.6 and" + f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead." + ) + + def _check_dl_idx_in_on_train_batch_hooks(model: "pl.LightningModule") -> None: for hook in ("on_train_batch_start", "on_train_batch_end"): if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True): diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 06d901bb48cff..bf24f81b8edc4 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -451,6 +451,37 @@ def on_epoch_end(self, *args, **kwargs): trainer.fit(model) +def test_v1_8_0_remove_on_pretrain_routine_start_end_lightning_module(tmpdir): + class CustomModel(BoringModel): + def on_pretrain_routine_start(self, *args, **kwargs): + print("foo") + + model = CustomModel() + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `LightningModule.on_pretrain_routine_start` hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.fit(model) + + class CustomModel(BoringModel): + def on_pretrain_routine_end(self, *args, **kwargs): + print("foo") + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + ) + + model = CustomModel() + with pytest.deprecated_call( + match="The `LightningModule.on_pretrain_routine_end` hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.fit(model) + + def test_v1_8_0_rank_zero_imports(): import warnings diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 53838691a2efb..c04e36bbc09bd 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -28,7 +28,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.trainer.states import TrainerFn from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -253,7 +253,7 @@ def test_correct_step_and_epoch(tmpdir): assert trainer.global_step == 0 class TestModel(BoringModel): - def on_pretrain_routine_end(self) -> None: + def on_train_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs # TODO(@carmocca): should not need `+1` assert self.trainer.global_step == first_max_epochs * train_batches + 1 @@ -610,26 +610,18 @@ def test_dp_resume(tmpdir): class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() - self.on_pretrain_routine_end_called = False + self.on_train_start_called = False - # set the epoch start hook so we can predict before the model does the full training - def on_pretrain_routine_end(self): + def on_validation_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 - - # if model and state loaded correctly, predictions will be good even though we - # haven't trained with the new loaded model - new_trainer.state.stage = RunningStage.VALIDATING - - dataloader = dm.train_dataloader() + dataloader = dm.val_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) - self.on_pretrain_routine_end_called = True # new model model = CustomModel() - # fit new model which should load hpc weights - new_trainer.fit(model, datamodule=dm) - assert model.on_pretrain_routine_end_called + # validate new model which should load hpc weights + new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path) # test freeze on gpu model.freeze() diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index edd63057620ef..df0a668e4e8c0 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -104,7 +104,7 @@ def __init__(self, num_workers): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - def on_pretrain_routine_start(self): + def on_fit_start(self): self._resout = StringIO() self.ctx = redirect_stderr(self._resout) self.ctx.__enter__()