2222
2323import pytorch_lightning as pl
2424from pytorch_lightning .plugins .environments import SLURMEnvironment
25+ from pytorch_lightning .plugins .precision import ApexMixedPrecisionPlugin , NativeMixedPrecisionPlugin
2526from pytorch_lightning .trainer .states import TrainerFn
2627from pytorch_lightning .utilities import _OMEGACONF_AVAILABLE
2728from pytorch_lightning .utilities .cloud_io import get_filesystem
@@ -196,7 +197,7 @@ def restore_training_state(self) -> None:
196197 return
197198
198199 # restore precision plugin (scaler etc.)
199- self .trainer . precision_plugin . on_load_checkpoint ( self . _loaded_checkpoint )
200+ self .restore_precision_plugin_state ( )
200201
201202 # restore loops and their progress
202203 self .restore_loops ()
@@ -206,6 +207,21 @@ def restore_training_state(self) -> None:
206207 # restore optimizers and schedulers state
207208 self .restore_optimizers_and_schedulers ()
208209
210+ def restore_precision_plugin_state (self ) -> None :
211+ """Restore the precision plugin state from the pre-loaded checkpoint."""
212+ prec_plugin = self .trainer .precision_plugin
213+ prec_plugin .on_load_checkpoint (self ._loaded_checkpoint )
214+ if prec_plugin .__class__ .__qualname__ in self ._loaded_checkpoint :
215+ prec_plugin .load_state_dict (self ._loaded_checkpoint [prec_plugin .__class__ .__qualname__ ])
216+
217+ # old checkpoints compatibility
218+ if "amp_scaling_state" in self ._loaded_checkpoint and isinstance (prec_plugin , ApexMixedPrecisionPlugin ):
219+ prec_plugin .load_state_dict (self ._loaded_checkpoint ["amp_scaling_state" ])
220+ if "native_amp_scaling_state" in self ._loaded_checkpoint and isinstance (
221+ prec_plugin , NativeMixedPrecisionPlugin
222+ ):
223+ prec_plugin .load_state_dict (self ._loaded_checkpoint ["native_amp_scaling_state" ])
224+
209225 def restore_callbacks (self ) -> None :
210226 """Restores all callbacks from the pre-loaded checkpoint."""
211227 if not self ._loaded_checkpoint :
@@ -318,9 +334,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
318334 'callbacks': "callback specific state"[] # if not weights_only
319335 'optimizer_states': "PT optim's state_dict"[] # if not weights_only
320336 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
321- 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp
322- 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp
323337 'state_dict': Model's state_dict (e.g. network weights)
338+ precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only
324339 CHECKPOINT_HYPER_PARAMS_NAME:
325340 CHECKPOINT_HYPER_PARAMS_KEY:
326341 CHECKPOINT_HYPER_PARAMS_TYPE:
@@ -357,7 +372,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
357372 lr_schedulers .append (config .scheduler .state_dict ())
358373 checkpoint ["lr_schedulers" ] = lr_schedulers
359374
360- self .trainer .precision_plugin .on_save_checkpoint (checkpoint )
375+ # precision plugin
376+ prec_plugin = self .trainer .precision_plugin
377+ prec_plugin_state_dict = prec_plugin .state_dict ()
378+ if prec_plugin_state_dict :
379+ checkpoint [prec_plugin .__class__ .__qualname__ ] = prec_plugin_state_dict
380+ prec_plugin .on_save_checkpoint (checkpoint )
361381
362382 # dump hyper-parameters
363383 if model .hparams :
0 commit comments