diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 28695785c367c..44c4096b9b204 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -302,6 +302,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the input validation for the accelerator Trainer argument when passed as a string ([#13417](https://github.com/PyTorchLightning/pytorch-lightning/pull/13417)) +- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467)) + ## [1.6.4] - 2022-06-01 diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 04e9d070a6d8e..36a594b45ae6f 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -273,6 +273,7 @@ def teardown(self) -> None: def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() + state_dict["_batches_that_stepped"] = self._batches_that_stepped if ( self.trainer is not None @@ -292,6 +293,7 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict") + self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) def _run_validation(self) -> None: # reload dataloaders diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index 1e67fcc0ed8db..f9630095502d1 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -47,7 +47,7 @@ def test_loops_state_dict_structure(): expected = { "fit_loop": { "state_dict": {}, - "epoch_loop.state_dict": {}, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 77f45928dd907..4f167c08e8a05 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -259,6 +259,7 @@ def on_train_start(self) -> None: trainer.fit(TestModel(), ckpt_path=ckpt_path) assert trainer.current_epoch == max_epochs assert trainer.global_step == max_epochs * train_batches + assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches def test_fit_twice(tmpdir):