Skip to content

[RFC] Standardize all stateful components on state_dict/load_state_dict #11429

@jjenniferdai

Description

@jjenniferdai

Proposed refactor

Standardize all stateful components on state_dict/load_state_dict.

Background

PyTorch convention uses state_dict/load_state_dict for gathering and loading object state. In lightning, some components follow this convention, while other components do not (see Appedix: current state gathering section)

Motivation

Each component should contribute saving/loading their own state with the same APIs.

  1. independently contributing/loading one’s own local component state (aligning with PyTorch primitives: state_dict/load_state_dict)
  2. operating and depending on global component state (Lightning CheckpointHooks: on_save/load_checkpoint)

This issue is focused on aligning all components on 1. Following this convention will allow consistency across Lightning components and consistency with PyTorch conventions. Any stateful component can simply implement their own state_dict/load_state_dict methods to contribute their own state.
For now 2 is only adjusted as needed (Callbacks).

Pitch

Lightning already has this Stateful Protocol in auto_restart.py. We can move this _SupportsStateDict out to the more central core/hooks.py file:

https://github.com/PyTorchLightning/pytorch-lightning/blob/34c62da37dc1ed9a1de7e023c690c7528ee56c60/pytorch_lightning/utilities/auto_restart.py#L638-L646

@runtime_checkable
class Stateful(Protocol):
    """This class is used to detect if an object is stateful using `isinstance(obj, Stateful)`."""

    def state_dict(self) -> Dict[str, Any]:
        ...

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        ...

Additional context

Part of #7740

Appendix: current state gathering:

Specifically, current components contribute state in the following different ways:

  1. Some components contribute saving/loading their own state with only on_save/load_checkpoint [DataModule, PrecisionPlugin, Callbacks]
    a. [DataModule, PrecisionPlugin] use on_save/load_checkpoint from CheckpointHooks https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/core/hooks.py#L765-L807
    b. [Callbacks] contribute saving/loading their own state with different on_save/load_checkpoint hook signatures and functionality https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/callbacks/base.py#L293-L323
    c. though [Loops] falls under 3. below, noting here that Loops also have their own on_save/load_checkpoint methods with different signatures. https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/loops/base.py#L253-L262
  2. Some components contribute saving/loading their own state with only state_dict/load_state_dict calls [Optimizers, LR schedulers]
  3. Some components have both [LightningModule, Loops]

Appendix: component, checkpoint_connector changes

Save aligning on state_dict

dump_checkpoint
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L310-L393
becomes:

def dump_checkpoint(self, weights_only: bool = False) -> dict:
   ...
   # dump callbacks
   # checkpoint["callbacks"] = self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
   # becomes
   checkpoint["callbacks"] = self.trainer._call_callbacks_state_dict()
   
   ...
   # precision plugin
   # self.trainer.precision_plugin.on_save_checkpoint(checkpoint)
   # becomes
   prec_plugin = self.trainer.precision_plugin
   checkpoint[prec_plugin.__class__.__name__] = self.trainer.precision_plugin.state_dict()

   ...
   # give the model a chance to dump a few things
   # model.on_save_checkpoint(checkpoint)
   # if self.trainer.datamodule is not None:
   #    self.trainer.datamodule.on_save_checkpoint(checkpoint)
   # becomes
   
   # datamodule state
   dm = self.trainer.datamodule
   if dm is not None:
       checkpoint[dm.__class__.__name__] = dm.state_dict()
       
   # on_save_checkpoint calls
   model.on_save_checkpoint(checkpoint)
   dm.on_save_checkpoint(checkpoint)
   prec_plugin.on_save_checkpoint(checkpoint)
   for callback in self.trainer.callbacks:
       callback.on_save_checkpoint(self.trainer, model, checkpoint)

Load aligning on load_state_dict

see component Load sections below

Callbacks:

Base Class

https://github.com/PyTorchLightning/pytorch-lightning/blob/59a7ba760548baadf6dbb30864b54cb01c7225a3/pytorch_lightning/callbacks/base.py#L293-L323

becomes
BC: on_save_checkpoint returns None instead of dict
BC: on_load_checkpoint arg takes entire checkpoint dict instead of callback_state

def state_dict(self) -> Dict[str, Any]:
    return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    pass

def on_save_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
    ) -> None:
    """Called by Lightning when saving a checkpoint to give you a chance to store or customize anything
    else you might want to save.
    Args:
        trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
        pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
        checkpoint: the checkpoint dictionary that will be saved.
    """
    pass

def on_load_checkpoint(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
    """Called by Lightning when loading a checkpoint to give you a chance to reload or customize anything
    else you may have saved in on_save_checkpoint.
    Args:
        trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
        pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
        checkpoint: entire loaded checkpoint dictionary
    """
    pass

Save

https://github.com/PyTorchLightning/pytorch-lightning/blob/85304d4672a9ed24a16f7f5b2abaa34148ab86f4/pytorch_lightning/trainer/trainer.py#L1603

def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
    ...
    state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)

becomes

def _call_callbacks_state_dict(self) -> Dict[str, dict]:
   ...
   state = callback.state_dict()

Load

https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/trainer.py#L1652

def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
    ...
    callback.on_load_checkpoint(self, self.lightning_module, state)

becomes

def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
    ...
    callback.load_state_dict(state)

restore_callbacks
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L205
becomes

for callback in self.trainer.callbacks:
    callback.on_load_checkpoint(self.trainer, self.trainer.lightning_module, self._loaded_checkpoint)
self.trainer._call_callbacks_load_state_dict(self._loaded_checkpoint)

Callback classes

update timer, pruning, model_checkpoint, finetuning, early_stopping to use state_dict/load_state_dict instead of on_save/load_checkpoint.

Precision Plugin:

Base Class

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/precision_plugin.py
add dummy state_dict/load_state_dict:

def state_dict(self) -> Dict[str, Any]:
    return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    pass

Load

restore_training_state
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L189-L190
becomes

# restore precision plugin (scaler etc.)
prec_plugin = self.trainer.precision_plugin
prec_plugin.on_load_checkpoint(self._loaded_checkpoint)
if prec_plugin.__class__.__name__ in self._loaded_checkpoint:
    prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__name__])

Precision Plugin classes

Update apex_amp, native_amp to use state_dict/load_state_dict instead of on_save/load_checkpoint

Datamodule:

Base Class

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/datamodule.py
add dummy state_dict/load_state_dict:

def state_dict(self) -> Dict[str, Any]:
    return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    pass

Load

restore_datamodule
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L151
becomes

datamodule.on_load_checkpoint(self._loaded_checkpoint)
if datamodule.__class__.__name__ in self._loaded_checkpoint:
    datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__name__])

If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @akihironitta @rohitgr7 @ananthsub @ninginthecloud

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions