Skip to content

Commit b286088

Browse files
rohitgr7carmocca
andcommitted
Restore log step during restart (#13467)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 9e6997c commit b286088

File tree

4 files changed

+5
-1
lines changed

4 files changed

+5
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))
1313
- Fixed bug with Python version check that prevented use with development versions of Python ([#13420](https://github.com/PyTorchLightning/pytorch-lightning/pull/13420))
1414
- The loops now call `.set_epoch()` also on batch samplers if the dataloader has one wrapped in a distributed sampler ([#13396](https://github.com/PyTorchLightning/pytorch-lightning/pull/13396))
15+
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))
1516

1617

1718
## [1.6.4] - 2022-06-01

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def teardown(self) -> None:
281281

282282
def on_save_checkpoint(self) -> Dict:
283283
state_dict = super().on_save_checkpoint()
284+
state_dict["_batches_that_stepped"] = self._batches_that_stepped
284285

285286
if (
286287
self.trainer is not None
@@ -300,6 +301,7 @@ def on_save_checkpoint(self) -> Dict:
300301
def on_load_checkpoint(self, state_dict: Dict) -> None:
301302
# cache the dataloader state dict until the dataloader objects are available
302303
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
304+
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
303305

304306
def _run_validation(self) -> None:
305307
# reload dataloaders

tests/loops/test_loop_state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_loops_state_dict_structure():
4747
expected = {
4848
"fit_loop": {
4949
"state_dict": {},
50-
"epoch_loop.state_dict": {},
50+
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
5151
"epoch_loop.batch_progress": {
5252
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
5353
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},

tests/models/test_restore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def on_train_start(self) -> None:
259259
trainer.fit(TestModel(), ckpt_path=ckpt_path)
260260
assert trainer.current_epoch == max_epochs
261261
assert trainer.global_step == max_epochs * train_batches
262+
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches
262263

263264

264265
def test_fit_twice(tmpdir):

0 commit comments

Comments
 (0)