From 7983c62de22df3914bef08c7f70f55c7ed1adeab Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 4 Nov 2022 15:39:45 +0100 Subject: [PATCH 1/6] Only load global step when fitting --- .../connectors/checkpoint_connector.py | 23 ++++++++++--------- tests/tests_pytorch/models/test_restore.py | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2cb285167aeb0..6fc0a085ae767 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -343,17 +343,18 @@ def restore_loops(self) -> None: # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved - batch_loop = fit_loop.epoch_loop.batch_loop - if pl_module.automatic_optimization: - batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ - "global_step" - ] - else: - batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] - - # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] + if self.trainer.state.fn == TrainerFn.FITTING: + batch_loop = fit_loop.epoch_loop.batch_loop + if pl_module.automatic_optimization: + batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ + "global_step" + ] + else: + batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] + + # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 8648d9ba1a6bf..7795f5cdcb8b5 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -187,7 +187,7 @@ def _check_model_state_dict(self): def _test_on_val_test_predict_start(self): assert self.trainer.current_epoch == state_dict["epoch"] - assert self.trainer.global_step == state_dict["global_step"] + assert self.trainer.global_step == 0 assert self._check_model_state_dict() def on_train_start(self): From 5d364789d971c8e3635b20746d4cc376fd5e74b3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 4 Nov 2022 15:47:40 +0100 Subject: [PATCH 2/6] changelog --- src/pytorch_lightning/CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 377a82428fe8c..4aedefcf9ce0b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -21,9 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) -- +- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532)) -- +- ### Deprecated From 586a990a46388757b38157fe2af3d07de14ed3e0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 4 Nov 2022 15:47:58 +0100 Subject: [PATCH 3/6] move comment inside if --- .../trainer/connectors/checkpoint_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6fc0a085ae767..bff97a7fbc3b5 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -341,9 +341,9 @@ def restore_loops(self) -> None: pl_module = self.trainer.lightning_module assert pl_module is not None - # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved if self.trainer.state.fn == TrainerFn.FITTING: + # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved batch_loop = fit_loop.epoch_loop.batch_loop if pl_module.automatic_optimization: batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ From 60868fabdbb25eef531ef17891f9e67ff4150e13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Nov 2022 14:50:23 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 4aedefcf9ce0b..1e6117a456ac7 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -23,7 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532)) -- +- ### Deprecated From 6ae0e41e7a6b7d39f918d3db5d3e552857ea9b89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 4 Nov 2022 11:33:11 -0400 Subject: [PATCH 5/6] Update src/pytorch_lightning/CHANGELOG.md Co-authored-by: Rohit Gupta --- src/pytorch_lightning/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1e6117a456ac7..cd5235af5ca8d 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) -- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532)) +- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532)) - From 31a13717e867e9fa8256e2b1470f9da18514e489 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 4 Nov 2022 17:10:12 +0100 Subject: [PATCH 6/6] update another test --- tests/tests_pytorch/models/test_restore.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 7795f5cdcb8b5..36dd508ff92d0 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -626,8 +626,10 @@ def __init__(self): super().__init__() self.on_train_start_called = False - def on_validation_start(self): + def on_train_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 + + def on_validation_start(self): dataloader = dm.val_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)