Skip to content

Commit d20b4f4

Browse files
committed
migrate checkpoint on load
wrap
1 parent 94c5299 commit d20b4f4

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from pytorch_lightning.utilities import rank_zero_warn
2727
from pytorch_lightning.utilities.cloud_io import atomic_save
2828
from pytorch_lightning.utilities.cloud_io import load as pl_load
29+
from pytorch_lightning.utilities.migration.base import pl_legacy_patch
30+
from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint
2931

3032
TBroadcast = TypeVar("T")
3133

@@ -216,7 +218,10 @@ def restore_model_state_from_ckpt_path(
216218
bool: Wether to load optimizer / lr_schedulers states from checkpoint
217219
218220
"""
219-
ckpt = pl_load(ckpt_path, map_location=map_location)
221+
with pl_legacy_patch():
222+
ckpt = pl_load(ckpt_path, map_location=map_location)
223+
ckpt = migrate_checkpoint(ckpt)
224+
220225
# restore datamodule states
221226
if self.lightning_module.trainer.datamodule is not None:
222227
self.lightning_module.trainer.datamodule.on_load_checkpoint(ckpt)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
3333
from pytorch_lightning.utilities.cloud_io import load as pl_load
3434
from pytorch_lightning.utilities.exceptions import MisconfigurationException
35+
from pytorch_lightning.utilities.migration.base import pl_legacy_patch
36+
from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint
3537

3638
if _APEX_AVAILABLE:
3739
from apex import amp
@@ -89,9 +91,11 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
8991
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
9092
return False
9193

92-
checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path(
93-
checkpoint_path, map_location=lambda storage, loc: storage
94-
)
94+
with pl_legacy_patch():
95+
checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path(
96+
checkpoint_path, map_location=lambda storage, loc: storage
97+
)
98+
migrate_checkpoint(checkpoint)
9599

96100
model = self.trainer.lightning_module
97101

pytorch_lightning/utilities/migration/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ def __enter__(self):
2626
return self
2727

2828
def __exit__(self, exc_type, exc_value, exc_traceback):
29-
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
29+
if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"):
30+
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")

0 commit comments

Comments
 (0)