Skip to content

Commit 43a89eb

Browse files
bug fix: restore_optimizers correctly handles non-mapping values in optimizer.state.values() (#11757)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 9ed44de commit 43a89eb

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
573573
- The `RichProgressBar` now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
574574

575575

576+
- Fixed `restore_optimizers` for mapping states ([#11757](https://github.com/PyTorchLightning/pytorch-lightning/pull/11757))
577+
578+
576579
- Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526))
577580

578581

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,13 @@ def restore_optimizers(self) -> None:
283283
# move optimizer to GPU 1 weight at a time
284284
# avoids OOM
285285
if self.trainer.root_gpu is not None:
286-
for state in optimizer.state.values():
287-
for k, v in state.items():
288-
if isinstance(v, torch.Tensor):
289-
state[k] = v.cuda(self.trainer.root_gpu)
286+
for param, state in optimizer.state.items():
287+
if isinstance(state, dict):
288+
for k, v in state.items():
289+
if isinstance(v, torch.Tensor):
290+
state[k] = v.cuda(self.trainer.root_gpu)
291+
elif isinstance(state, torch.Tensor):
292+
optimizer.state[param] = state.cuda(self.trainer.root_gpu)
290293

291294
def restore_lr_schedulers(self) -> None:
292295
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""

0 commit comments

Comments
 (0)