diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 377a82428fe8c..cd5235af5ca8d 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) -- +- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532)) - diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2cb285167aeb0..bff97a7fbc3b5 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -341,19 +341,20 @@ def restore_loops(self) -> None: pl_module = self.trainer.lightning_module assert pl_module is not None - # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - batch_loop = fit_loop.epoch_loop.batch_loop - if pl_module.automatic_optimization: - batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ - "global_step" - ] - else: - batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] - - # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] + if self.trainer.state.fn == TrainerFn.FITTING: + # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + batch_loop = fit_loop.epoch_loop.batch_loop + if pl_module.automatic_optimization: + batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ + "global_step" + ] + else: + batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] + + # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 8648d9ba1a6bf..36dd508ff92d0 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -187,7 +187,7 @@ def _check_model_state_dict(self): def _test_on_val_test_predict_start(self): assert self.trainer.current_epoch == state_dict["epoch"] - assert self.trainer.global_step == state_dict["global_step"] + assert self.trainer.global_step == 0 assert self._check_model_state_dict() def on_train_start(self): @@ -626,8 +626,10 @@ def __init__(self): super().__init__() self.on_train_start_called = False - def on_validation_start(self): + def on_train_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 + + def on_validation_start(self): dataloader = dm.val_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)