Skip to content

Commit c67b075

Browse files
Use global_step while restoring logging step for old checkpoints (#13645)
Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 6cbd9d7 commit c67b075

File tree

5 files changed

+43
-4
lines changed

5 files changed

+43
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
355355
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))
356356

357357

358+
- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645))
359+
360+
358361
## [1.6.4] - 2022-06-01
359362

360363
### Added

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ def on_save_checkpoint(self) -> Dict:
287287
def on_load_checkpoint(self, state_dict: Dict) -> None:
288288
# cache the dataloader state dict until the dataloader objects are available
289289
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
290-
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
290+
# restore global step instead to make sure logging works correctly if checkpoints <v1.6.5 used to resume
291+
self._batches_that_stepped = state_dict.get("_batches_that_stepped", self.global_step)
291292

292293
def _run_validation(self) -> None:
293294
# reload dataloaders

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,19 @@ def restore_loops(self) -> None:
264264
return
265265

266266
fit_loop = self.trainer.fit_loop
267+
pl_module = self.trainer.lightning_module
268+
assert pl_module is not None
269+
267270
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
268271
# it will be overwritten by the loop's state if it was also saved
269-
optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop
270-
optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"]
272+
batch_loop = fit_loop.epoch_loop.batch_loop
273+
if pl_module.automatic_optimization:
274+
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
275+
"global_step"
276+
]
277+
else:
278+
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]
279+
271280
# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
272281
# it will be overwritten by the loop's state if it was also saved
273282
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]

tests/tests_pytorch/models/test_restore.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import tests_pytorch.helpers.utils as tutils
2929
from pytorch_lightning import Callback, Trainer
3030
from pytorch_lightning.callbacks import ModelCheckpoint
31-
from pytorch_lightning.demos.boring_classes import BoringModel
31+
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
3232
from pytorch_lightning.trainer.states import TrainerFn
3333
from tests_pytorch.helpers.datamodules import ClassifDataModule
3434
from tests_pytorch.helpers.runif import RunIf
@@ -255,13 +255,37 @@ class TestModel(BoringModel):
255255
def on_train_start(self) -> None:
256256
assert self.trainer.current_epoch == first_max_epochs
257257
assert self.trainer.global_step == first_max_epochs * train_batches
258+
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches
258259

259260
trainer.fit(TestModel(), ckpt_path=ckpt_path)
260261
assert trainer.current_epoch == max_epochs
261262
assert trainer.global_step == max_epochs * train_batches
262263
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches
263264

264265

266+
@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
267+
def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class):
268+
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
269+
model = model_class()
270+
trainer.fit(model)
271+
ckpt_path = trainer.checkpoint_callback.best_model_path
272+
ckpt = torch.load(ckpt_path)
273+
# the key "_batches_that_stepped" doesn't exist in checkpoints generated with <v1.6.5
274+
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
275+
torch.save(ckpt, ckpt_path)
276+
277+
class TestModel(model_class):
278+
def on_train_start(self) -> None:
279+
assert self.trainer.global_step == 1
280+
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1
281+
282+
trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
283+
model = TestModel()
284+
trainer.fit(model, ckpt_path=ckpt_path)
285+
new_loop = trainer.fit_loop.epoch_loop
286+
assert new_loop.global_step == new_loop._batches_that_stepped == 2
287+
288+
265289
def test_fit_twice(tmpdir):
266290
epochs = []
267291

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def test_loops_restore(tmpdir):
172172
ckpt_path = str(tmpdir / "last.ckpt")
173173

174174
trainer = Trainer(**trainer_args)
175+
trainer.strategy.connect(model)
176+
175177
for fn in TrainerFn:
176178
if fn != TrainerFn.TUNING:
177179
trainer_fn = getattr(trainer, f"{fn}_loop")

0 commit comments

Comments
 (0)