Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))

-
- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532))

-

Expand Down
27 changes: 14 additions & 13 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,19 +341,20 @@ def restore_loops(self) -> None:
pl_module = self.trainer.lightning_module
assert pl_module is not None

# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]

# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
if self.trainer.state.fn == TrainerFn.FITTING:
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]

# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]

assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _check_model_state_dict(self):

def _test_on_val_test_predict_start(self):
assert self.trainer.current_epoch == state_dict["epoch"]
assert self.trainer.global_step == state_dict["global_step"]
assert self.trainer.global_step == 0
assert self._check_model_state_dict()

def on_train_start(self):
Expand Down Expand Up @@ -626,8 +626,10 @@ def __init__(self):
super().__init__()
self.on_train_start_called = False

def on_validation_start(self):
def on_train_start(self):
assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

def on_validation_start(self):
dataloader = dm.val_dataloader()
tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)

Expand Down