Skip to content

Commit 6afe6c3

Browse files
committed
suboptimal
1 parent 24a2cc8 commit 6afe6c3

File tree

9 files changed

+103
-53
lines changed

9 files changed

+103
-53
lines changed

pytorch_lightning/core/saving.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytorch_lightning.utilities.apply_func import apply_to_collection
3131
from pytorch_lightning.utilities.cloud_io import get_filesystem
3232
from pytorch_lightning.utilities.cloud_io import load as pl_load
33+
from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint
3334
from pytorch_lightning.utilities.parsing import parse_class_init_keys
3435

3536
log = logging.getLogger(__name__)
@@ -134,6 +135,9 @@ def load_from_checkpoint(
134135
else:
135136
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
136137

138+
# convert legacy checkpoints to the new format
139+
checkpoint = migrate_checkpoint(checkpoint)
140+
137141
if hparams_file is not None:
138142
extension = hparams_file.split('.')[-1]
139143
if extension.lower() == 'csv':
@@ -148,6 +152,7 @@ def load_from_checkpoint(
148152
# overwrite hparams by the given file
149153
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
150154

155+
# TODO: make this a migration:
151156
# for past checkpoint need to add the new key
152157
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
153158
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
@@ -171,6 +176,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl
171176
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
172177

173178
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
179+
# TODO: make this a migration:
174180
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
175181
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
176182

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
3333
from pytorch_lightning.utilities.cloud_io import load as pl_load
3434
from pytorch_lightning.utilities.exceptions import MisconfigurationException
35-
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
3635

3736
if _APEX_AVAILABLE:
3837
from apex import amp
@@ -134,14 +133,6 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True)
134133
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
135134
)
136135

137-
if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
138-
raise ValueError(
139-
"The checkpoint you're attempting to load follows an"
140-
" outdated schema. You can upgrade to the current schema by running"
141-
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
142-
" where `model.ckpt` is your checkpoint file."
143-
)
144-
145136
# restore amp scaling
146137
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
147138
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])

pytorch_lightning/utilities/argparse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,11 @@ def _gpus_allowed_type(x) -> Union[int, str]:
286286
return int(x)
287287

288288

289-
def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
290-
# unused, but here for backward compatibility with old checkpoints that need to be able to
291-
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
292-
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
293-
pass
289+
# def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
290+
# # unused, but here for backward compatibility with old checkpoints that need to be able to
291+
# # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
292+
# # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
293+
# pass
294294

295295

296296
def _int_or_float_type(x) -> Union[int, float]:

pytorch_lightning/utilities/migration/__init__.py

Whitespace-only changes.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from distutils.version import LooseVersion
2+
3+
import pytorch_lightning.utilities.argparse
4+
5+
6+
def get_version(checkpoint: dict) -> str:
7+
return checkpoint["pytorch-lightning_version"]
8+
9+
10+
def set_version(checkpoint: dict, version: str):
11+
checkpoint["pytorch-lightning_version"] = version
12+
13+
14+
def should_upgrade(checkpoint: dict, target: str) -> bool:
15+
return LooseVersion(get_version(checkpoint)) < LooseVersion(target)
16+
17+
18+
class pl_legacy_patch:
19+
"""
20+
Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be
21+
included for unpickling old checkpoints.
22+
"""
23+
24+
def __enter__(self):
25+
setattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default", lambda x: x)
26+
return self
27+
28+
def __exit__(self, exc_type, exc_value, exc_traceback):
29+
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytorch_lightning as pl
2+
from pytorch_lightning.utilities.migration.base import set_version, should_upgrade
3+
4+
5+
# v0.10.0
6+
def migrate_model_checkpoint_early_stopping(checkpoint: dict) -> dict:
7+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
8+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
9+
keys_mapping = {
10+
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
11+
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
12+
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
13+
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
14+
"early_stop_callback_patience": (EarlyStopping, "patience"),
15+
}
16+
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
17+
18+
for key, new_path in keys_mapping.items():
19+
if key in checkpoint:
20+
value = checkpoint[key]
21+
callback_type, callback_key = new_path
22+
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
23+
checkpoint["callbacks"][callback_type][callback_key] = value
24+
del checkpoint[key]
25+
return checkpoint
26+
27+
28+
# v1.3.1
29+
def migrate_callback_state_identifiers(checkpoint):
30+
if "callbacks" not in checkpoint:
31+
return
32+
callbacks = checkpoint["callbacks"]
33+
checkpoint["callbacks"] = dict((callback_type.__name__, state) for callback_type, state in callbacks.items())
34+
return checkpoint
35+
36+
37+
def migrate_checkpoint(checkpoint: dict):
38+
""" Applies all the above migrations in order. """
39+
if should_upgrade(checkpoint, "0.10.0"):
40+
migrate_model_checkpoint_early_stopping(checkpoint)
41+
if should_upgrade(checkpoint, "1.3.0"):
42+
migrate_callback_state_identifiers(checkpoint)
43+
set_version(checkpoint, "1.3.0")
44+
set_version(checkpoint, pl.__version__)
45+
return checkpoint

