-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactor
Standardize all stateful components on state_dict/load_state_dict.
Background
PyTorch convention uses state_dict/load_state_dict for gathering and loading object state. In lightning, some components follow this convention, while other components do not (see Appedix: current state gathering section)
Motivation
Each component should contribute saving/loading their own state with the same APIs.
- independently contributing/loading one’s own local component state (aligning with PyTorch primitives:
state_dict/load_state_dict) - operating and depending on global component state (Lightning
CheckpointHooks:on_save/load_checkpoint)
This issue is focused on aligning all components on 1. Following this convention will allow consistency across Lightning components and consistency with PyTorch conventions. Any stateful component can simply implement their own state_dict/load_state_dict methods to contribute their own state.
For now 2 is only adjusted as needed (Callbacks).
Pitch
Lightning already has this Stateful Protocol in auto_restart.py. We can move this _SupportsStateDict out to the more central core/hooks.py file:
@runtime_checkable
class Stateful(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, Stateful)`."""
def state_dict(self) -> Dict[str, Any]:
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...Additional context
Part of #7740
Appendix: current state gathering:
Specifically, current components contribute state in the following different ways:
- Some components contribute saving/loading their own state with only
on_save/load_checkpoint[DataModule, PrecisionPlugin, Callbacks]
a. [DataModule, PrecisionPlugin] useon_save/load_checkpointfromCheckpointHookshttps://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/core/hooks.py#L765-L807
b. [Callbacks] contribute saving/loading their own state with differenton_save/load_checkpointhook signatures and functionality https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/callbacks/base.py#L293-L323
c. though [Loops] falls under 3. below, noting here that Loops also have their ownon_save/load_checkpointmethods with different signatures. https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/loops/base.py#L253-L262 - Some components contribute saving/loading their own state with only
state_dict/load_state_dictcalls [Optimizers, LR schedulers] - Some components have both [LightningModule, Loops]
Appendix: component, checkpoint_connector changes
Save aligning on state_dict
dump_checkpoint
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L310-L393
becomes:
def dump_checkpoint(self, weights_only: bool = False) -> dict:
...
# dump callbacks
# checkpoint["callbacks"] = self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
# becomes
checkpoint["callbacks"] = self.trainer._call_callbacks_state_dict()
...
# precision plugin
# self.trainer.precision_plugin.on_save_checkpoint(checkpoint)
# becomes
prec_plugin = self.trainer.precision_plugin
checkpoint[prec_plugin.__class__.__name__] = self.trainer.precision_plugin.state_dict()
...
# give the model a chance to dump a few things
# model.on_save_checkpoint(checkpoint)
# if self.trainer.datamodule is not None:
# self.trainer.datamodule.on_save_checkpoint(checkpoint)
# becomes
# datamodule state
dm = self.trainer.datamodule
if dm is not None:
checkpoint[dm.__class__.__name__] = dm.state_dict()
# on_save_checkpoint calls
model.on_save_checkpoint(checkpoint)
dm.on_save_checkpoint(checkpoint)
prec_plugin.on_save_checkpoint(checkpoint)
for callback in self.trainer.callbacks:
callback.on_save_checkpoint(self.trainer, model, checkpoint)Load aligning on load_state_dict
see component Load sections below
Callbacks:
Base Class
becomes
BC: on_save_checkpoint returns None instead of dict
BC: on_load_checkpoint arg takes entire checkpoint dict instead of callback_state
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
"""Called by Lightning when saving a checkpoint to give you a chance to store or customize anything
else you might want to save.
Args:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
checkpoint: the checkpoint dictionary that will be saved.
"""
pass
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
"""Called by Lightning when loading a checkpoint to give you a chance to reload or customize anything
else you may have saved in on_save_checkpoint.
Args:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
checkpoint: entire loaded checkpoint dictionary
"""
passSave
def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
...
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)becomes
def _call_callbacks_state_dict(self) -> Dict[str, dict]:
...
state = callback.state_dict()Load
def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
...
callback.on_load_checkpoint(self, self.lightning_module, state)becomes
def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
...
callback.load_state_dict(state)restore_callbacks
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L205
becomes
for callback in self.trainer.callbacks:
callback.on_load_checkpoint(self.trainer, self.trainer.lightning_module, self._loaded_checkpoint)
self.trainer._call_callbacks_load_state_dict(self._loaded_checkpoint)Callback classes
update timer, pruning, model_checkpoint, finetuning, early_stopping to use state_dict/load_state_dict instead of on_save/load_checkpoint.
Precision Plugin:
Base Class
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/precision_plugin.py
add dummy state_dict/load_state_dict:
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
passLoad
restore_training_state
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L189-L190
becomes
# restore precision plugin (scaler etc.)
prec_plugin = self.trainer.precision_plugin
prec_plugin.on_load_checkpoint(self._loaded_checkpoint)
if prec_plugin.__class__.__name__ in self._loaded_checkpoint:
prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__name__])Precision Plugin classes
Update apex_amp, native_amp to use state_dict/load_state_dict instead of on_save/load_checkpoint
Datamodule:
Base Class
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/datamodule.py
add dummy state_dict/load_state_dict:
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
passLoad
restore_datamodule
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L151
becomes
datamodule.on_load_checkpoint(self._loaded_checkpoint)
if datamodule.__class__.__name__ in self._loaded_checkpoint:
datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__name__])If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @justusschock @awaelchli @akihironitta @rohitgr7 @ananthsub @ninginthecloud