Skip to content

Commit adec445

Browse files
author
SeanNaren
committed
Address code review
1 parent d5dd739 commit adec445

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,11 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
420420
"""
421421
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)
422422

423-
def state_dict(self) -> Dict[str, Union[Any, Tensor]]:
423+
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
424424
"""
425425
Returns state of model. Allows for syncing/collating model state from processes in custom plugins.
426426
"""
427-
return self.training_type_plugin.state_dict()
427+
return self.training_type_plugin.lightning_module_state_dict()
428428

429429
def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]:
430430
return self.training_type_plugin.on_save(checkpoint)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) ->
242242
"""
243243
return current_global_step + 1
244244

245-
def state_dict(self) -> Dict[str, Union[Any, Tensor]]:
245+
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
246246
"""Returns model state."""
247247
model = self.lightning_module
248248
return model.state_dict()

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
273273
'epoch': current_epoch,
274274
'global_step': global_step,
275275
'pytorch-lightning_version': pytorch_lightning.__version__,
276-
'state_dict': self.trainer.accelerator.state_dict(),
276+
'state_dict': self.trainer.accelerator.lightning_module_state_dict(),
277277
}
278278

279279
if not weights_only:

0 commit comments

Comments
 (0)