From 32d7c2bdba254d7b6fb0bff5c7f381ab6a699d6c Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Mon, 7 Feb 2022 16:49:17 +0100 Subject: [PATCH 01/36] init commit --- pytorch_lightning/callbacks/base.py | 16 ++++++++++++++-- pytorch_lightning/callbacks/lambda_function.py | 2 -- pytorch_lightning/callbacks/model_checkpoint.py | 11 +++++------ pytorch_lightning/callbacks/model_summary.py | 2 +- pytorch_lightning/core/hooks.py | 14 ++++++-------- pytorch_lightning/trainer/trainer.py | 9 --------- 6 files changed, 26 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 993073c54a438..228f16c220e98 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -231,10 +231,22 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - """Called when the train ends.""" def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when the pretrain routine begins.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on_fit_start`` or ``setup`` instead. + + Called when the pretrain routine begins. + """ def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when the pretrain routine ends.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on_fit_start`` or ``setup`` instead. + + Called when the pretrain routine ends. + """ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the validation loop begins.""" diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 1813e7d19090f..f38a1fc78048c 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -68,8 +68,6 @@ def __init__( on_batch_end: Optional[Callable] = None, on_train_start: Optional[Callable] = None, on_train_end: Optional[Callable] = None, - on_pretrain_routine_start: Optional[Callable] = None, - on_pretrain_routine_end: Optional[Callable] = None, on_validation_start: Optional[Callable] = None, on_validation_end: Optional[Callable] = None, on_test_start: Optional[Callable] = None, diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 75b1adb10c39b..5a18b850fb4f2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -248,6 +248,11 @@ def state_key(self) -> str: ) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + # When pretrain routine starts we build the ckpt dir on the fly. + self.__resolve_ckpt_dir(trainer) + if trainer.is_global_zero: + self.__warn_if_dir_not_empty(self.dirpath) + # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, # because the attributes are part of the state_key which needs to be fully defined before reloading. if self._save_on_train_epoch_end is None: @@ -255,12 +260,6 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O # validation, then we run after validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """When pretrain routine starts we build the ckpt dir on the fly.""" - self.__resolve_ckpt_dir(trainer) - if trainer.is_global_zero: - self.__warn_if_dir_not_empty(self.dirpath) - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._last_time_checked = time.monotonic() diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index d921c11943acc..78739cb4714a8 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self._max_depth: return None diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index c5aafdecd146f..3b89915aed9bd 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -64,19 +64,17 @@ def on_predict_end(self) -> None: def on_pretrain_routine_start(self) -> None: """Called at the beginning of the pretrain routine (between fit and train start). - - fit - - pretrain_routine start - - pretrain_routine end - - training_start + .. deprecated:: v1.6 + :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + Please use :meth:`on_fit_start` or :meth:`setup` directly. """ def on_pretrain_routine_end(self) -> None: """Called at the end of the pretrain routine (between fit and train start). - - fit - - pretrain_routine start - - pretrain_routine end - - training_start + .. deprecated:: v1.6 + :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + Please use :meth:`on_fit_start` or :meth:`setup` directly. """ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f0b56e35e1bf1..fa3c29f84ff07 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1281,15 +1281,6 @@ def _pre_training_routine(self): # register signals self._signal_connector.register_signal_handlers() - # -------------------------- - # Pre-train - # -------------------------- - self._call_callback_hooks("on_pretrain_routine_start") - self._call_lightning_module_hook("on_pretrain_routine_start") - - self._call_callback_hooks("on_pretrain_routine_end") - self._call_lightning_module_hook("on_pretrain_routine_end") - def _run_train(self) -> None: self._pre_training_routine() From 5d33ba265464d753522aa56d3386d0c9f1e7b884 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Mon, 7 Feb 2022 17:02:50 +0100 Subject: [PATCH 02/36] feedback based changes --- pytorch_lightning/callbacks/lambda_function.py | 2 ++ pytorch_lightning/callbacks/model_summary.py | 2 +- pytorch_lightning/core/hooks.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 9 +++++++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index f38a1fc78048c..1813e7d19090f 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -68,6 +68,8 @@ def __init__( on_batch_end: Optional[Callable] = None, on_train_start: Optional[Callable] = None, on_train_end: Optional[Callable] = None, + on_pretrain_routine_start: Optional[Callable] = None, + on_pretrain_routine_end: Optional[Callable] = None, on_validation_start: Optional[Callable] = None, on_validation_end: Optional[Callable] = None, on_test_start: Optional[Callable] = None, diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index 78739cb4714a8..921c6bb1ae93a 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: if not self._max_depth: return None diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 3b89915aed9bd..f3fba7d5b686e 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -65,7 +65,7 @@ def on_pretrain_routine_start(self) -> None: """Called at the beginning of the pretrain routine (between fit and train start). .. deprecated:: v1.6 - :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + :meth:`on_pretrain_routine_start` is deprecated and will be removed in v1.8.0. Please use :meth:`on_fit_start` or :meth:`setup` directly. """ @@ -73,7 +73,7 @@ def on_pretrain_routine_end(self) -> None: """Called at the end of the pretrain routine (between fit and train start). .. deprecated:: v1.6 - :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + :meth:`on_pretrain_routine_end` is deprecated and will be removed in v1.8.0. Please use :meth:`on_fit_start` or :meth:`setup` directly. """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa3c29f84ff07..f0b56e35e1bf1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1281,6 +1281,15 @@ def _pre_training_routine(self): # register signals self._signal_connector.register_signal_handlers() + # -------------------------- + # Pre-train + # -------------------------- + self._call_callback_hooks("on_pretrain_routine_start") + self._call_lightning_module_hook("on_pretrain_routine_start") + + self._call_callback_hooks("on_pretrain_routine_end") + self._call_lightning_module_hook("on_pretrain_routine_end") + def _run_train(self) -> None: self._pre_training_routine() From 1142c924000dc9114432604ca2df7153322c4a7c Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 13:50:06 +0100 Subject: [PATCH 03/36] adress PR comments --- pytorch_lightning/callbacks/model_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index 921c6bb1ae93a..7901ae6afd7db 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: if not self._max_depth: return None From d3d743e6e052aaafd315d835c3c496ba5c49266d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Feb 2022 12:51:27 +0000 Subject: [PATCH 04/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/callbacks/model_summary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index 7901ae6afd7db..d6272e28ccdd2 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,9 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def on_train_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None + ) -> None: if not self._max_depth: return None From 9014b113c1488accefd123290216bf6fdb7502ce Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 14:47:28 +0100 Subject: [PATCH 05/36] trainer changes --- pytorch_lightning/trainer/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f0b56e35e1bf1..41200e2b3afd9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1596,6 +1596,12 @@ def _call_callback_hooks( with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): fn(self, self.lightning_module, *args, **kwargs) + # TODO: remove if block in v1.8 + if hook_name == "on_pretrain_routine_start": + self.on_train_start(*args, **kwargs) + if hook_name == "on_pretrain_routine_end": + self.on_train_end(*args, **kwargs) + if pl_module: # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name From f0282dbd6139cc562047ac09d73b39f41a85b3e5 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 14:55:11 +0100 Subject: [PATCH 06/36] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb46b7039d87c..c5f5c4791cadf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -343,6 +343,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning` +- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hook ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) ### Removed From efa6ef6e337d71db6860cac894017103b3a622de Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 15:05:23 +0100 Subject: [PATCH 07/36] validations --- pytorch_lightning/callbacks/base.py | 16 ++++++++++++++++ .../trainer/configuration_validator.py | 10 ++++++++++ 2 files changed, 26 insertions(+) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 228f16c220e98..d932f0d05b8ce 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -89,6 +89,22 @@ def on_init_end(self, trainer: "pl.Trainer") -> None: Called when the trainer initialization ends, model has not yet been set. """ + def on_pretrain_routine_start(self, trainer: "pl.Trainer") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. + + Called when the pretrain routine begins. + """ + + def on_pretrain_routine_end(self, trainer: "pl.Trainer") -> None: + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. + + Called when the pretrain routine ends. + """ + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when fit begins.""" diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index cd8754b938e24..c4b5dbe2f44c5 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -348,3 +348,13 @@ def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None: "The `on_configure_sharded_model` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead." ) + +def _check_on_pretrain_routine_start_end(trainer: "pl.Trainer") -> None: + hooks = (["on_pretrain_routine_start", "on_train_start"], ["on_pretrain_routine_end", "on_train_end"]) + + for hook, alternative_hook in hooks: + for callback in trainer.callbacks: + if is_overridden(method_name=hook, instance=callback): + rank_zero_deprecation( + f"The `Callback.{hook}` hook was deprecated in v1.6 and" + f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." From ad1751bd16df21db658b6752c4ac392ed329feeb Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 15:21:54 +0100 Subject: [PATCH 08/36] alternative hooks update --- pytorch_lightning/callbacks/base.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d932f0d05b8ce..56c7187a8b27c 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -89,22 +89,6 @@ def on_init_end(self, trainer: "pl.Trainer") -> None: Called when the trainer initialization ends, model has not yet been set. """ - def on_pretrain_routine_start(self, trainer: "pl.Trainer") -> None: - r""" - .. deprecated:: v1.6 - This callback hook was deprecated in v1.6 and will be removed in v1.8. - - Called when the pretrain routine begins. - """ - - def on_pretrain_routine_end(self, trainer: "pl.Trainer") -> None: - r""" - .. deprecated:: v1.6 - This callback hook was deprecated in v1.6 and will be removed in v1.8. - - Called when the pretrain routine ends. - """ - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when fit begins.""" @@ -250,7 +234,7 @@ def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn r""" .. deprecated:: v1.6 This callback hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on_fit_start`` or ``setup`` instead. + ``on_train_start``. Called when the pretrain routine begins. """ @@ -259,7 +243,7 @@ def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin r""" .. deprecated:: v1.6 This callback hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on_fit_start`` or ``setup`` instead. + ``on_train_start``. Called when the pretrain routine ends. """ From c1c2ebabea640c5aa47a5d9100e82f39f4e16158 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 15:41:44 +0100 Subject: [PATCH 09/36] revert changes --- docs/source/common/lightning_module.rst | 15 --------------- pytorch_lightning/trainer/trainer.py | 6 ------ 2 files changed, 21 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index e416d329ec4ce..85e9c26a057b1 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1208,9 +1208,6 @@ for more information. setup("fit") configure_optimizers() - on_pretrain_routine_start() - on_pretrain_routine_end() - # the sanity check runs here on_train_start() @@ -1378,18 +1375,6 @@ on_validation_end .. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_end :noindex: -on_pretrain_routine_start -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_start - :noindex: - -on_pretrain_routine_end -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_end - :noindex: - on_test_batch_start ~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 41200e2b3afd9..f0b56e35e1bf1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1596,12 +1596,6 @@ def _call_callback_hooks( with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): fn(self, self.lightning_module, *args, **kwargs) - # TODO: remove if block in v1.8 - if hook_name == "on_pretrain_routine_start": - self.on_train_start(*args, **kwargs) - if hook_name == "on_pretrain_routine_end": - self.on_train_end(*args, **kwargs) - if pl_module: # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name From 3e480b7a31133932a356c8b4f7ee0a3424c20148 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 15:44:03 +0100 Subject: [PATCH 10/36] revert hooks --- pytorch_lightning/core/hooks.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index f3fba7d5b686e..ba77c727e71d2 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -64,17 +64,19 @@ def on_predict_end(self) -> None: def on_pretrain_routine_start(self) -> None: """Called at the beginning of the pretrain routine (between fit and train start). - .. deprecated:: v1.6 - :meth:`on_pretrain_routine_start` is deprecated and will be removed in v1.8.0. - Please use :meth:`on_fit_start` or :meth:`setup` directly. + - fit + - pretrain_routine start + - pretrain_routine end + - training_start """ def on_pretrain_routine_end(self) -> None: """Called at the end of the pretrain routine (between fit and train start). - .. deprecated:: v1.6 - :meth:`on_pretrain_routine_end` is deprecated and will be removed in v1.8.0. - Please use :meth:`on_fit_start` or :meth:`setup` directly. + - fit + - pretrain_routine start + - pretrain_routine end + - training_start """ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: @@ -180,10 +182,20 @@ def on_predict_model_eval(self) -> None: self.trainer.model.eval() def on_epoch_start(self) -> None: - """Called when either of train/val/test epoch begins.""" + r""" + .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_start`` instead. + + Called when either of train/val/test epoch begins. + """ def on_epoch_end(self) -> None: - """Called when either of train/val/test epoch ends.""" + r""" + .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_end`` instead. + + Called when either of train/val/test epoch ends. + """ def on_train_epoch_start(self) -> None: """Called in the training loop at the very beginning of the epoch.""" @@ -802,4 +814,4 @@ def on_save_checkpoint(self, checkpoint): including amp scaling. There is no need for you to store anything about training. - """ + """ \ No newline at end of file From f6c2f050b52b4d02f56e77b03e8e7aec440e9e17 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 15:46:09 +0100 Subject: [PATCH 11/36] revert hook --- pytorch_lightning/core/hooks.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ba77c727e71d2..9495185264914 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -63,7 +63,6 @@ def on_predict_end(self) -> None: def on_pretrain_routine_start(self) -> None: """Called at the beginning of the pretrain routine (between fit and train start). - - fit - pretrain_routine start - pretrain_routine end @@ -72,7 +71,6 @@ def on_pretrain_routine_start(self) -> None: def on_pretrain_routine_end(self) -> None: """Called at the end of the pretrain routine (between fit and train start). - - fit - pretrain_routine start - pretrain_routine end @@ -182,20 +180,10 @@ def on_predict_model_eval(self) -> None: self.trainer.model.eval() def on_epoch_start(self) -> None: - r""" - .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on__epoch_start`` instead. - - Called when either of train/val/test epoch begins. - """ + """Called when either of train/val/test epoch begins.""" def on_epoch_end(self) -> None: - r""" - .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on__epoch_end`` instead. - - Called when either of train/val/test epoch ends. - """ + """Called when either of train/val/test epoch ends.""" def on_train_epoch_start(self) -> None: """Called in the training loop at the very beginning of the epoch.""" @@ -814,4 +802,4 @@ def on_save_checkpoint(self, checkpoint): including amp scaling. There is no need for you to store anything about training. - """ \ No newline at end of file + """ From 8779cc522b523269e70ad7cdef8a9f941636847f Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 15:50:10 +0100 Subject: [PATCH 12/36] revert changes for hooks --- docs/source/common/lightning_module.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 85e9c26a057b1..e416d329ec4ce 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1208,6 +1208,9 @@ for more information. setup("fit") configure_optimizers() + on_pretrain_routine_start() + on_pretrain_routine_end() + # the sanity check runs here on_train_start() @@ -1375,6 +1378,18 @@ on_validation_end .. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_end :noindex: +on_pretrain_routine_start +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_start + :noindex: + +on_pretrain_routine_end +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_end + :noindex: + on_test_batch_start ~~~~~~~~~~~~~~~~~~~ From e7f3ad5c1a2f5aa44b6d6b4266df3c4b29a6a033 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 15:59:46 +0100 Subject: [PATCH 13/36] remove from logging --- tests/trainer/logging_/test_logger_connector.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 8d6b3551e8579..c6df53550aa20 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -51,8 +51,6 @@ def test_fx_validator(tmpdir): "on_keyboard_interrupt", "on_exception", "on_load_checkpoint", - "on_pretrain_routine_end", - "on_pretrain_routine_start", "on_sanity_check_end", "on_sanity_check_start", "on_save_checkpoint", @@ -94,8 +92,6 @@ def test_fx_validator(tmpdir): "on_keyboard_interrupt", "on_exception", "on_load_checkpoint", - "on_pretrain_routine_end", - "on_pretrain_routine_start", "on_sanity_check_end", "on_sanity_check_start", "on_predict_batch_end", @@ -217,8 +213,6 @@ def test_fx_validator_integration(tmpdir): "on_configure_sharded_model": "You can't", "configure_optimizers": "You can't", "on_fit_start": "You can't", - "on_pretrain_routine_start": "You can't", - "on_pretrain_routine_end": "You can't", "on_train_dataloader": "You can't", "train_dataloader": "You can't", "on_val_dataloader": "You can't", From ca747522e122df51a77a9e3637c9133fc45775cd Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 17:05:26 +0100 Subject: [PATCH 14/36] remove comments --- pytorch_lightning/callbacks/model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5a18b850fb4f2..91d59c2ee968e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -248,7 +248,6 @@ def state_key(self) -> str: ) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - # When pretrain routine starts we build the ckpt dir on the fly. self.__resolve_ckpt_dir(trainer) if trainer.is_global_zero: self.__warn_if_dir_not_empty(self.dirpath) From 72c5f2aa08196944fee3007ec116588e6423c501 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Mon, 7 Feb 2022 16:49:17 +0100 Subject: [PATCH 15/36] init commit --- pytorch_lightning/callbacks/base.py | 16 ++++++++++++++-- pytorch_lightning/callbacks/lambda_function.py | 2 -- pytorch_lightning/callbacks/model_checkpoint.py | 11 +++++------ pytorch_lightning/callbacks/model_summary.py | 2 +- pytorch_lightning/core/hooks.py | 14 ++++++-------- pytorch_lightning/trainer/trainer.py | 9 --------- 6 files changed, 26 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a24fef72e5b36..ba8e14c9140ed 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -248,10 +248,22 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - """Called when the train ends.""" def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when the pretrain routine begins.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on_fit_start`` or ``setup`` instead. + + Called when the pretrain routine begins. + """ def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when the pretrain routine ends.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on_fit_start`` or ``setup`` instead. + + Called when the pretrain routine ends. + """ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the validation loop begins.""" diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 1813e7d19090f..f38a1fc78048c 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -68,8 +68,6 @@ def __init__( on_batch_end: Optional[Callable] = None, on_train_start: Optional[Callable] = None, on_train_end: Optional[Callable] = None, - on_pretrain_routine_start: Optional[Callable] = None, - on_pretrain_routine_end: Optional[Callable] = None, on_validation_start: Optional[Callable] = None, on_validation_end: Optional[Callable] = None, on_test_start: Optional[Callable] = None, diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 278094dc7bff0..8bb6ea0910bef 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -248,6 +248,11 @@ def state_key(self) -> str: ) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + # When pretrain routine starts we build the ckpt dir on the fly. + self.__resolve_ckpt_dir(trainer) + if trainer.is_global_zero: + self.__warn_if_dir_not_empty(self.dirpath) + # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, # because the attributes are part of the state_key which needs to be fully defined before reloading. if self._save_on_train_epoch_end is None: @@ -255,12 +260,6 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O # validation, then we run after validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """When pretrain routine starts we build the ckpt dir on the fly.""" - self.__resolve_ckpt_dir(trainer) - if trainer.is_global_zero: - self.__warn_if_dir_not_empty(self.dirpath) - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._last_time_checked = time.monotonic() diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index d921c11943acc..78739cb4714a8 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self._max_depth: return None diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b5a638d7100ae..862b7e448052a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -64,19 +64,17 @@ def on_predict_end(self) -> None: def on_pretrain_routine_start(self) -> None: """Called at the beginning of the pretrain routine (between fit and train start). - - fit - - pretrain_routine start - - pretrain_routine end - - training_start + .. deprecated:: v1.6 + :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + Please use :meth:`on_fit_start` or :meth:`setup` directly. """ def on_pretrain_routine_end(self) -> None: """Called at the end of the pretrain routine (between fit and train start). - - fit - - pretrain_routine start - - pretrain_routine end - - training_start + .. deprecated:: v1.6 + :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + Please use :meth:`on_fit_start` or :meth:`setup` directly. """ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f73e16604dcba..d0b5141184f98 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1280,15 +1280,6 @@ def _pre_training_routine(self): # register signals self._signal_connector.register_signal_handlers() - # -------------------------- - # Pre-train - # -------------------------- - self._call_callback_hooks("on_pretrain_routine_start") - self._call_lightning_module_hook("on_pretrain_routine_start") - - self._call_callback_hooks("on_pretrain_routine_end") - self._call_lightning_module_hook("on_pretrain_routine_end") - def _run_train(self) -> None: self._pre_training_routine() From 9d8ae1425386994c0ded8e8e042863bd4b142c7b Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Mon, 7 Feb 2022 17:02:50 +0100 Subject: [PATCH 16/36] feedback based changes --- pytorch_lightning/callbacks/lambda_function.py | 2 ++ pytorch_lightning/callbacks/model_summary.py | 2 +- pytorch_lightning/core/hooks.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 9 +++++++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index f38a1fc78048c..1813e7d19090f 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -68,6 +68,8 @@ def __init__( on_batch_end: Optional[Callable] = None, on_train_start: Optional[Callable] = None, on_train_end: Optional[Callable] = None, + on_pretrain_routine_start: Optional[Callable] = None, + on_pretrain_routine_end: Optional[Callable] = None, on_validation_start: Optional[Callable] = None, on_validation_end: Optional[Callable] = None, on_test_start: Optional[Callable] = None, diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index 78739cb4714a8..921c6bb1ae93a 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: if not self._max_depth: return None diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 862b7e448052a..e7976b2fac6c6 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -65,7 +65,7 @@ def on_pretrain_routine_start(self) -> None: """Called at the beginning of the pretrain routine (between fit and train start). .. deprecated:: v1.6 - :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + :meth:`on_pretrain_routine_start` is deprecated and will be removed in v1.8.0. Please use :meth:`on_fit_start` or :meth:`setup` directly. """ @@ -73,7 +73,7 @@ def on_pretrain_routine_end(self) -> None: """Called at the end of the pretrain routine (between fit and train start). .. deprecated:: v1.6 - :meth:`on_val_dataloader` is deprecated and will be removed in v1.8.0. + :meth:`on_pretrain_routine_end` is deprecated and will be removed in v1.8.0. Please use :meth:`on_fit_start` or :meth:`setup` directly. """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d0b5141184f98..f73e16604dcba 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1280,6 +1280,15 @@ def _pre_training_routine(self): # register signals self._signal_connector.register_signal_handlers() + # -------------------------- + # Pre-train + # -------------------------- + self._call_callback_hooks("on_pretrain_routine_start") + self._call_lightning_module_hook("on_pretrain_routine_start") + + self._call_callback_hooks("on_pretrain_routine_end") + self._call_lightning_module_hook("on_pretrain_routine_end") + def _run_train(self) -> None: self._pre_training_routine() From 53887ee158cb726d78ac0a992220b9e65dc0c2e4 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 13:50:06 +0100 Subject: [PATCH 17/36] adress PR comments --- pytorch_lightning/callbacks/model_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index 921c6bb1ae93a..7901ae6afd7db 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: if not self._max_depth: return None From fe5efa24adf46d8239976a1a414d2a5ff9c88741 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Feb 2022 12:51:27 +0000 Subject: [PATCH 18/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/callbacks/model_summary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index 7901ae6afd7db..d6272e28ccdd2 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,7 +49,9 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def on_train_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None + ) -> None: if not self._max_depth: return None From e2172fca56be64fb48d0f51507d94c580228e55f Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 9 Feb 2022 14:47:28 +0100 Subject: [PATCH 19/36] trainer changes --- pytorch_lightning/trainer/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f73e16604dcba..438dcb9e3df08 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1595,6 +1595,12 @@ def _call_callback_hooks( with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): fn(self, self.lightning_module, *args, **kwargs) + # TODO: remove if block in v1.8 + if hook_name == "on_pretrain_routine_start": + self.on_train_start(*args, **kwargs) + if hook_name == "on_pretrain_routine_end": + self.on_train_end(*args, **kwargs) + if pl_module: # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name From 437710b5cbfdc15ecdc81af7bffb68d8d4be2ce9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Feb 2022 11:29:19 +0000 Subject: [PATCH 20/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/configuration_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 77a3eb26bfefd..2c4ca3ce97990 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -360,6 +360,7 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) + def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None: for callback in trainer.callbacks: if is_overridden(method_name="on_configure_sharded_model", instance=callback): @@ -367,4 +368,3 @@ def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None: "The `on_configure_sharded_model` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead." ) - From 53535f851215daa41684a481ae68c967a9f9ce52 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Thu, 10 Feb 2022 12:33:16 +0100 Subject: [PATCH 21/36] rebased validator --- pytorch_lightning/trainer/configuration_validator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 77a3eb26bfefd..ef93a8013a3a5 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -350,6 +350,15 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) + for hook, alternative_hook in ( + ["on_epoch_start", "on__epoch_start"], + ["on_epoch_end", "on__epoch_end"], + ): + if is_overridden(method_name=hook, instance=callback): + rank_zero_deprecation( + f"The `Callback.{hook}` hook was deprecated in v1.6 and" + f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." + ) for hook, alternative_hook in ( ["on_pretrain_routine_start", "on_fit_start"], ["on_pretrain_routine_end", "on_fit_start"], From e741632e04ead92441711748bde20b1521b76ab5 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Thu, 10 Feb 2022 12:36:17 +0100 Subject: [PATCH 22/36] rebase again --- pytorch_lightning/trainer/configuration_validator.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index c9cb50796323b..3c4c87b6416ab 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -368,12 +368,3 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) - - -def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None: - for callback in trainer.callbacks: - if is_overridden(method_name="on_configure_sharded_model", instance=callback): - rank_zero_deprecation( - "The `on_configure_sharded_model` callback hook was deprecated in" - " v1.6 and will be removed in v1.8. Use `setup()` instead." - ) From 0311369e3a6ff3dc14e36cee434caa066d9528a2 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Tue, 15 Feb 2022 07:59:14 +0100 Subject: [PATCH 23/36] fix ci error by importing optional --- pytorch_lightning/callbacks/model_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index d6272e28ccdd2..27389aad68c7e 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -22,7 +22,7 @@ """ import logging -from typing import List, Tuple +from typing import List, Tuple, Optional import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback From d779a2590af895d5bac5eab004af0895273c1d4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Feb 2022 07:00:37 +0000 Subject: [PATCH 24/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/callbacks/model_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index 27389aad68c7e..cfa44c627905e 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -22,7 +22,7 @@ """ import logging -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback From 47b3e825d725fa302a3c7a1dd74883330b7d9768 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Tue, 15 Feb 2022 14:35:24 +0100 Subject: [PATCH 25/36] remove bc breaking changes --- pytorch_lightning/trainer/trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 438dcb9e3df08..f73e16604dcba 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1595,12 +1595,6 @@ def _call_callback_hooks( with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): fn(self, self.lightning_module, *args, **kwargs) - # TODO: remove if block in v1.8 - if hook_name == "on_pretrain_routine_start": - self.on_train_start(*args, **kwargs) - if hook_name == "on_pretrain_routine_end": - self.on_train_end(*args, **kwargs) - if pl_module: # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name From 216ac9506aff1b2052dd31df3d9c2165a260e0ec Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Tue, 15 Feb 2022 18:56:02 +0100 Subject: [PATCH 26/36] changes according to suggestions --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac47d7fadf6e6..5c04ae170ba7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -385,7 +385,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning` -- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hook ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) +- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) ### Removed diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 3c4c87b6416ab..a5a422e846426 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -360,8 +360,8 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) for hook, alternative_hook in ( - ["on_pretrain_routine_start", "on_fit_start"], - ["on_pretrain_routine_end", "on_fit_start"], + ("on_pretrain_routine_start", "on_fit_start"), + ("on_pretrain_routine_end", "on_fit_start"), ): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( From 1b2763ab0e8adeecaaf99eb2740ad1b792f99fd4 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 16 Feb 2022 23:02:00 +0100 Subject: [PATCH 27/36] update unit tests --- tests/models/test_hooks.py | 12 ++++-------- tests/trainer/test_data_loading.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 00ccaa3ec7c6c..c1b6e53543191 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -511,10 +511,8 @@ def training_step(self, batch, batch_idx): dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), - dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), - dict(name="on_pretrain_routine_start"), - dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), - dict(name="on_pretrain_routine_end"), + dict(name="Callback.on_fit_start", args=(trainer, model)), + dict(name="on_fit_start"), dict(name="Callback.on_sanity_check_start", args=(trainer, model)), dict(name="on_val_dataloader"), dict(name="val_dataloader"), @@ -632,10 +630,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), - dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), - dict(name="on_pretrain_routine_start"), - dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), - dict(name="on_pretrain_routine_end"), + dict(name="Callback.on_fit_start", args=(trainer, model)), + dict(name="on_fit_start"), dict(name="train", args=(True,)), dict(name="on_train_dataloader"), dict(name="train_dataloader"), diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 96d943aa00d79..85ef15fe27d8d 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -104,7 +104,7 @@ def __init__(self, num_workers): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - def on_pretrain_routine_start(self): + def on_fit_start(self): self._resout = StringIO() self.ctx = redirect_stderr(self._resout) self.ctx.__enter__() From a1637939734a11c8d003bf2270e3b99a4a19b700 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 16 Feb 2022 23:04:34 +0100 Subject: [PATCH 28/36] unit tests updated --- tests/deprecated_api/test_remove_1-8.py | 3 --- tests/models/test_hooks.py | 4 ---- 2 files changed, 7 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 1f7a92d0745c9..8c07a4c637889 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -209,8 +209,6 @@ def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): "on_epoch_end", "on_train_start", "on_train_end", - "on_pretrain_routine_start", - "on_pretrain_routine_end", "on_batch_start", "on_batch_end", "on_validation_start", @@ -246,7 +244,6 @@ def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): logger=False, ) model = BoringModel() - # need to attach model to trainer for testing of `on_pretrain_routine_start` trainer.fit(model) for method_name in methods_with_self: fn = getattr(trainer, method_name, None) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index c1b6e53543191..1823107001d71 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -511,8 +511,6 @@ def training_step(self, batch, batch_idx): dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), - dict(name="Callback.on_fit_start", args=(trainer, model)), - dict(name="on_fit_start"), dict(name="Callback.on_sanity_check_start", args=(trainer, model)), dict(name="on_val_dataloader"), dict(name="val_dataloader"), @@ -630,8 +628,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), - dict(name="Callback.on_fit_start", args=(trainer, model)), - dict(name="on_fit_start"), dict(name="train", args=(True,)), dict(name="on_train_dataloader"), dict(name="train_dataloader"), From 73e6c68054a31cfdbd300e0af01eadcffd97cd90 Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 16 Feb 2022 23:07:13 +0100 Subject: [PATCH 29/36] update test restore --- tests/models/test_restore.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 8e1006ce73147..225eba0a375c7 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -254,7 +254,7 @@ def test_correct_step_and_epoch(tmpdir): assert trainer.global_step == 0 class TestModel(BoringModel): - def on_pretrain_routine_end(self) -> None: + def on_fit_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs # TODO(@carmocca): should not need `+1` assert self.trainer.global_step == first_max_epochs * train_batches + 1 @@ -302,7 +302,7 @@ def test_try_resume_from_non_existing_checkpoint(tmpdir): class CaptureCallbacksBeforeTraining(Callback): callbacks = [] - def on_pretrain_routine_end(self, trainer, pl_module): + def on_fit_start(self, trainer, pl_module): self.callbacks = deepcopy(trainer.callbacks) @@ -610,10 +610,10 @@ def test_dp_resume(tmpdir): class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() - self.on_pretrain_routine_end_called = False + self.on_fit_start_called = False # set the epoch start hook so we can predict before the model does the full training - def on_pretrain_routine_end(self): + def on_fit_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we @@ -622,14 +622,14 @@ def on_pretrain_routine_end(self): dataloader = dm.train_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) - self.on_pretrain_routine_end_called = True + self.on_fit_start_called = True # new model model = CustomModel() # fit new model which should load hpc weights new_trainer.fit(model, datamodule=dm) - assert model.on_pretrain_routine_end_called + assert model.on_fit_start_called # test freeze on gpu model.freeze() From b7f52520a7769a9f269a52aca61f059f7a49335c Mon Sep 17 00:00:00 2001 From: krishnakalyan3 Date: Wed, 16 Feb 2022 23:26:30 +0100 Subject: [PATCH 30/36] remove to fix unit test --- tests/models/test_restore.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 225eba0a375c7..f3fe97d50fb59 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -253,13 +253,6 @@ def test_correct_step_and_epoch(tmpdir): assert trainer.current_epoch == 0 assert trainer.global_step == 0 - class TestModel(BoringModel): - def on_fit_start(self) -> None: - assert self.trainer.current_epoch == first_max_epochs - # TODO(@carmocca): should not need `+1` - assert self.trainer.global_step == first_max_epochs * train_batches + 1 - - trainer.fit(TestModel(), ckpt_path=ckpt_path) # TODO(@carmocca): should not need `-1` assert trainer.current_epoch == max_epochs - 1 # TODO(@carmocca): should not need `+1` From fe14dec74d8f1201ebcc675dcaa23475ae697702 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 24 Feb 2022 19:45:00 +0530 Subject: [PATCH 31/36] fix the deprecations --- CHANGELOG.md | 4 +++- docs/source/extensions/callbacks.rst | 12 ---------- .../callbacks/model_checkpoint.py | 2 +- pytorch_lightning/callbacks/model_summary.py | 4 +--- pytorch_lightning/core/hooks.py | 24 +++++++++---------- .../trainer/configuration_validator.py | 9 +++---- tests/deprecated_api/test_remove_1-8.py | 5 +++- tests/models/test_hooks.py | 8 +++++++ tests/models/test_restore.py | 7 ++++-- .../trainer/logging_/test_logger_connector.py | 6 +++++ 10 files changed, 43 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d523bb9838abb..1fe3985cd7d92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -402,7 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning` -- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) + +- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks in favor of `on_fit_start` ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794)) + - Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832)) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index ab4894de77ca2..23c06bb1790d1 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -375,18 +375,6 @@ on_train_end .. automethod:: pytorch_lightning.callbacks.Callback.on_train_end :noindex: -on_pretrain_routine_start -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_start - :noindex: - -on_pretrain_routine_end -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_end - :noindex: - on_validation_start ~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 67724c34bf04a..1e96c86105b02 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -249,7 +249,7 @@ def state_key(self) -> str: def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: self.__resolve_ckpt_dir(trainer) - if trainer.is_global_zero: + if trainer.is_global_zero and stage == "fit": self.__warn_if_dir_not_empty(self.dirpath) # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index cfa44c627905e..d374d85ccde46 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -49,9 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def on_train_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None - ) -> None: + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: if not self._max_depth: return None diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e7976b2fac6c6..0524a86600b60 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -64,17 +64,19 @@ def on_predict_end(self) -> None: def on_pretrain_routine_start(self) -> None: """Called at the beginning of the pretrain routine (between fit and train start). - .. deprecated:: v1.6 - :meth:`on_pretrain_routine_start` is deprecated and will be removed in v1.8.0. - Please use :meth:`on_fit_start` or :meth:`setup` directly. + - fit + - pretrain_routine start + - pretrain_routine end + - training_start """ def on_pretrain_routine_end(self) -> None: """Called at the end of the pretrain routine (between fit and train start). - .. deprecated:: v1.6 - :meth:`on_pretrain_routine_end` is deprecated and will be removed in v1.8.0. - Please use :meth:`on_fit_start` or :meth:`setup` directly. + - fit + - pretrain_routine start + - pretrain_routine end + - training_start """ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: @@ -180,17 +182,15 @@ def on_predict_model_eval(self) -> None: self.trainer.model.eval() def on_epoch_start(self) -> None: - r""" - .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on__epoch_start`` instead. + """.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_start`` instead. Called when either of train/val/test epoch begins. """ def on_epoch_end(self) -> None: - r""" - .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on__epoch_end`` instead. + """.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_end`` instead. Called when either of train/val/test epoch ends. """ diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 66de5170332e1..7529f01948563 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -357,12 +357,9 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) - for hook, alternative_hook in ( - ("on_pretrain_routine_start", "on_fit_start"), - ("on_pretrain_routine_end", "on_fit_start"), - ): + for hook in ("on_pretrain_routine_start", "on_pretrain_routine_end"): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( - f"The `Callback.{hook}` hook was deprecated in v1.6 and" - f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." + f"The `Callback.{hook}` hook has been deprecated in v1.6 and" + f" will be removed in v1.8. Please use `Callback.on_fit_start` instead." ) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 3c6e62f91107e..990f22d92a131 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -210,6 +210,8 @@ def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): "on_epoch_end", "on_train_start", "on_train_end", + "on_pretrain_routine_start", + "on_pretrain_routine_end", "on_batch_start", "on_batch_end", "on_validation_start", @@ -245,7 +247,8 @@ def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): logger=False, ) model = BoringModel() - trainer.fit(model) + # need to attach model to trainer for testing of `on_pretrain_routine_start` + trainer.strategy.connect(model) for method_name in methods_with_self: fn = getattr(trainer, method_name, None) with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index a96bd87567083..f08b75cfa454f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -511,6 +511,10 @@ def training_step(self, batch, batch_idx): dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), + dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), + dict(name="on_pretrain_routine_start"), + dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), + dict(name="on_pretrain_routine_end"), dict(name="Callback.on_sanity_check_start", args=(trainer, model)), dict(name="on_val_dataloader"), dict(name="val_dataloader"), @@ -627,6 +631,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), + dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), + dict(name="on_pretrain_routine_start"), + dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), + dict(name="on_pretrain_routine_end"), dict(name="train", args=(True,)), dict(name="on_train_dataloader"), dict(name="train_dataloader"), diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 7b90247044f94..53838691a2efb 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -325,12 +325,15 @@ def get_trainer_args(): # initial training trainer = Trainer(**get_trainer_args()) - trainer.fit(model, datamodule=dm) + with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"): + trainer.fit(model, datamodule=dm) + callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args()) - trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) + with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"): + trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index a168068e81239..edebcc0a9e900 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -51,6 +51,8 @@ def test_fx_validator(tmpdir): "on_keyboard_interrupt", "on_exception", "on_load_checkpoint", + "on_pretrain_routine_end", + "on_pretrain_routine_start", "on_sanity_check_end", "on_sanity_check_start", "on_save_checkpoint", @@ -92,6 +94,8 @@ def test_fx_validator(tmpdir): "on_keyboard_interrupt", "on_exception", "on_load_checkpoint", + "on_pretrain_routine_end", + "on_pretrain_routine_start", "on_sanity_check_end", "on_sanity_check_start", "on_predict_batch_end", @@ -213,6 +217,8 @@ def test_fx_validator_integration(tmpdir): "on_configure_sharded_model": "You can't", "configure_optimizers": "You can't", "on_fit_start": "You can't", + "on_pretrain_routine_start": "You can't", + "on_pretrain_routine_end": "You can't", "on_train_dataloader": "You can't", "train_dataloader": "You can't", "on_val_dataloader": "You can't", From 86574e2b845320857e20622df0b4ca1d3ac34a91 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 24 Feb 2022 19:49:33 +0530 Subject: [PATCH 32/36] fix the deprecations --- pytorch_lightning/callbacks/model_summary.py | 4 ++-- pytorch_lightning/core/hooks.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py index d374d85ccde46..e659ddd057ace 100644 --- a/pytorch_lightning/callbacks/model_summary.py +++ b/pytorch_lightning/callbacks/model_summary.py @@ -22,7 +22,7 @@ """ import logging -from typing import List, Optional, Tuple +from typing import List, Tuple import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback @@ -49,7 +49,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1) -> None: self._max_depth: int = max_depth - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self._max_depth: return None diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 0524a86600b60..d321b151a27dc 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -182,17 +182,19 @@ def on_predict_model_eval(self) -> None: self.trainer.model.eval() def on_epoch_start(self) -> None: - """.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on__epoch_start`` instead. + """Called when either of train/val/test epoch begins. - Called when either of train/val/test epoch begins. + .. deprecated:: v1.6 + :meth:`on_epoch_start` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on__epoch_start`` instead. """ def on_epoch_end(self) -> None: - """.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on__epoch_end`` instead. + """Called when either of train/val/test epoch ends. - Called when either of train/val/test epoch ends. + .. deprecated:: v1.6 + :meth:`on_epoch_end` has been deprecated in v1.6 and will be removed in v1.8. + Use ``on__epoch_end`` instead. """ def on_train_epoch_start(self) -> None: From 88dc0ab951161e0f0f4139d78f2cba9148bf513f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 24 Feb 2022 19:51:23 +0530 Subject: [PATCH 33/36] fix the deprecations --- tests/trainer/test_data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index a0c5193834013..be063aab0bf95 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -104,7 +104,7 @@ def __init__(self, num_workers): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - def on_fit_start(self): + def on_pretrain_routine_start(self): self._resout = StringIO() self.ctx = redirect_stderr(self._resout) self.ctx.__enter__() From 92586bcd850a34b061f166a448dc585f19368c5f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 24 Feb 2022 20:00:24 +0530 Subject: [PATCH 34/36] add deprecation test --- tests/deprecated_api/test_remove_1-8.py | 36 +++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 990f22d92a131..8fe871538ad11 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -532,3 +532,39 @@ def agg_and_log_metrics(self, metrics, step): Trainer(logger=[logger, logger3]) # Should have no deprecation warning Trainer(logger=[logger2, logger3]) + + +def test_v1_8_0_callback_on_pretrain_routune(tmpdir): + class TestCallback(Callback): + def on_pretrain_routine_start(self, trainer, pl_module): + print("on_pretrain_routune_start called.") + + model = BoringModel() + + trainer = Trainer( + callbacks=[TestCallback()], + fast_dev_run=True, + enable_progress_bar=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `Callback.on_pretrain_routine_start` hook has been deprecated in v1.6" " and will be removed in v1.8" + ): + trainer.fit(model) + + class TestCallback(Callback): + def on_pretrain_routine_end(self, trainer, pl_module): + print("on_pretrain_routune_end called.") + + model = BoringModel() + + trainer = Trainer( + callbacks=[TestCallback()], + fast_dev_run=True, + enable_progress_bar=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8" + ): + trainer.fit(model) From 75da9fadc88c73ba87d344955d3180ced6f378e4 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 24 Feb 2022 22:40:33 +0530 Subject: [PATCH 35/36] Apply suggestions from code review Co-authored-by: Jirka Borovec Co-authored-by: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> --- pytorch_lightning/callbacks/base.py | 4 ++-- tests/deprecated_api/test_remove_1-8.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index f338f6fc59605..41a893c0d8158 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -259,8 +259,8 @@ def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: r""" .. deprecated:: v1.6 - This callback hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on_fit_start`` instead. + + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead. Called when the pretrain routine ends. """ diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index f5bc4cfbd5278..d51c5638f04c3 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -574,10 +574,10 @@ def agg_and_log_metrics(self, metrics, step): Trainer(logger=[logger2, logger3]) -def test_v1_8_0_callback_on_pretrain_routune(tmpdir): +def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir): class TestCallback(Callback): def on_pretrain_routine_start(self, trainer, pl_module): - print("on_pretrain_routune_start called.") + print("on_pretrain_routine_start called.") model = BoringModel() @@ -594,7 +594,7 @@ def on_pretrain_routine_start(self, trainer, pl_module): class TestCallback(Callback): def on_pretrain_routine_end(self, trainer, pl_module): - print("on_pretrain_routune_end called.") + print("on_pretrain_routine_end called.") model = BoringModel() From 41f548913f6e58781388e6ac53a0508d3741278d Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 24 Feb 2022 22:40:56 +0530 Subject: [PATCH 36/36] Update pytorch_lightning/callbacks/base.py Co-authored-by: Jirka Borovec --- pytorch_lightning/callbacks/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 41a893c0d8158..123b100ee4f26 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -250,8 +250,8 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: r""" .. deprecated:: v1.6 - This callback hook was deprecated in v1.6 and will be removed in v1.8. Use - ``on_fit_start`` instead. + + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead. Called when the pretrain routine begins. """