Skip to content

Commit aa82615

Browse files
committed
Fix amp restore place
1 parent 203060b commit aa82615

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,6 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
127127
# restore the state_dict on the model
128128
model.load_state_dict(checkpoint['state_dict'])
129129

130-
# restore amp scaling
131-
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
132-
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
133-
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
134-
amp.load_state_dict(checkpoint['amp_scaling_state'])
135-
136130
def restore_training_state(self, checkpoint):
137131
"""
138132
Restore trainer state.
@@ -155,6 +149,12 @@ def restore_training_state(self, checkpoint):
155149
" where `model.ckpt` is your checkpoint file."
156150
)
157151

152+
# restore amp scaling
153+
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
154+
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
155+
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
156+
amp.load_state_dict(checkpoint['amp_scaling_state'])
157+
158158
# restore callback states
159159
self.trainer.on_load_checkpoint(checkpoint)
160160

0 commit comments

Comments
 (0)