From 6bf05302cc8f0a1ba628d83fa873dc11cdf98f17 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 25 Feb 2022 16:35:46 -0800 Subject: [PATCH 01/10] Deprecate `LightningModule.on_pretrain_routine_{start/end}` --- CHANGELOG.md | 3 ++ docs/source/common/lightning_module.rst | 3 -- pytorch_lightning/core/hooks.py | 8 +++++ .../trainer/configuration_validator.py | 12 +++++++ tests/deprecated_api/test_remove_1-8.py | 31 +++++++++++++++++++ 5 files changed, 54 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cd91612a6aea..7b80d609b1324 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -442,6 +442,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks in favor of `on_fit_start` ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) +- Deprecated `LightningModule.on_pretrain_routine_start` and `LightningModule.on_pretrain_routine_end` hooks in favor of `on_fit_start` ([#]()) + + - Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 136e3e98164ab..cb0f625b4f5e3 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1225,9 +1225,6 @@ for more information. setup("fit") configure_optimizers() - on_pretrain_routine_start() - on_pretrain_routine_end() - # the sanity check runs here on_train_start() diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index d321b151a27dc..1f7f5a82a9b86 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -68,6 +68,10 @@ def on_pretrain_routine_start(self) -> None: - pretrain_routine start - pretrain_routine end - training_start + + .. deprecated:: v1.6 + :meth:`on_pretrain_routine_start` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on_fit_start`` instead. """ def on_pretrain_routine_end(self) -> None: @@ -77,6 +81,10 @@ def on_pretrain_routine_end(self) -> None: - pretrain_routine start - pretrain_routine end - training_start + + .. deprecated:: v1.6 + :meth:`on_pretrain_routine_end` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on_fit_start`` instead. """ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 2e98ab76b3ac5..1285596301383 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -60,6 +60,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: _check_on_epoch_start_end(model) # TODO: Delete CheckpointHooks off PrecisionPlugin in v1.8 _check_precision_plugin_checkpoint_hooks(trainer) + # TODO: Delete on_pretrain_routine_start/end hooks in v1.8 + _check_on_pretrain_routine(model) def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: @@ -306,6 +308,16 @@ def _check_on_epoch_start_end(model: "pl.LightningModule") -> None: ) +def _check_on_pretrain_routine(model: "pl.LightningModule") -> None: + hooks = (("on_pretrain_routine_start", "on_fit_start"), ("on_pretrain_routine_end", "on_fit_start")) + for hook, alternative_hook in hooks: + if is_overridden(hook, model): + rank_zero_deprecation( + f"The `LightningModule.{hook}` hook was deprecated in v1.6 and" + f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead." + ) + + def _check_dl_idx_in_on_train_batch_hooks(model: "pl.LightningModule") -> None: for hook in ("on_train_batch_start", "on_train_batch_end"): if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True): diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 06d901bb48cff..bf24f81b8edc4 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -451,6 +451,37 @@ def on_epoch_end(self, *args, **kwargs): trainer.fit(model) +def test_v1_8_0_remove_on_pretrain_routine_start_end_lightning_module(tmpdir): + class CustomModel(BoringModel): + def on_pretrain_routine_start(self, *args, **kwargs): + print("foo") + + model = CustomModel() + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `LightningModule.on_pretrain_routine_start` hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.fit(model) + + class CustomModel(BoringModel): + def on_pretrain_routine_end(self, *args, **kwargs): + print("foo") + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + ) + + model = CustomModel() + with pytest.deprecated_call( + match="The `LightningModule.on_pretrain_routine_end` hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.fit(model) + + def test_v1_8_0_rank_zero_imports(): import warnings From 0f891284b0308eddaba75a6580811bdcda8deba0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 25 Feb 2022 16:37:07 -0800 Subject: [PATCH 02/10] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b80d609b1324..225fced44ec4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -442,7 +442,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks in favor of `on_fit_start` ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) -- Deprecated `LightningModule.on_pretrain_routine_start` and `LightningModule.on_pretrain_routine_end` hooks in favor of `on_fit_start` ([#]()) +- Deprecated `LightningModule.on_pretrain_routine_start` and `LightningModule.on_pretrain_routine_end` hooks in favor of `on_fit_start` ([#12122](https://github.com/PyTorchLightning/pytorch-lightning/pull/12122)) - Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871)) From b037339125d620fc8e399ebfcb7678028dfdc675 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 25 Feb 2022 16:44:46 -0800 Subject: [PATCH 03/10] updates --- docs/source/common/lightning_module.rst | 12 ------------ tests/trainer/test_data_loading.py | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index cb0f625b4f5e3..dc10f235ceb39 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1388,18 +1388,6 @@ on_validation_end .. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_end :noindex: -on_pretrain_routine_start -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_start - :noindex: - -on_pretrain_routine_end -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_end - :noindex: - on_test_batch_start ~~~~~~~~~~~~~~~~~~~ diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index edd63057620ef..df0a668e4e8c0 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -104,7 +104,7 @@ def __init__(self, num_workers): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - def on_pretrain_routine_start(self): + def on_fit_start(self): self._resout = StringIO() self.ctx = redirect_stderr(self._resout) self.ctx.__enter__() From 90af533cd13477184085bb0be5fbb3a36716fff9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 25 Feb 2022 16:45:47 -0800 Subject: [PATCH 04/10] Update test_data_loading.py --- tests/trainer/test_data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index df0a668e4e8c0..506a037ddf631 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -104,7 +104,7 @@ def __init__(self, num_workers): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - def on_fit_start(self): + def setup(self, stage=None): self._resout = StringIO() self.ctx = redirect_stderr(self._resout) self.ctx.__enter__() From 67ee939e16fb70fed3abdfaaf34ff4b171211bff Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 25 Feb 2022 17:00:40 -0800 Subject: [PATCH 05/10] Update test_restore.py --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 53838691a2efb..b1963a03c6036 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -253,7 +253,7 @@ def test_correct_step_and_epoch(tmpdir): assert trainer.global_step == 0 class TestModel(BoringModel): - def on_pretrain_routine_end(self) -> None: + def on_fit_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs # TODO(@carmocca): should not need `+1` assert self.trainer.global_step == first_max_epochs * train_batches + 1 From aa94db120441d72a992d48036f967173b80d7f65 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 25 Feb 2022 17:16:05 -0800 Subject: [PATCH 06/10] Update test_restore.py --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index b1963a03c6036..303e91ade19aa 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -253,7 +253,7 @@ def test_correct_step_and_epoch(tmpdir): assert trainer.global_step == 0 class TestModel(BoringModel): - def on_fit_start(self) -> None: + def on_train_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs # TODO(@carmocca): should not need `+1` assert self.trainer.global_step == first_max_epochs * train_batches + 1 From 81433a373b60a5cca10ccbfedbe74ecc68e1e8d8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 25 Feb 2022 18:13:13 -0800 Subject: [PATCH 07/10] Update test_restore.py --- tests/models/test_restore.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 303e91ade19aa..df9c9acecf7e7 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -610,10 +610,10 @@ def test_dp_resume(tmpdir): class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() - self.on_pretrain_routine_end_called = False + self.on_train_start_called = False # set the epoch start hook so we can predict before the model does the full training - def on_pretrain_routine_end(self): + def on_train_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we @@ -622,14 +622,14 @@ def on_pretrain_routine_end(self): dataloader = dm.train_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) - self.on_pretrain_routine_end_called = True + self.on_train_start_called = True # new model model = CustomModel() # fit new model which should load hpc weights new_trainer.fit(model, datamodule=dm) - assert model.on_pretrain_routine_end_called + assert model.on_train_start_called # test freeze on gpu model.freeze() From 3045f7b6d861fbbefadf7cb0edd83e6466afcc56 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 28 Feb 2022 23:59:20 -0800 Subject: [PATCH 08/10] Update test_data_loading.py --- tests/trainer/test_data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 506a037ddf631..df0a668e4e8c0 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -104,7 +104,7 @@ def __init__(self, num_workers): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - def setup(self, stage=None): + def on_fit_start(self): self._resout = StringIO() self.ctx = redirect_stderr(self._resout) self.ctx.__enter__() From ce88d4e98356858c74cf465371d4869db20dbb28 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 4 Mar 2022 21:15:31 -0800 Subject: [PATCH 09/10] Update test_restore.py --- tests/models/test_restore.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index df9c9acecf7e7..f02c44f71c07c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -28,7 +28,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.trainer.states import TrainerFn from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -618,10 +618,10 @@ def on_train_start(self): # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model - new_trainer.state.stage = RunningStage.VALIDATING + # new_trainer.state.stage = RunningStage.VALIDATING - dataloader = dm.train_dataloader() - tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) + # dataloader = dm.train_dataloader() + # tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) self.on_train_start_called = True # new model From 921a87080c3d5a3f2cb85f602013a8e9f8f05c64 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 4 Mar 2022 21:38:43 -0800 Subject: [PATCH 10/10] Update test_restore.py --- tests/models/test_restore.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index f02c44f71c07c..c04e36bbc09bd 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -612,24 +612,16 @@ def __init__(self): super().__init__() self.on_train_start_called = False - # set the epoch start hook so we can predict before the model does the full training - def on_train_start(self): + def on_validation_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 - - # if model and state loaded correctly, predictions will be good even though we - # haven't trained with the new loaded model - # new_trainer.state.stage = RunningStage.VALIDATING - - # dataloader = dm.train_dataloader() - # tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) - self.on_train_start_called = True + dataloader = dm.val_dataloader() + tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) # new model model = CustomModel() - # fit new model which should load hpc weights - new_trainer.fit(model, datamodule=dm) - assert model.on_train_start_called + # validate new model which should load hpc weights + new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path) # test freeze on gpu model.freeze()