pytorch_lightning/utilities/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torchmetrics import Metric
5+
56
"""
67
Convention:
78
- Do not include any `_TYPE` suffix

pytorch_lightning/utilities/upgrade_checkpoint.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,11 @@
1717

1818
import torch
1919

20-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
21-
22-
KEYS_MAPPING = {
23-
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
24-
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
25-
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
26-
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
27-
"early_stop_callback_patience": (EarlyStopping, "patience"),
28-
}
20+
from pytorch_lightning.utilities.migration.base import pl_legacy_patch
21+
from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint
2922

3023
log = logging.getLogger(__name__)
3124

32-
33-
def upgrade_checkpoint(filepath):
34-
checkpoint = torch.load(filepath)
35-
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
36-
37-
for key, new_path in KEYS_MAPPING.items():
38-
if key in checkpoint:
39-
value = checkpoint[key]
40-
callback_type, callback_key = new_path
41-
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
42-
checkpoint["callbacks"][callback_type][callback_key] = value
43-
del checkpoint[key]
44-
45-
torch.save(checkpoint, filepath)
46-
47-
4825
if __name__ == "__main__":
4926

5027
parser = argparse.ArgumentParser(
@@ -57,4 +34,7 @@ def upgrade_checkpoint(filepath):
5734

5835
log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
5936
copyfile(args.file, args.file + ".bak")
60-
upgrade_checkpoint(args.file)
37+
with pl_legacy_patch():
38+
checkpoint = torch.load(args.file)
39+
migrate_checkpoint(checkpoint)
40+
torch.save(checkpoint, args.file)

tests/utilities/test_upgrade_checkpoint.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
15-
1614
import pytest
17-
import torch
15+
import pytorch_lightning as pl
1816

19-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
20-
from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint
17+
from pytorch_lightning.utilities.migration.base import set_version, get_version
18+
from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint
2119

2220

2321
@pytest.mark.parametrize(
@@ -33,7 +31,7 @@
3331
"epoch": 1,
3432
"global_step": 23,
3533
"callbacks": {
36-
ModelCheckpoint: {
34+
"ModelCheckpoint": {
3735
"best_model_score": 0.34
3836
}
3937
}
@@ -49,7 +47,7 @@
4947
"epoch": 1,
5048
"global_step": 23,
5149
"callbacks": {
52-
ModelCheckpoint: {
50+
"ModelCheckpoint": {
5351
"best_model_score": 0.99
5452
}
5553
}
@@ -65,7 +63,7 @@
6563
"epoch": 1,
6664
"global_step": 23,
6765
"callbacks": {
68-
ModelCheckpoint: {
66+
"ModelCheckpoint": {
6967
"best_model_path": 'path'
7068
}
7169
}
@@ -82,7 +80,7 @@
8280
"epoch": 1,
8381
"global_step": 23,
8482
"callbacks": {
85-
EarlyStopping: {
83+
"EarlyStopping": {
8684
"wait_count": 2,
8785
"patience": 4
8886
}
@@ -92,8 +90,8 @@
9290
],
9391
)
9492
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
95-
filepath = os.path.join(tmpdir, "model.ckpt")
96-
torch.save(old_checkpoint, filepath)
97-
upgrade_checkpoint(filepath)
98-
updated_checkpoint = torch.load(filepath)
93+
set_version(old_checkpoint, "0.9.0")
94+
set_version(new_checkpoint, pl.__version__)
95+
updated_checkpoint = migrate_checkpoint(old_checkpoint)
9996
assert updated_checkpoint == new_checkpoint
97+
assert get_version(updated_checkpoint) == pl.__version__

0 commit comments

Comments
 (0)