Skip to content

Commit d69b33f

Browse files
jjenniferdairohitgr7ananthsub
authored
Introduce Stateful PrecisionPlugin (#11638)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: ananthsub <[email protected]>
1 parent 914f685 commit d69b33f

File tree

6 files changed

+80
-16
lines changed

6 files changed

+80
-16
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
108108
- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637))
109109

110110

111+
- Added `_Stateful` support for `PrecisionPlugin` ([#11638](https://github.com/PyTorchLightning/pytorch-lightning/pull/11638))
112+
113+
111114
- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
112115

113116

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,21 @@ def optimizer_step(
9393
return optimizer.step(**kwargs)
9494
return closure_result
9595

96+
def state_dict(self) -> Dict[str, Any]:
97+
return amp.state_dict()
98+
99+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
100+
amp.load_state_dict(state_dict)
101+
96102
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
97-
if "amp_scaling_state" in checkpoint:
98-
amp.load_state_dict(checkpoint["amp_scaling_state"])
103+
"""``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6.
104+
105+
Lightning will auto-restore ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.load_state_dict``
106+
instead
107+
"""
99108

100109
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
101-
checkpoint["amp_scaling_state"] = amp.state_dict()
110+
"""``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6.
111+
112+
Lightning will auto-save ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.state_dict`` instead
113+
"""

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,24 @@ def forward_context(self) -> Generator[None, None, None]:
108108
with self.autocast_context_manager():
109109
yield
110110

111+
def state_dict(self) -> Dict[str, Any]:
112+
if self.scaler is not None:
113+
return self.scaler.state_dict()
114+
return {}
115+
116+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
117+
if self.scaler is not None:
118+
self.scaler.load_state_dict(state_dict)
119+
111120
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
112-
if self.scaler is not None and "native_amp_scaling_state" in checkpoint:
113-
self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])
121+
"""``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6.
122+
123+
Lightning will auto-restore NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.load_state_dict``
124+
instead
125+
"""
114126

115127
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
116-
if self.scaler is not None:
117-
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()
128+
"""``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6.
129+
130+
Lightning will auto-save NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.state_dict`` instead
131+
"""

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import contextlib
1515
from functools import partial
16-
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
16+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
1717

1818
import torch
1919
from torch import Tensor
@@ -242,3 +242,20 @@ def teardown(self) -> None:
242242
243243
It is the right place to release memory and free other resources.
244244
"""
245+
246+
def state_dict(self) -> Dict[str, Any]:
247+
"""Called when saving a checkpoint, implement to generate precision plugin state_dict.
248+
249+
Returns:
250+
A dictionary containing precision plugin state.
251+
"""
252+
return {}
253+
254+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
255+
"""Called when loading a checkpoint, implement to reload precision plugin state given precision plugin
256+
state_dict.
257+
258+
Args:
259+
state_dict: the precision plugin state returned by ``state_dict``.
260+
"""
261+
pass

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pytorch_lightning as pl
2424
from pytorch_lightning.plugins.environments import SLURMEnvironment
25+
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
2526
from pytorch_lightning.trainer.states import TrainerFn
2627
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
2728
from pytorch_lightning.utilities.cloud_io import get_filesystem
@@ -196,7 +197,7 @@ def restore_training_state(self) -> None:
196197
return
197198

198199
# restore precision plugin (scaler etc.)
199-
self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
200+
self.restore_precision_plugin_state()
200201

201202
# restore loops and their progress
202203
self.restore_loops()
@@ -206,6 +207,21 @@ def restore_training_state(self) -> None:
206207
# restore optimizers and schedulers state
207208
self.restore_optimizers_and_schedulers()
208209

210+
def restore_precision_plugin_state(self) -> None:
211+
"""Restore the precision plugin state from the pre-loaded checkpoint."""
212+
prec_plugin = self.trainer.precision_plugin
213+
prec_plugin.on_load_checkpoint(self._loaded_checkpoint)
214+
if prec_plugin.__class__.__qualname__ in self._loaded_checkpoint:
215+
prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__])
216+
217+
# old checkpoints compatibility
218+
if "amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, ApexMixedPrecisionPlugin):
219+
prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"])
220+
if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(
221+
prec_plugin, NativeMixedPrecisionPlugin
222+
):
223+
prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"])
224+
209225
def restore_callbacks(self) -> None:
210226
"""Restores all callbacks from the pre-loaded checkpoint."""
211227
if not self._loaded_checkpoint:
@@ -318,9 +334,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
318334
'callbacks': "callback specific state"[] # if not weights_only
319335
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
320336
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
321-
'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp
322-
'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp
323337
'state_dict': Model's state_dict (e.g. network weights)
338+
precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only
324339
CHECKPOINT_HYPER_PARAMS_NAME:
325340
CHECKPOINT_HYPER_PARAMS_KEY:
326341
CHECKPOINT_HYPER_PARAMS_TYPE:
@@ -357,7 +372,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
357372
lr_schedulers.append(config.scheduler.state_dict())
358373
checkpoint["lr_schedulers"] = lr_schedulers
359374

360-
self.trainer.precision_plugin.on_save_checkpoint(checkpoint)
375+
# precision plugin
376+
prec_plugin = self.trainer.precision_plugin
377+
prec_plugin_state_dict = prec_plugin.state_dict()
378+
if prec_plugin_state_dict:
379+
checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
380+
prec_plugin.on_save_checkpoint(checkpoint)
361381

362382
# dump hyper-parameters
363383
if model.hparams:

tests/models/test_hooks.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,10 +493,8 @@ def training_step(self, batch, batch_idx):
493493
"state_dict": ANY,
494494
"loops": ANY,
495495
}
496-
if kwargs.get("amp_backend") == "native":
497-
saved_ckpt["native_amp_scaling_state"] = ANY
498-
elif kwargs.get("amp_backend") == "apex":
499-
saved_ckpt["amp_scaling_state"] = ANY
496+
if kwargs.get("amp_backend") == "native" or kwargs.get("amp_backend") == "apex":
497+
saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY
500498
device = torch.device("cuda:0" if "gpus" in kwargs else "cpu")
501499
expected = [
502500
dict(name="Callback.on_init_start", args=(trainer,)),

0 commit comments

Comments
 (0)