Skip to content

Commit 277b0b8

Browse files
committed
migration
1 parent 3da62ff commit 277b0b8

File tree

5 files changed

+169
-72
lines changed

5 files changed

+169
-72
lines changed

src/pytorch_lightning/core/saving.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from lightning_lite.utilities.cloud_io import load as pl_load
3232
from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH
3333
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
34-
from pytorch_lightning.utilities.migration import pl_legacy_patch
34+
from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch
3535
from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys
3636
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
3737

@@ -156,6 +156,9 @@ def _load_from_checkpoint(
156156
with pl_legacy_patch():
157157
checkpoint = pl_load(checkpoint_path, map_location=map_location)
158158

159+
# convert legacy checkpoints to the new format
160+
checkpoint = migrate_checkpoint(checkpoint)
161+
159162
if hparams_file is not None:
160163
extension = str(hparams_file).split(".")[-1]
161164
if extension.lower() == "csv":
@@ -168,6 +171,7 @@ def _load_from_checkpoint(
168171
# overwrite hparams by the given file
169172
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
170173

174+
# TODO: make this a migration:
171175
# for past checkpoint need to add the new key
172176
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
173177
# override the hparams with values that were passed in
@@ -197,6 +201,7 @@ def _load_state(
197201
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
198202

199203
if issubclass(cls, pl.LightningModule):
204+
# TODO: make this a migration:
200205
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
201206
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
202207
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
3333
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3434
from pytorch_lightning.utilities.imports import _fault_tolerant_training
35-
from pytorch_lightning.utilities.migration import pl_legacy_patch
35+
from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch
3636
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
37-
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
3837

3938
if _OMEGACONF_AVAILABLE:
4039
from omegaconf import Container
@@ -86,13 +85,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
8685
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
8786
with pl_legacy_patch():
8887
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
89-
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
90-
raise ValueError(
91-
"The checkpoint you're attempting to load follows an"
92-
" outdated schema. You can upgrade to the current schema by running"
93-
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
94-
" where `model.ckpt` is your checkpoint file."
95-
)
88+
loaded_checkpoint = migrate_checkpoint(loaded_checkpoint)
9689
return loaded_checkpoint
9790

9891
def _set_ckpt_path(
@@ -348,23 +341,6 @@ def restore_loops(self) -> None:
348341
return
349342

350343
fit_loop = self.trainer.fit_loop
351-
pl_module = self.trainer.lightning_module
352-
assert pl_module is not None
353-
354-
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
355-
# it will be overwritten by the loop's state if it was also saved
356-
batch_loop = fit_loop.epoch_loop.batch_loop
357-
if pl_module.automatic_optimization:
358-
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
359-
"global_step"
360-
]
361-
else:
362-
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]
363-
364-
# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
365-
# it will be overwritten by the loop's state if it was also saved
366-
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
367-
368344
assert self.trainer.state.fn is not None
369345
state_dict = self._loaded_checkpoint.get("loops")
370346
if state_dict is not None:

src/pytorch_lightning/utilities/migration.py

Lines changed: 150 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,29 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from __future__ import annotations
14+
"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version.
15+
16+
When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially,
17+
see :func:`migrate_checkpoint`.
18+
"""
1519

1620
import sys
17-
import threading
21+
from distutils.version import LooseVersion
1822
from types import ModuleType, TracebackType
23+
from typing import Any, Dict, Optional, Type
1924

25+
import pytorch_lightning as pl
2026
import pytorch_lightning.utilities.argparse
2127

22-
# Create a global lock to ensure no race condition with deleting sys modules
23-
_lock = threading.Lock()
28+
_CHECKPOINT = Dict[str, Any]
2429

2530

2631
class pl_legacy_patch:
2732
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
2833
unpickling old checkpoints. The following patches apply.
2934
3035
1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
31-
version 1.2.8. See: https://github.com/Lightning-AI/lightning/pull/6898
36+
version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
3237
2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
3338
but still needs to be available for import for legacy checkpoints.
3439
@@ -38,20 +43,156 @@ class pl_legacy_patch:
3843
torch.load("path/to/legacy/checkpoint.ckpt")
3944
"""
4045

41-
def __enter__(self) -> None:
42-
_lock.acquire()
46+
def __enter__(self) -> "pl_legacy_patch":
4347
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
4448
legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils")
4549
sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module
4650

4751
# `_gpus_arg_default` used to be imported from these locations
4852
legacy_argparse_module._gpus_arg_default = lambda x: x
4953
pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x
54+
return self
5055

5156
def __exit__(
52-
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None
57+
self,
58+
exc_type: Optional[Type[BaseException]],
59+
exc_value: Optional[BaseException],
60+
exc_traceback: Optional[TracebackType],
5361
) -> None:
5462
if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"):
5563
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
5664
del sys.modules["pytorch_lightning.utilities.argparse_utils"]
57-
_lock.release()
65+
66+
67+
def get_version(checkpoint: _CHECKPOINT) -> str:
68+
"""Get the version of a Lightning checkpoint."""
69+
return checkpoint["pytorch-lightning_version"]
70+
71+
72+
def set_version(checkpoint: _CHECKPOINT, version: str) -> None:
73+
"""Set the version of a Lightning checkpoint."""
74+
checkpoint["pytorch-lightning_version"] = version
75+
76+
77+
def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool:
78+
"""Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
79+
return LooseVersion(get_version(checkpoint)) < LooseVersion(target)
80+
81+
82+
def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
83+
"""Applies all migrations below in order."""
84+
if should_upgrade(checkpoint, "0.10.0"):
85+
_migrate_model_checkpoint_early_stopping(checkpoint)
86+
if should_upgrade(checkpoint, "1.6.0"):
87+
_migrate_loop_global_step_to_progress_tracking(checkpoint)
88+
_migrate_loop_current_epoch_to_progress_tracking(checkpoint)
89+
90+
set_version(checkpoint, pl.__version__)
91+
92+
# TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert
93+
# checkpoints permanently
94+
return checkpoint
95+
96+
97+
def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
98+
"""The checkpoint and early stopping keys were renamed.
99+
100+
Version: 0.10.0
101+
Commit:
102+
"""
103+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
104+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
105+
106+
keys_mapping = {
107+
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
108+
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
109+
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
110+
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
111+
"early_stop_callback_patience": (EarlyStopping, "patience"),
112+
}
113+
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
114+
115+
for key, new_path in keys_mapping.items():
116+
if key in checkpoint:
117+
value = checkpoint[key]
118+
callback_type, callback_key = new_path
119+
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
120+
checkpoint["callbacks"][callback_type][callback_key] = value
121+
del checkpoint[key]
122+
return checkpoint
123+
124+
125+
def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
126+
"""Set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
127+
It will be overwritten by the loop's state if it was also saved.
128+
129+
Version: 1.6.0
130+
Commit:
131+
"""
132+
global_step = checkpoint["global_step"]
133+
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
134+
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
135+
# for automatic optimization
136+
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]
137+
optim_progress["optimizer"]["step"]["total"]["completed"] = global_step
138+
# for manual optimization
139+
optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]
140+
optim_step_progress["total"]["completed"] = global_step
141+
return checkpoint
142+
143+
144+
def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
145+
"""Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
146+
It will be overwritten by the loop's state if it was also saved.
147+
148+
Version: 1.6.0
149+
Commit:
150+
"""
151+
epoch = checkpoint["epoch"]
152+
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
153+
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
154+
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch
155+
156+
157+
_FIT_LOOP_INITIAL_STATE_1_6_0 = {
158+
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
159+
"current": {"completed": 0, "ready": 0},
160+
"total": {"completed": 0, "ready": 0},
161+
},
162+
"epoch_loop.batch_loop.manual_loop.state_dict": {},
163+
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
164+
"optimizer": {
165+
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
166+
"zero_grad": {
167+
"current": {"completed": 0, "ready": 0, "started": 0},
168+
"total": {"completed": 0, "ready": 0, "started": 0},
169+
},
170+
},
171+
"optimizer_position": 0,
172+
},
173+
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
174+
"epoch_loop.batch_loop.state_dict": {},
175+
"epoch_loop.batch_progress": {
176+
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
177+
"is_last_batch": False,
178+
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
179+
},
180+
"epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
181+
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
182+
"epoch_loop.val_loop.dataloader_progress": {
183+
"current": {"completed": 0, "ready": 0},
184+
"total": {"completed": 0, "ready": 0},
185+
},
186+
"epoch_loop.val_loop.epoch_loop.batch_progress": {
187+
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
188+
"is_last_batch": False,
189+
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
190+
},
191+
"epoch_loop.val_loop.epoch_loop.state_dict": {},
192+
"epoch_loop.val_loop.state_dict": {},
193+
"epoch_progress": {
194+
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
195+
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
196+
},
197+
"state_dict": {},
198+
}

src/pytorch_lightning/utilities/upgrade_checkpoint.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,11 @@
1717

1818
import torch
1919

20-
from lightning_lite.utilities.types import _PATH
21-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
22-
from pytorch_lightning.utilities.migration import pl_legacy_patch
23-
24-
KEYS_MAPPING = {
25-
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
26-
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
27-
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
28-
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
29-
"early_stop_callback_patience": (EarlyStopping, "patience"),
30-
}
20+
from pytorch_lightning.utilities.migration.base import pl_legacy_patch
21+
from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint
3122

3223
log = logging.getLogger(__name__)
3324

34-
35-
def upgrade_checkpoint(filepath: _PATH) -> None:
36-
checkpoint = torch.load(filepath)
37-
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
38-
39-
for key, new_path in KEYS_MAPPING.items():
40-
if key in checkpoint:
41-
value = checkpoint[key]
42-
callback_type, callback_key = new_path
43-
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
44-
checkpoint["callbacks"][callback_type][callback_key] = value
45-
del checkpoint[key]
46-
47-
torch.save(checkpoint, filepath)
48-
49-
5025
if __name__ == "__main__":
5126

5227
parser = argparse.ArgumentParser(
@@ -61,4 +36,6 @@ def upgrade_checkpoint(filepath: _PATH) -> None:
6136
log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
6237
copyfile(args.file, args.file + ".bak")
6338
with pl_legacy_patch():
64-
upgrade_checkpoint(args.file)
39+
checkpoint = torch.load(args.file)
40+
migrate_checkpoint(checkpoint)
41+
torch.save(checkpoint, args.file)

tests/tests_pytorch/utilities/test_upgrade_checkpoint.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
15-
1614
import pytest
17-
import torch
1815

16+
import pytorch_lightning as pl
1917
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
20-
from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint
18+
from pytorch_lightning.utilities.migration import get_version, migrate_checkpoint, set_version
2119

2220

2321
@pytest.mark.parametrize(
@@ -42,8 +40,8 @@
4240
],
4341
)
4442
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
45-
filepath = os.path.join(tmpdir, "model.ckpt")
46-
torch.save(old_checkpoint, filepath)
47-
upgrade_checkpoint(filepath)
48-
updated_checkpoint = torch.load(filepath)
43+
set_version(old_checkpoint, "0.9.0")
44+
set_version(new_checkpoint, pl.__version__)
45+
updated_checkpoint = migrate_checkpoint(old_checkpoint)
4946
assert updated_checkpoint == new_checkpoint
47+
assert get_version(updated_checkpoint) == pl.__version__

0 commit comments

Comments
 (0)