Skip to content

Commit 94f7d23

Browse files
awaelchlicarmocca
andauthored
Introduce checkpoint migration (#15237)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 6aa6423 commit 94f7d23

File tree

15 files changed

+399
-150
lines changed

15 files changed

+399
-150
lines changed

src/lightning_lite/utilities/device_dtype_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Optional, Union
15+
from typing import Any, List, Optional, Union
1616

1717
import torch
1818
from torch.nn import Module
1919
from typing_extensions import Self
2020

2121

2222
class _DeviceDtypeModuleMixin(Module):
23-
__jit_unused_properties__ = ["device", "dtype"]
23+
__jit_unused_properties__: List[str] = ["device", "dtype"]
2424

2525
def __init__(self) -> None:
2626
super().__init__()

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
-
13+
- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
1414

1515
-
1616

@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
### Changed
2121

22-
-
22+
- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
2323

2424
-
2525

@@ -57,7 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5757

5858
## [1.8.0] - 2022-11-01
5959

60-
6160
### Added
6261

6362
- Added support for requeueing slurm array jobs ([#15040](https://github.com/Lightning-AI/lightning/pull/15040))

src/pytorch_lightning/core/mixins/hparams_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
import inspect
1616
import types
1717
from argparse import Namespace
18-
from typing import Any, MutableMapping, Optional, Sequence, Union
18+
from typing import Any, List, MutableMapping, Optional, Sequence, Union
1919

2020
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
2121
from pytorch_lightning.utilities.parsing import AttributeDict, save_hyperparameters
2222

2323

2424
class HyperparametersMixin:
2525

26-
__jit_unused_properties__ = ["hparams", "hparams_initial"]
26+
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]
2727

2828
def __init__(self) -> None:
2929
super().__init__()

src/pytorch_lightning/core/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class LightningModule(
7575
):
7676
# Below is for property support of JIT
7777
# since none of these are important when using JIT, we are going to ignore them.
78-
__jit_unused_properties__ = (
78+
__jit_unused_properties__: List[str] = (
7979
[
8080
"example_input_array",
8181
"on_gpu",

src/pytorch_lightning/core/saving.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from argparse import Namespace
2121
from copy import deepcopy
2222
from enum import Enum
23+
from pathlib import Path
2324
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
2425
from warnings import warn
2526

@@ -32,6 +33,7 @@
3233
from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH
3334
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
3435
from pytorch_lightning.utilities.migration import pl_legacy_patch
36+
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
3537
from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys
3638
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
3739

@@ -156,6 +158,11 @@ def _load_from_checkpoint(
156158
with pl_legacy_patch():
157159
checkpoint = pl_load(checkpoint_path, map_location=map_location)
158160

161+
# convert legacy checkpoints to the new format
162+
checkpoint = _pl_migrate_checkpoint(
163+
checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
164+
)
165+
159166
if hparams_file is not None:
160167
extension = str(hparams_file).split(".")[-1]
161168
if extension.lower() == "csv":
@@ -168,6 +175,7 @@ def _load_from_checkpoint(
168175
# overwrite hparams by the given file
169176
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
170177

178+
# TODO: make this a migration:
171179
# for past checkpoint need to add the new key
172180
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
173181
# override the hparams with values that were passed in
@@ -198,6 +206,7 @@ def _load_state(
198206
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
199207

200208
if issubclass(cls, pl.LightningModule):
209+
# TODO: make this a migration:
201210
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
202211
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
203212
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3434
from pytorch_lightning.utilities.imports import _fault_tolerant_training
3535
from pytorch_lightning.utilities.migration import pl_legacy_patch
36+
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
3637
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
37-
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
3838

3939
if _OMEGACONF_AVAILABLE:
4040
from omegaconf import Container
@@ -81,19 +81,9 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
8181
return
8282

8383
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
84-
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
85-
86-
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
8784
with pl_legacy_patch():
8885
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
89-
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
90-
raise ValueError(
91-
"The checkpoint you're attempting to load follows an"
92-
" outdated schema. You can upgrade to the current schema by running"
93-
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
94-
" where `model.ckpt` is your checkpoint file."
95-
)
96-
return loaded_checkpoint
86+
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
9787

9888
def _set_ckpt_path(
9989
self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool

src/pytorch_lightning/utilities/migration.py

Lines changed: 0 additions & 57 deletions
This file was deleted.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401
16+
from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version.
15+
16+
When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially,
17+
see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`.
18+
19+
For the Lightning developer: How to add a new migration?
20+
21+
1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include
22+
version information as well as the specific commit or PR where the breaking change happened.
23+
2. Add the function to the `_migration_index()` below. The key in the index is the version of Lightning in which the
24+
change happened. Any checkpoint with a version greater or equal to that version will apply the given function.
25+
Multiple migrations per version get executed in the provided list order.
26+
3. You can test the migration on a checkpoint (backup your files first) by running:
27+
28+
cp model.ckpt model.ckpt.backup
29+
python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt
30+
"""
31+
32+
from typing import Any, Callable, Dict, List
33+
34+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
35+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
36+
37+
_CHECKPOINT = Dict[str, Any]
38+
39+
40+
def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
41+
"""Migration functions returned here will get executed in the order they are listed."""
42+
return {
43+
"0.10.0": [_migrate_model_checkpoint_early_stopping],
44+
}
45+
46+
47+
def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
48+
"""The checkpoint and early stopping keys were renamed.
49+
50+
Version: 0.10.0
51+
Commit: a5d1176
52+
"""
53+
keys_mapping = {
54+
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
55+
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
56+
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
57+
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
58+
"early_stop_callback_patience": (EarlyStopping, "patience"),
59+
}
60+
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
61+
62+
for key, new_path in keys_mapping.items():
63+
if key in checkpoint:
64+
value = checkpoint[key]
65+
callback_type, callback_key = new_path
66+
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
67+
checkpoint["callbacks"][callback_type][callback_key] = value
68+
del checkpoint[key]
69+
return checkpoint

0 commit comments

Comments
 (0)