From 89131c2ac262a4ed21be35b8654e53a205f3ef68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:06:51 +0200 Subject: [PATCH 01/45] class name as key --- pytorch_lightning/trainer/callback_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 606f6b2e4b52b..62bfa00ddcf1a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -243,7 +243,7 @@ def __is_old_signature(fn: Callable) -> bool: return True return False - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: @@ -257,10 +257,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[type(callback)] = state + callback_states[callback.__class__.__name__] = state return callback_states - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') # Todo: the `callback_states` are dropped with TPUSpawn as they @@ -268,7 +268,7 @@ def on_load_checkpoint(self, checkpoint): # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(type(callback)) + state = callback_states.get(callback.__class__.__name__) if state: state = deepcopy(state) callback.on_load_checkpoint(state) From 63fb9830fbc7b4f81807035c03d1d8343df07bfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:37:22 +0200 Subject: [PATCH 02/45] string state identifier --- pytorch_lightning/callbacks/base.py | 4 ++++ pytorch_lightning/trainer/callback_hook.py | 5 +++-- .../trainer/connectors/checkpoint_connector.py | 2 +- tests/callbacks/test_early_stopping.py | 2 +- tests/trainer/connectors/test_callback_connector.py | 6 +++--- tests/trainer/logging_/test_logger_connector.py | 2 +- 6 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 768e4ebca30ee..a214904b1c31e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -29,6 +29,10 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ + @property + def state_identifier(self) -> str: + return self.__class__.__name__ + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 62bfa00ddcf1a..dfc4922e14277 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -257,18 +257,19 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.__class__.__name__] = state + callback_states[callback.state_identifier] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') + version = checkpoint.get('pytorch-lightning_version') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.__class__.__name__) + state = callback_states.get(callback.state_identifier) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4ae42e4bad6ac..887bd4064dc6c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -244,7 +244,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: structured dictionary: { 'epoch': training epoch 'global_step': training global step - 'pytorch-lightning_version': PyTorch Lightning's version + 'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cc619077ee136..7768e112d27cc 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -76,7 +76,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state + assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss') diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..aa9faa59a3188 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -68,11 +68,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) - state0 = ckpt["callbacks"][type(callback0)] - state1 = ckpt["callbacks"][type(callback1)] + state0 = ckpt["callbacks"]["StatefulCallback0"] + state1 = ckpt["callbacks"]["StatefulCallback1"] assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == 1 - assert type(checkpoint_callback) in ckpt["callbacks"] + assert "ModelCheckpoint" in ckpt["callbacks"] def test_attach_model_callbacks(): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 923821a5e50e4..310f43f94a177 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -268,7 +268,7 @@ def test_dataloader(self): def test_call_back_validator(tmpdir): - funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) + funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_') and callable(getattr(Callback, f))]) callbacks_func = [ 'on_after_backward', From 7dc218a7b491879b2d745be40ba6a1f18da7cd6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:51:04 +0200 Subject: [PATCH 03/45] add legacy state loading --- pytorch_lightning/callbacks/base.py | 6 +++++- pytorch_lightning/trainer/callback_hook.py | 6 +++++- tests/checkpointing/test_legacy_checkpoints.py | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a214904b1c31e..bf82899373230 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from pytorch_lightning.core.lightning import LightningModule @@ -33,6 +33,10 @@ class Callback(abc.ABC): def state_identifier(self) -> str: return self.__class__.__name__ + @property + def _legacy_state_identifier(self) -> Type: + return type(self) + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index dfc4922e14277..0f7d6a6fe5066 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,6 +14,7 @@ from abc import ABC from copy import deepcopy +from distutils.version import LooseVersion from inspect import signature from typing import Any, Callable, Dict, List, Optional, Type @@ -269,7 +270,10 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.state_identifier) + state = ( + callback_states.get(callback.state_identifier) + or callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7d1284ee0d329..4080eb1deb788 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -60,6 +60,7 @@ "1.2.5", "1.2.6", "1.2.7", + "1.2.8", ] ) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): From 04b588b7ebf11894ebefa2b7bd8cdfc9936b6464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:03:11 +0200 Subject: [PATCH 04/45] update test --- tests/checkpointing/test_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f58ff768759e8..5143e4e1eb33f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -840,7 +840,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_last = torch.load(path_last) assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) - ch_type = type(model_checkpoint) + ch_type = "ModelCheckpoint" assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1098,7 +1098,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts] + ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] From bb11e2872356ce8bbcf1f333cdec59cbe554d501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:07:56 +0200 Subject: [PATCH 05/45] update tests --- tests/checkpointing/test_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 5143e4e1eb33f..d2a57a6c4d2b8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -147,7 +147,7 @@ def configure_optimizers(self): assert chk['epoch'] == epoch + 1 assert chk['global_step'] == limit_train_batches * (epoch + 1) - mc_specific_data = chk['callbacks'][type(checkpoint)] + mc_specific_data = chk['callbacks']["ModelCheckpoint"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score @@ -251,7 +251,7 @@ def configure_optimizers(self): assert chk['epoch'] == epoch + 1 assert chk['global_step'] == per_epoch_steps * (global_ix + 1) - mc_specific_data = chk['callbacks'][type(checkpoint)] + mc_specific_data = chk['callbacks']["ModelCheckpoint"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score From 0259ecbdf72fe3fdbecd1524bcae62e9e328b871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 11:20:23 +0200 Subject: [PATCH 06/45] flake8 --- pytorch_lightning/trainer/callback_hook.py | 4 +--- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 0f7d6a6fe5066..fb5c82a2330db 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,9 +14,8 @@ from abc import ABC from copy import deepcopy -from distutils.version import LooseVersion from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -264,7 +263,6 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') - version = checkpoint.get('pytorch-lightning_version') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 17c1f0ea589a5..13236d32ea522 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -264,10 +264,10 @@ def _make_assertions(epoch, ix, add=''): expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num) assert chk['global_step'] == expected_global_step - mc_specific_data = chk['callbacks']["ModelCheckpoint"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk['callbacks']["ModelCheckpoint"] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score if not reduce_lr_on_plateau: lr_scheduler_specific_data = chk['lr_schedulers'][0] From d56e5e47d7cf74199650eebe7d409acea7c29285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 12:33:07 +0200 Subject: [PATCH 07/45] add test --- .../checkpointing/test_legacy_checkpoints.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 4080eb1deb788..d4d053f6136ac 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,11 +14,18 @@ import glob import os import sys +from copy import deepcopy +from pathlib import Path import pytest +import torch +from pytorch_lightning.callbacks.base import Callback + +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning import Trainer from tests import PATH_LEGACY +from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -87,3 +94,34 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): # assert result sys.path = orig_sys_paths + + +class StatefulCallback(Callback): + + def on_save_checkpoint(self, *args): + return {"content": 123} + + +def test_callback_state_loading_by_type(tmpdir): + """ Test that legacy checkpoints that don't use a state identifier can still be loaded. """ + model = BoringModel() + callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback], + ) + trainer.fit(model) + # simulate old format where type(callback) was the key + new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) + old_checkpiont = deepcopy(new_checkpoint) + old_checkpiont["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} + torch.save(old_checkpiont, Path(tmpdir, "old.ckpt")) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + callbacks=[callback], + resume_from_checkpoint=Path(tmpdir, "old.ckpt"), + ) + trainer.fit(model) From 72ba44026c411476bae3544ebcd67d2961c6ffa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 21:39:11 +0200 Subject: [PATCH 08/45] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/callback_hook.py | 7 ++----- tests/checkpointing/test_legacy_checkpoints.py | 6 +++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 839d78d276efe..4f35e30f7ee5d 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -293,16 +293,13 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - callback_states = checkpoint.get('callbacks') + callback_states: Dict[str, dict] = checkpoint.get('callbacks') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = ( - callback_states.get(callback.state_identifier) - or callback_states.get(callback._legacy_state_identifier) - ) + state = callback_states.get(callback.state_identifier, callback_states.get(callback._legacy_state_identifier)) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index d4d053f6136ac..c44dc526a1314 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -114,9 +114,9 @@ def test_callback_state_loading_by_type(tmpdir): trainer.fit(model) # simulate old format where type(callback) was the key new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) - old_checkpiont = deepcopy(new_checkpoint) - old_checkpiont["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} - torch.save(old_checkpiont, Path(tmpdir, "old.ckpt")) + old_checkpoint = deepcopy(new_checkpoint) + old_checkpoint["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} + torch.save(old_checkpoint, Path(tmpdir, "old.ckpt")) trainer = Trainer( default_root_dir=tmpdir, From 79d85684d80443d449e937aba433e5425a736020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 22:21:37 +0200 Subject: [PATCH 09/45] improve test --- .../checkpointing/test_legacy_checkpoints.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index c44dc526a1314..b3e639c0ee3d5 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,15 +14,11 @@ import glob import os import sys -from copy import deepcopy from pathlib import Path import pytest -import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - from pytorch_lightning import Trainer from tests import PATH_LEGACY from tests.helpers import BoringModel @@ -96,32 +92,41 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): sys.path = orig_sys_paths -class StatefulCallback(Callback): +class OldStatefulCallback(Callback): + + def __init__(self, state): + self.state = state + + @property + def state_identifier(self): + return type(self) def on_save_checkpoint(self, *args): - return {"content": 123} + return {"state": self.state} + + def on_load_checkpoint(self, callback_state): + self.state = callback_state["state"] -def test_callback_state_loading_by_type(tmpdir): - """ Test that legacy checkpoints that don't use a state identifier can still be loaded. """ +def test_resume_callback_state_saved_by_type(tmpdir): + """ Test that a legacy checkpoint that didn't use a state identifier before can still be loaded. """ model = BoringModel() - callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) + callback = OldStatefulCallback(state=111) trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback], ) trainer.fit(model) - # simulate old format where type(callback) was the key - new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) - old_checkpoint = deepcopy(new_checkpoint) - old_checkpoint["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} - torch.save(old_checkpoint, Path(tmpdir, "old.ckpt")) + ckpt_path = Path(trainer.checkpoint_callback.best_model_path) + assert ckpt_path.exists() + callback = OldStatefulCallback(state=222) trainer = Trainer( default_root_dir=tmpdir, max_steps=2, callbacks=[callback], - resume_from_checkpoint=Path(tmpdir, "old.ckpt"), + resume_from_checkpoint=ckpt_path, ) trainer.fit(model) + assert callback.state == 111 From d9ea8c165fc7983b6611e5c769a3fdbe70d59575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 22:29:34 +0200 Subject: [PATCH 10/45] flake --- pytorch_lightning/trainer/callback_hook.py | 4 +++- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 4f35e30f7ee5d..82bfd644ca3bc 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -299,7 +299,9 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.state_identifier, callback_states.get(callback._legacy_state_identifier)) + state = callback_states.get( + callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index b3e639c0ee3d5..40d755d76f2c6 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -17,9 +17,9 @@ from pathlib import Path import pytest -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.base import Callback from tests import PATH_LEGACY from tests.helpers import BoringModel From 0851f0d439d30e1cda22310ba76eeb5c22d92e7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:16:11 +0200 Subject: [PATCH 11/45] fix merge --- pytorch_lightning/trainer/callback_hook.py | 17 +++++------------ tests/checkpointing/test_legacy_checkpoints.py | 5 +++-- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 2ecdb5159d146..343019171f406 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Type, Union import torch @@ -282,19 +282,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 - if callback_states is not None: - for callback in self.callbacks: - state = callback_states.get( - callback.state_identifier, callback_states.get(callback._legacy_state_identifier) - ) - if state: - state = deepcopy(state) - callback.on_load_checkpoint(state) - if callback_states is None: return @@ -309,7 +300,9 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get(type(callback)) + state = callback_states.get( + callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 8c5adf19357dc..4bbccb5c85006 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,8 +18,9 @@ import pytest -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from tests import _PATH_LEGACY +from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -110,7 +111,7 @@ def state_identifier(self): def on_save_checkpoint(self, *args): return {"state": self.state} - def on_load_checkpoint(self, callback_state): + def on_load_checkpoint(self, trainer, pl_module, callback_state): self.state = callback_state["state"] From 82d5658a80d63974d9d4fe5dc1227cd0d8051c4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 09:17:52 +0000 Subject: [PATCH 12/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/governance.rst | 2 +- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 5c29f7d0da544..4114ccdb8a818 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -39,7 +39,7 @@ Board Alumni ------ -- Jeff Yang (`ydcjeff `_) +- Jeff Yang (`ydcjeff `_) - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 4bbccb5c85006..2dc499c1882d6 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,7 +18,7 @@ import pytest -from pytorch_lightning import Trainer, Callback +from pytorch_lightning import Callback, Trainer from tests import _PATH_LEGACY from tests.helpers import BoringModel From 334fd4a9d191db99d4cb10048fc9be82c64f27cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:21:18 +0200 Subject: [PATCH 13/45] use qualname --- pytorch_lightning/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 05685f7e9a688..a0af628c0aefc 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -35,7 +35,7 @@ class Callback(abc.ABC): @property def state_identifier(self) -> str: - return self.__class__.__name__ + return self.__class__.__qualname__ @property def _legacy_state_identifier(self) -> Type: From f144fd188af921b245eea9177c43d559aec9a818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:44:57 +0200 Subject: [PATCH 14/45] rename state_id --- pytorch_lightning/callbacks/base.py | 4 ++-- pytorch_lightning/trainer/callback_hook.py | 4 ++-- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a0af628c0aefc..f05ddd9217844 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,11 +34,11 @@ class Callback(abc.ABC): """ @property - def state_identifier(self) -> str: + def state_id(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_identifier(self) -> Type: + def _legacy_state_id(self) -> Type: return type(self) def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 343019171f406..5083bcae51263 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -277,7 +277,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.state_identifier] = state + callback_states[callback.state_id] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: @@ -301,7 +301,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: for callback in self.callbacks: state = callback_states.get( - callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + callback.state_id, callback_states.get(callback._legacy_state_id) ) if state: state = deepcopy(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 2dc499c1882d6..cb43a65ac5610 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -105,7 +105,7 @@ def __init__(self, state): self.state = state @property - def state_identifier(self): + def state_id(self): return type(self) def on_save_checkpoint(self, *args): From 615498670d3f9ae886c5f617af39963b4fd22df6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:45:33 +0200 Subject: [PATCH 15/45] fix diff --- pytorch_lightning/trainer/callback_hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5083bcae51263..3075a3cfa8c77 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -282,10 +282,11 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") + if callback_states is None: return From 17459fe94f0a2454cab6ee44555d8314acaab0fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 15:00:00 +0200 Subject: [PATCH 16/45] suboptimal --- pytorch_lightning/core/saving.py | 6 +++ pytorch_lightning/utilities/argparse.py | 10 ++--- .../utilities/migration/__init__.py | 0 pytorch_lightning/utilities/migration/base.py | 29 ++++++++++++ .../utilities/migration/migrations.py | 45 +++++++++++++++++++ .../utilities/upgrade_checkpoint.py | 32 +++---------- tests/utilities/test_upgrade_checkpoint.py | 24 +++++----- 7 files changed, 102 insertions(+), 44 deletions(-) create mode 100644 pytorch_lightning/utilities/migration/__init__.py create mode 100644 pytorch_lightning/utilities/migration/base.py create mode 100644 pytorch_lightning/utilities/migration/migrations.py diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 94b7062be0756..7897b677566ae 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint from pytorch_lightning.utilities.parsing import parse_class_init_keys log = logging.getLogger(__name__) @@ -133,6 +134,9 @@ def load_from_checkpoint( else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + # convert legacy checkpoints to the new format + checkpoint = migrate_checkpoint(checkpoint) + if hparams_file is not None: extension = hparams_file.split('.')[-1] if extension.lower() == 'csv': @@ -147,6 +151,7 @@ def load_from_checkpoint( # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + # TODO: make this a migration: # for past checkpoint need to add the new key if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} @@ -170,6 +175,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys + # TODO: make this a migration: for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index aebbcb41ac34f..44b7fa7933e64 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -289,11 +289,11 @@ def _gpus_allowed_type(x) -> Union[int, str]: return int(x) -def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover - # unused, but here for backward compatibility with old checkpoints that need to be able to - # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 - # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 - pass +# def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover +# # unused, but here for backward compatibility with old checkpoints that need to be able to +# # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 +# # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 +# pass def _int_or_float_type(x) -> Union[int, float]: diff --git a/pytorch_lightning/utilities/migration/__init__.py b/pytorch_lightning/utilities/migration/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py new file mode 100644 index 0000000000000..fa3f6bb3aaf5a --- /dev/null +++ b/pytorch_lightning/utilities/migration/base.py @@ -0,0 +1,29 @@ +from distutils.version import LooseVersion + +import pytorch_lightning.utilities.argparse + + +def get_version(checkpoint: dict) -> str: + return checkpoint["pytorch-lightning_version"] + + +def set_version(checkpoint: dict, version: str): + checkpoint["pytorch-lightning_version"] = version + + +def should_upgrade(checkpoint: dict, target: str) -> bool: + return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + + +class pl_legacy_patch: + """ + Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be + included for unpickling old checkpoints. + """ + + def __enter__(self): + setattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default", lambda x: x) + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py new file mode 100644 index 0000000000000..0c85f95e60a03 --- /dev/null +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -0,0 +1,45 @@ +import pytorch_lightning as pl +from pytorch_lightning.utilities.migration.base import set_version, should_upgrade + + +# v0.10.0 +def migrate_model_checkpoint_early_stopping(checkpoint: dict) -> dict: + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + keys_mapping = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), + } + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in keys_mapping.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + return checkpoint + + +# v1.3.1 +def migrate_callback_state_identifiers(checkpoint): + if "callbacks" not in checkpoint: + return + callbacks = checkpoint["callbacks"] + checkpoint["callbacks"] = dict((callback_type.__name__, state) for callback_type, state in callbacks.items()) + return checkpoint + + +def migrate_checkpoint(checkpoint: dict): + """ Applies all the above migrations in order. """ + if should_upgrade(checkpoint, "0.10.0"): + migrate_model_checkpoint_early_stopping(checkpoint) + if should_upgrade(checkpoint, "1.3.0"): + migrate_callback_state_identifiers(checkpoint) + set_version(checkpoint, "1.3.0") + set_version(checkpoint, pl.__version__) + return checkpoint diff --git a/pytorch_lightning/utilities/upgrade_checkpoint.py b/pytorch_lightning/utilities/upgrade_checkpoint.py index 4896845f10263..6fc482bff1f7c 100644 --- a/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,34 +17,11 @@ import torch -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint - -KEYS_MAPPING = { - "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), - "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), - "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), - "early_stop_callback_wait": (EarlyStopping, "wait_count"), - "early_stop_callback_patience": (EarlyStopping, "patience"), -} +from pytorch_lightning.utilities.migration.base import pl_legacy_patch +from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint log = logging.getLogger(__name__) - -def upgrade_checkpoint(filepath): - checkpoint = torch.load(filepath) - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in KEYS_MAPPING.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - - torch.save(checkpoint, filepath) - - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -57,4 +34,7 @@ def upgrade_checkpoint(filepath): log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") copyfile(args.file, args.file + ".bak") - upgrade_checkpoint(args.file) + with pl_legacy_patch(): + checkpoint = torch.load(args.file) + migrate_checkpoint(checkpoint) + torch.save(checkpoint, args.file) diff --git a/tests/utilities/test_upgrade_checkpoint.py b/tests/utilities/test_upgrade_checkpoint.py index 82801cb27c407..8a6e7a46f727c 100644 --- a/tests/utilities/test_upgrade_checkpoint.py +++ b/tests/utilities/test_upgrade_checkpoint.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest -import torch +import pytorch_lightning as pl -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint +from pytorch_lightning.utilities.migration.base import set_version, get_version +from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint @pytest.mark.parametrize( @@ -33,7 +31,7 @@ "epoch": 1, "global_step": 23, "callbacks": { - ModelCheckpoint: { + "ModelCheckpoint": { "best_model_score": 0.34 } } @@ -49,7 +47,7 @@ "epoch": 1, "global_step": 23, "callbacks": { - ModelCheckpoint: { + "ModelCheckpoint": { "best_model_score": 0.99 } } @@ -65,7 +63,7 @@ "epoch": 1, "global_step": 23, "callbacks": { - ModelCheckpoint: { + "ModelCheckpoint": { "best_model_path": 'path' } } @@ -82,7 +80,7 @@ "epoch": 1, "global_step": 23, "callbacks": { - EarlyStopping: { + "EarlyStopping": { "wait_count": 2, "patience": 4 } @@ -92,8 +90,8 @@ ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - filepath = os.path.join(tmpdir, "model.ckpt") - torch.save(old_checkpoint, filepath) - upgrade_checkpoint(filepath) - updated_checkpoint = torch.load(filepath) + set_version(old_checkpoint, "0.9.0") + set_version(new_checkpoint, pl.__version__) + updated_checkpoint = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == new_checkpoint + assert get_version(updated_checkpoint) == pl.__version__ From f84d047104de71a8f0546dcbbdaff0094b0a45f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 15:43:38 +0200 Subject: [PATCH 17/45] clean up --- pytorch_lightning/utilities/argparse.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 44b7fa7933e64..a988eca63a7e3 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -289,13 +289,6 @@ def _gpus_allowed_type(x) -> Union[int, str]: return int(x) -# def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover -# # unused, but here for backward compatibility with old checkpoints that need to be able to -# # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 -# # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 -# pass - - def _int_or_float_type(x) -> Union[int, float]: if '.' in str(x): return float(x) From d4dc3b3de46d4b9ec0c8e560b7211aa866c55a08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 23:31:54 +0200 Subject: [PATCH 18/45] migrate checkpoint on load wrap --- .../trainer/connectors/checkpoint_connector.py | 7 +++++++ pytorch_lightning/utilities/migration/base.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e2968a1cf6e29..4bca4def1f61c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -25,6 +25,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS +from pytorch_lightning.utilities.migration.base import pl_legacy_patch +from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -104,6 +106,11 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: # restore module states self.restore_datamodule() self.restore_model() + #with pl_legacy_patch(): + # checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( + # checkpoint_path, map_location=lambda storage, loc: storage + # ) + #migrate_checkpoint(checkpoint) # restore callback states self.restore_callbacks() diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py index fa3f6bb3aaf5a..46a5b5223efe2 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration/base.py @@ -26,4 +26,5 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): - delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") + if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): + delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") From 26200690719d40d9c8005445f07c1f56a7c2d9b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:00:52 +0200 Subject: [PATCH 19/45] update legacy patching --- .../trainer/connectors/checkpoint_connector.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4bca4def1f61c..3ad1d15e011e5 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -71,7 +71,9 @@ def resume_start(self) -> None: raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.") rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}") - self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) + with pl_legacy_patch(): + self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) + migrate_checkpoint(self._loaded_checkpoint) def resume_end(self) -> None: """ Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ @@ -106,11 +108,6 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: # restore module states self.restore_datamodule() self.restore_model() - #with pl_legacy_patch(): - # checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( - # checkpoint_path, map_location=lambda storage, loc: storage - # ) - #migrate_checkpoint(checkpoint) # restore callback states self.restore_callbacks() @@ -153,7 +150,9 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> """ Restore only the model weights. """ checkpoint = self._loaded_checkpoint if checkpoint_path is not None: - checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) + with pl_legacy_patch(): + checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) + migrate_checkpoint(checkpoint) self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint) From 97cba2948d06bfaae501a082f151eaae55e010f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 10:02:26 +0000 Subject: [PATCH 20/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/callback_hook.py | 6 ++---- .../trainer/connectors/checkpoint_connector.py | 2 +- tests/utilities/test_upgrade_checkpoint.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 3075a3cfa8c77..4fae7edc2aa97 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -286,7 +286,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") - + if callback_states is None: return @@ -301,9 +301,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get( - callback.state_id, callback_states.get(callback._legacy_state_id) - ) + state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3ad1d15e011e5..b3f4617680414 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -24,9 +24,9 @@ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_enabled -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS from pytorch_lightning.utilities.migration.base import pl_legacy_patch from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container diff --git a/tests/utilities/test_upgrade_checkpoint.py b/tests/utilities/test_upgrade_checkpoint.py index 8a6e7a46f727c..70ed8059d9361 100644 --- a/tests/utilities/test_upgrade_checkpoint.py +++ b/tests/utilities/test_upgrade_checkpoint.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -import pytorch_lightning as pl -from pytorch_lightning.utilities.migration.base import set_version, get_version +import pytorch_lightning as pl +from pytorch_lightning.utilities.migration.base import get_version, set_version from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint From d1365a1e87474c8dc8d623723d2d321cfb93c27d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 14:26:03 +0200 Subject: [PATCH 21/45] reset --- pytorch_lightning/trainer/callback_hook.py | 24 ++++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 4fae7edc2aa97..ffcac8f9073f6 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type import torch @@ -34,19 +34,19 @@ class TrainerCallbackHookMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] - lightning_module: 'pl.LightningModule' + lightning_module: "pl.LightningModule" - def on_before_accelerator_backend_setup(self, model: 'pl.LightningModule') -> None: + def on_before_accelerator_backend_setup(self, model: "pl.LightningModule") -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def configure_sharded_model(self, model: 'pl.LightningModule') -> None: + def configure_sharded_model(self, model: "pl.LightningModule") -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_configure_sharded_model(self, model) - def setup(self, model: 'pl.LightningModule', stage: Optional[str]) -> None: + def setup(self, model: "pl.LightningModule", stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage=stage) @@ -263,7 +263,7 @@ def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) return len(parameters) == 1 and parameters[0] != "args" - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: @@ -277,15 +277,16 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.state_id] = state + callback_states[type(callback)] = state return callback_states - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint): """Called when loading a model checkpoint.""" + # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 - callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") + callback_states = checkpoint.get("callbacks") if callback_states is None: return @@ -297,11 +298,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: rank_zero_warn( "Be aware that when using ``resume_from_checkpoint``, " "callbacks used to create the checkpoint need to be provided. " - f"Please, add the following callbacks: {list(difference)}. ", UserWarning + f"Please, add the following callbacks: {list(difference)}. ", + UserWarning, ) for callback in self.callbacks: - state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) + state = callback_states.get(type(callback)) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): From 4b2f70c31720184ccb82111ac0db5f407a1a4de2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 14:27:29 +0200 Subject: [PATCH 22/45] apply black --- pytorch_lightning/callbacks/base.py | 2 +- .../functional/precision_recall_curve.py | 3 +- .../utilities/migration/migrations.py | 1 + .../checkpointing/test_legacy_checkpoints.py | 1 - tests/checkpointing/test_model_checkpoint.py | 16 ++--- tests/utilities/test_upgrade_checkpoint.py | 66 +++---------------- 6 files changed, 20 insertions(+), 69 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index de568a8c60522..851d2b953c8c6 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -41,7 +41,7 @@ def state_id(self) -> str: def _legacy_state_id(self) -> Type: return type(self) - def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before configure sharded model""" def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index 93914c146e82f..93b203fae129b 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -27,7 +27,8 @@ def precision_recall_curve( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], ]: """ .. deprecated:: diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py index 0c85f95e60a03..3f4fe974ed05e 100644 --- a/pytorch_lightning/utilities/migration/migrations.py +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -6,6 +6,7 @@ def migrate_model_checkpoint_early_stopping(checkpoint: dict) -> dict: from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + keys_mapping = { "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index fa9397671320f..34df42a28e81f 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -100,7 +100,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): class OldStatefulCallback(Callback): - def __init__(self, state): self.state = state diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a84ca6bb1c852..270ecdbf51b9b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -148,10 +148,10 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk['callbacks']["ModelCheckpoint"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + assert mc_specific_data["dirpath"] == checkpoint.dirpath + assert mc_specific_data["monitor"] == monitor + assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] @@ -259,10 +259,10 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk['callbacks']["ModelCheckpoint"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + assert mc_specific_data["dirpath"] == checkpoint.dirpath + assert mc_specific_data["monitor"] == monitor + assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] diff --git a/tests/utilities/test_upgrade_checkpoint.py b/tests/utilities/test_upgrade_checkpoint.py index 70ed8059d9361..163b6835b51b4 100644 --- a/tests/utilities/test_upgrade_checkpoint.py +++ b/tests/utilities/test_upgrade_checkpoint.py @@ -22,70 +22,20 @@ "old_checkpoint, new_checkpoint", [ ( - { - "epoch": 1, - "global_step": 23, - "checkpoint_callback_best": 0.34 - }, - { - "epoch": 1, - "global_step": 23, - "callbacks": { - "ModelCheckpoint": { - "best_model_score": 0.34 - } - } - }, + {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, + {"epoch": 1, "global_step": 23, "callbacks": {"ModelCheckpoint": {"best_model_score": 0.34}}}, ), ( - { - "epoch": 1, - "global_step": 23, - "checkpoint_callback_best_model_score": 0.99 - }, - { - "epoch": 1, - "global_step": 23, - "callbacks": { - "ModelCheckpoint": { - "best_model_score": 0.99 - } - } - }, + {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, + {"epoch": 1, "global_step": 23, "callbacks": {"ModelCheckpoint": {"best_model_score": 0.99}}}, ), ( - { - "epoch": 1, - "global_step": 23, - "checkpoint_callback_best_model_path": 'path' - }, - { - "epoch": 1, - "global_step": 23, - "callbacks": { - "ModelCheckpoint": { - "best_model_path": 'path' - } - } - }, + {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, + {"epoch": 1, "global_step": 23, "callbacks": {"ModelCheckpoint": {"best_model_path": "path"}}}, ), ( - { - "epoch": 1, - "global_step": 23, - "early_stop_callback_wait": 2, - "early_stop_callback_patience": 4 - }, - { - "epoch": 1, - "global_step": 23, - "callbacks": { - "EarlyStopping": { - "wait_count": 2, - "patience": 4 - } - } - }, + {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, + {"epoch": 1, "global_step": 23, "callbacks": {"EarlyStopping": {"wait_count": 2, "patience": 4}}}, ), ], ) From f23b57b105e1b14cd7e3a348fb7875564a7be17a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 12:30:28 +0000 Subject: [PATCH 23/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../metrics/functional/precision_recall_curve.py | 3 +-- .../utilities/migration/migrations.py | 2 +- tests/checkpointing/test_legacy_checkpoints.py | 15 +++------------ 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index 93b203fae129b..93914c146e82f 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -27,8 +27,7 @@ def precision_recall_curve( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], ]: """ .. deprecated:: diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py index 3f4fe974ed05e..32a22a55a5773 100644 --- a/pytorch_lightning/utilities/migration/migrations.py +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -36,7 +36,7 @@ def migrate_callback_state_identifiers(checkpoint): def migrate_checkpoint(checkpoint: dict): - """ Applies all the above migrations in order. """ + """Applies all the above migrations in order.""" if should_upgrade(checkpoint, "0.10.0"): migrate_model_checkpoint_early_stopping(checkpoint) if should_upgrade(checkpoint, "1.3.0"): diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 34df42a28e81f..242dba8830475 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -115,24 +115,15 @@ def on_load_checkpoint(self, trainer, pl_module, callback_state): def test_resume_callback_state_saved_by_type(tmpdir): - """ Test that a legacy checkpoint that didn't use a state identifier before can still be loaded. """ + """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" model = BoringModel() callback = OldStatefulCallback(state=111) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - callbacks=[callback], - ) + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) trainer.fit(model) ckpt_path = Path(trainer.checkpoint_callback.best_model_path) assert ckpt_path.exists() callback = OldStatefulCallback(state=222) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=2, - callbacks=[callback], - resume_from_checkpoint=ckpt_path, - ) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) trainer.fit(model) assert callback.state == 111 From a1bed7e109900789a561fbb6db9c29bafcac1eb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 14:31:59 +0200 Subject: [PATCH 24/45] reset branch --- pytorch_lightning/callbacks/base.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 851d2b953c8c6..a492be314df26 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional import torch from torch.optim import Optimizer @@ -33,14 +33,6 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ - @property - def state_id(self) -> str: - return self.__class__.__qualname__ - - @property - def _legacy_state_id(self) -> Type: - return type(self) - def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before configure sharded model""" From 60a9ca6af4ca2b86cd597ddfb1d982b1033e98d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 14:37:59 +0200 Subject: [PATCH 25/45] reset branch --- .../utilities/migration/migrations.py | 12 ------- tests/callbacks/test_early_stopping.py | 3 +- .../checkpointing/test_legacy_checkpoints.py | 34 +------------------ tests/checkpointing/test_model_checkpoint.py | 8 ++--- .../connectors/test_callback_connector.py | 6 ++-- tests/utilities/test_upgrade_checkpoint.py | 9 ++--- 6 files changed, 14 insertions(+), 58 deletions(-) diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py index 32a22a55a5773..87e6974b0af2e 100644 --- a/pytorch_lightning/utilities/migration/migrations.py +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -26,21 +26,9 @@ def migrate_model_checkpoint_early_stopping(checkpoint: dict) -> dict: return checkpoint -# v1.3.1 -def migrate_callback_state_identifiers(checkpoint): - if "callbacks" not in checkpoint: - return - callbacks = checkpoint["callbacks"] - checkpoint["callbacks"] = dict((callback_type.__name__, state) for callback_type, state in callbacks.items()) - return checkpoint - - def migrate_checkpoint(checkpoint: dict): """Applies all the above migrations in order.""" if should_upgrade(checkpoint, "0.10.0"): migrate_model_checkpoint_early_stopping(checkpoint) - if should_upgrade(checkpoint, "1.3.0"): - migrate_callback_state_identifiers(checkpoint) - set_version(checkpoint, "1.3.0") set_version(checkpoint, pl.__version__) return checkpoint diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index e5b28b9a29262..460d5d0731601 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -76,8 +76,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] - assert 4 == len(early_stop_callback.saved_states) - assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state + assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 242dba8830475..8693965a52abc 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,13 +14,11 @@ import glob import os import sys -from pathlib import Path import pytest -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Trainer from tests import _PATH_LEGACY -from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") CHECKPOINT_EXTENSION = ".ckpt" @@ -97,33 +95,3 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): # trainer.fit(model) sys.path = orig_sys_paths - - -class OldStatefulCallback(Callback): - def __init__(self, state): - self.state = state - - @property - def state_id(self): - return type(self) - - def on_save_checkpoint(self, *args): - return {"state": self.state} - - def on_load_checkpoint(self, trainer, pl_module, callback_state): - self.state = callback_state["state"] - - -def test_resume_callback_state_saved_by_type(tmpdir): - """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" - model = BoringModel() - callback = OldStatefulCallback(state=111) - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) - trainer.fit(model) - ckpt_path = Path(trainer.checkpoint_callback.best_model_path) - assert ckpt_path.exists() - - callback = OldStatefulCallback(state=222) - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) - trainer.fit(model) - assert callback.state == 111 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 270ecdbf51b9b..2b0b636d4cfad 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -148,7 +148,7 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + mc_specific_data = chk["callbacks"][type(checkpoint)] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -259,7 +259,7 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + mc_specific_data = chk["callbacks"][type(checkpoint)] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -858,7 +858,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ch_type = "ModelCheckpoint" + ch_type = type(model_checkpoint) assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1097,7 +1097,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] + ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 338de72a31fed..bdc19ee15aaad 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -76,11 +76,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) - state0 = ckpt["callbacks"]["StatefulCallback0"] - state1 = ckpt["callbacks"]["StatefulCallback1"] + state0 = ckpt["callbacks"][type(callback0)] + state1 = ckpt["callbacks"][type(callback1)] assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == 1 - assert "ModelCheckpoint" in ckpt["callbacks"] + assert type(checkpoint_callback) in ckpt["callbacks"] def test_attach_model_callbacks(): diff --git a/tests/utilities/test_upgrade_checkpoint.py b/tests/utilities/test_upgrade_checkpoint.py index 163b6835b51b4..255bc062be8b9 100644 --- a/tests/utilities/test_upgrade_checkpoint.py +++ b/tests/utilities/test_upgrade_checkpoint.py @@ -14,6 +14,7 @@ import pytest import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration.base import get_version, set_version from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint @@ -23,19 +24,19 @@ [ ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, - {"epoch": 1, "global_step": 23, "callbacks": {"ModelCheckpoint": {"best_model_score": 0.34}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, - {"epoch": 1, "global_step": 23, "callbacks": {"ModelCheckpoint": {"best_model_score": 0.99}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, - {"epoch": 1, "global_step": 23, "callbacks": {"ModelCheckpoint": {"best_model_path": "path"}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}}, ), ( {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {"EarlyStopping": {"wait_count": 2, "patience": 4}}}, + {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}}, ), ], ) From 9c011a971ec5474283b4dcf5909e71440dbf0f00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 14:45:31 +0200 Subject: [PATCH 26/45] add docs --- pytorch_lightning/utilities/migration/base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py index 46a5b5223efe2..a15dd2c4e7a1a 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration/base.py @@ -4,21 +4,32 @@ def get_version(checkpoint: dict) -> str: + """ Get the version of a Lightning checkpoint. """ return checkpoint["pytorch-lightning_version"] def set_version(checkpoint: dict, version: str): + """ Set the version of a Lightning checkpoint. """ checkpoint["pytorch-lightning_version"] = version def should_upgrade(checkpoint: dict, target: str) -> bool: + """ Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target. """ return LooseVersion(get_version(checkpoint)) < LooseVersion(target) class pl_legacy_patch: """ Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be - included for unpickling old checkpoints. + included for unpickling old checkpoints. The following patches apply. + + 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 + + Example: + + with pl_legacy_patch(): + torch.load("path/to/legacy/checkpoint.ckpt") """ def __enter__(self): From 314cc336938f57f6a83c62cc08a5b0ad2e639b17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 14:49:15 +0200 Subject: [PATCH 27/45] add docs and licence --- .../utilities/migration/__init__.py | 13 +++++++++++++ pytorch_lightning/utilities/migration/base.py | 13 +++++++++++++ .../utilities/migration/migrations.py | 19 +++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/pytorch_lightning/utilities/migration/__init__.py b/pytorch_lightning/utilities/migration/__init__.py index e69de29bb2d1d..d7aa17d7f8468 100644 --- a/pytorch_lightning/utilities/migration/__init__.py +++ b/pytorch_lightning/utilities/migration/__init__.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py index a15dd2c4e7a1a..a6eb1db7143c0 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration/base.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from distutils.version import LooseVersion import pytorch_lightning.utilities.argparse diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py index 87e6974b0af2e..1feaee8d380b3 100644 --- a/pytorch_lightning/utilities/migration/migrations.py +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -1,3 +1,22 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. + +When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary +sequentially, see :func:`migrate_checkpoint`. +""" import pytorch_lightning as pl from pytorch_lightning.utilities.migration.base import set_version, should_upgrade From ade1bbd10b0e73bedea1006bb18db7ded351c472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 15:07:16 +0200 Subject: [PATCH 28/45] remove obsolete warning --- .../trainer/connectors/checkpoint_connector.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a3d76facdaebd..0dfb2c84e251b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -26,7 +26,6 @@ from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.migration.base import pl_legacy_patch from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -176,13 +175,6 @@ def restore_callbacks(self) -> None: if not self._loaded_checkpoint: return - if any(key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): - raise ValueError( - "The checkpoint you're attempting to load follows an" - " outdated schema. You can upgrade to the current schema by running" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) self.trainer.on_load_checkpoint(self._loaded_checkpoint) def restore_loops(self) -> None: From fcdce4fbf50531fcb790175c5b41b2cdac2bcb00 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 13:08:13 +0000 Subject: [PATCH 29/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/migration/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py index a6eb1db7143c0..2c0c3c50ccb58 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration/base.py @@ -17,17 +17,17 @@ def get_version(checkpoint: dict) -> str: - """ Get the version of a Lightning checkpoint. """ + """Get the version of a Lightning checkpoint.""" return checkpoint["pytorch-lightning_version"] def set_version(checkpoint: dict, version: str): - """ Set the version of a Lightning checkpoint. """ + """Set the version of a Lightning checkpoint.""" checkpoint["pytorch-lightning_version"] = version def should_upgrade(checkpoint: dict, target: str) -> bool: - """ Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target. """ + """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" return LooseVersion(get_version(checkpoint)) < LooseVersion(target) From 900e41541280b084bc2fc1a8b85fdfe55ca28b89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 11:47:31 +0200 Subject: [PATCH 30/45] update legacy load --- .../trainer/connectors/checkpoint_connector.py | 1 - pytorch_lightning/utilities/argparse_utils.py | 7 ------- pytorch_lightning/utilities/migration/base.py | 13 ++++++++++++- 3 files changed, 12 insertions(+), 9 deletions(-) delete mode 100644 pytorch_lightning/utilities/argparse_utils.py diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5acb4f15b991d..fac5967ce2a5f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -24,7 +24,6 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.migration.base import pl_legacy_patch from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py deleted file mode 100644 index e797719eb54a9..0000000000000 --- a/pytorch_lightning/utilities/argparse_utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from pytorch_lightning.utilities import rank_zero_deprecation - -rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4") - -# for backward compatibility with old checkpoints (versions < 1.2.0) -# that need to be able to unpickle the function from the checkpoint -from pytorch_lightning.utilities.argparse import _gpus_arg_default # noqa: E402, F401 # isort: skip diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py index 2c0c3c50ccb58..3c03503495c2c 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration/base.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys from distutils.version import LooseVersion +from types import ModuleType import pytorch_lightning.utilities.argparse @@ -38,6 +40,8 @@ class pl_legacy_patch: 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 + 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, + but still needs to be available for import for legacy checkpoints. Example: @@ -46,9 +50,16 @@ class pl_legacy_patch: """ def __enter__(self): - setattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default", lambda x: x) + # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` + legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") + sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module + + # `_gpus_arg_default` used to be imported from these locations + legacy_argparse_module._gpus_arg_default = lambda x: x + pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x return self def __exit__(self, exc_type, exc_value, exc_traceback): if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") + del sys.modules["pytorch_lightning.utilities.argparse_utils"] From e177bf1d789128c687a41994cd25edb28b44736c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:01:05 +0200 Subject: [PATCH 31/45] reset branch --- pytorch_lightning/core/saving.py | 6 --- .../connectors/checkpoint_connector.py | 14 +++-- .../{migration/base.py => migration.py} | 16 ------ .../utilities/migration/__init__.py | 13 ----- .../utilities/migration/migrations.py | 53 ------------------- .../utilities/upgrade_checkpoint.py | 32 +++++++++-- tests/utilities/test_upgrade_checkpoint.py | 15 +++--- 7 files changed, 45 insertions(+), 104 deletions(-) rename pytorch_lightning/utilities/{migration/base.py => migration.py} (78%) delete mode 100644 pytorch_lightning/utilities/migration/__init__.py delete mode 100644 pytorch_lightning/utilities/migration/migrations.py diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index a979a3d218080..79608bfc1c5c1 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -29,7 +29,6 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint from pytorch_lightning.utilities.parsing import parse_class_init_keys log = logging.getLogger(__name__) @@ -131,9 +130,6 @@ def load_from_checkpoint( else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - # convert legacy checkpoints to the new format - checkpoint = migrate_checkpoint(checkpoint) - if hparams_file is not None: extension = hparams_file.split(".")[-1] if extension.lower() == "csv": @@ -148,7 +144,6 @@ def load_from_checkpoint( # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams - # TODO: make this a migration: # for past checkpoint need to add the new key if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} @@ -172,7 +167,6 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys - # TODO: make this a migration: for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fac5967ce2a5f..5a33585157290 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -24,8 +24,9 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.migration.base import pl_legacy_patch -from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint +from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -63,7 +64,6 @@ def resume_start(self) -> None: rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") with pl_legacy_patch(): self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) - migrate_checkpoint(self._loaded_checkpoint) def resume_end(self) -> None: """Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" @@ -148,7 +148,6 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> if checkpoint_path is not None: with pl_legacy_patch(): checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) - migrate_checkpoint(checkpoint) self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint) @@ -173,6 +172,13 @@ def restore_callbacks(self) -> None: if not self._loaded_checkpoint: return + if any(key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): + raise ValueError( + "The checkpoint you're attempting to load follows an" + " outdated schema. You can upgrade to the current schema by running" + " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" + " where `model.ckpt` is your checkpoint file." + ) self.trainer.on_load_checkpoint(self._loaded_checkpoint) def restore_loops(self) -> None: diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration.py similarity index 78% rename from pytorch_lightning/utilities/migration/base.py rename to pytorch_lightning/utilities/migration.py index 3c03503495c2c..2eba0d2ea490a 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration.py @@ -12,27 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from distutils.version import LooseVersion from types import ModuleType import pytorch_lightning.utilities.argparse -def get_version(checkpoint: dict) -> str: - """Get the version of a Lightning checkpoint.""" - return checkpoint["pytorch-lightning_version"] - - -def set_version(checkpoint: dict, version: str): - """Set the version of a Lightning checkpoint.""" - checkpoint["pytorch-lightning_version"] = version - - -def should_upgrade(checkpoint: dict, target: str) -> bool: - """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" - return LooseVersion(get_version(checkpoint)) < LooseVersion(target) - - class pl_legacy_patch: """ Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be diff --git a/pytorch_lightning/utilities/migration/__init__.py b/pytorch_lightning/utilities/migration/__init__.py deleted file mode 100644 index d7aa17d7f8468..0000000000000 --- a/pytorch_lightning/utilities/migration/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py deleted file mode 100644 index 1feaee8d380b3..0000000000000 --- a/pytorch_lightning/utilities/migration/migrations.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. - -When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary -sequentially, see :func:`migrate_checkpoint`. -""" -import pytorch_lightning as pl -from pytorch_lightning.utilities.migration.base import set_version, should_upgrade - - -# v0.10.0 -def migrate_model_checkpoint_early_stopping(checkpoint: dict) -> dict: - from pytorch_lightning.callbacks.early_stopping import EarlyStopping - from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - - keys_mapping = { - "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), - "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), - "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), - "early_stop_callback_wait": (EarlyStopping, "wait_count"), - "early_stop_callback_patience": (EarlyStopping, "patience"), - } - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in keys_mapping.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - return checkpoint - - -def migrate_checkpoint(checkpoint: dict): - """Applies all the above migrations in order.""" - if should_upgrade(checkpoint, "0.10.0"): - migrate_model_checkpoint_early_stopping(checkpoint) - set_version(checkpoint, pl.__version__) - return checkpoint diff --git a/pytorch_lightning/utilities/upgrade_checkpoint.py b/pytorch_lightning/utilities/upgrade_checkpoint.py index 46038701b1e52..34483ce39b925 100644 --- a/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,11 +17,35 @@ import torch -from pytorch_lightning.utilities.migration.base import pl_legacy_patch -from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.migration import pl_legacy_patch + +KEYS_MAPPING = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), +} log = logging.getLogger(__name__) + +def upgrade_checkpoint(filepath): + checkpoint = torch.load(filepath) + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in KEYS_MAPPING.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + + torch.save(checkpoint, filepath) + + if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -36,6 +60,4 @@ log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") copyfile(args.file, args.file + ".bak") with pl_legacy_patch(): - checkpoint = torch.load(args.file) - migrate_checkpoint(checkpoint) - torch.save(checkpoint, args.file) + upgrade_checkpoint(args.file) diff --git a/tests/utilities/test_upgrade_checkpoint.py b/tests/utilities/test_upgrade_checkpoint.py index 255bc062be8b9..a58bdb5721bc7 100644 --- a/tests/utilities/test_upgrade_checkpoint.py +++ b/tests/utilities/test_upgrade_checkpoint.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest +import torch -import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration.base import get_version, set_version -from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint +from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint @pytest.mark.parametrize( @@ -41,8 +42,8 @@ ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - set_version(old_checkpoint, "0.9.0") - set_version(new_checkpoint, pl.__version__) - updated_checkpoint = migrate_checkpoint(old_checkpoint) + filepath = os.path.join(tmpdir, "model.ckpt") + torch.save(old_checkpoint, filepath) + upgrade_checkpoint(filepath) + updated_checkpoint = torch.load(filepath) assert updated_checkpoint == new_checkpoint - assert get_version(updated_checkpoint) == pl.__version__ From d8d96c085ef1bca80b7fe0347eab3b2cc4e614f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:06:07 +0200 Subject: [PATCH 32/45] add tests --- tests/utilities/test_migration.py | 35 +++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/utilities/test_migration.py diff --git a/tests/utilities/test_migration.py b/tests/utilities/test_migration.py new file mode 100644 index 0000000000000..8ed6bac2611ae --- /dev/null +++ b/tests/utilities/test_migration.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import pytorch_lightning +from pytorch_lightning.utilities.migration import pl_legacy_patch + + +def test_patch_legacy_argparse_utils(): + with pl_legacy_patch(): + from pytorch_lightning.utilities import argparse_utils + + assert "pytorch_lightning.utilities.argparse_utils" in sys.modules + + assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules + + +def test_patch_legacy_gpus_arg_default(): + with pl_legacy_patch(): + from pytorch_lightning.utilities.argparse import _gpus_arg_default + + assert callable(_gpus_arg_default) + assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") + assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") From 30203ba6bbe4a74c24aac89063691322533e4615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:09:33 +0200 Subject: [PATCH 33/45] update load from checkpoint call --- pytorch_lightning/core/saving.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 79608bfc1c5c1..525f7cdf9ad98 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.parsing import parse_class_init_keys log = logging.getLogger(__name__) @@ -125,10 +126,11 @@ def load_from_checkpoint( pretrained_model.freeze() y_hat = pretrained_model(x) """ - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) if hparams_file is not None: extension = hparams_file.split(".")[-1] From 0babb5c87d374f453a38ce391dd40f1535aaa11b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:16:43 +0200 Subject: [PATCH 34/45] rm notebooks --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 29aea106edefc..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 29aea106edefc9d1904c0c17223a8ac2b15c48e7 From 2db27195f78c82ee644e7da0fbe746d2a6fd7376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:17:21 +0200 Subject: [PATCH 35/45] reset notebook --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..6100885854c80 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 6100885854c803458886c731cddd6bd67498c0a1 From 3d29350e0118633fb0ccf7e667979e3d09a8103d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:18:35 +0200 Subject: [PATCH 36/45] Update pytorch_lightning/utilities/argparse.py --- pytorch_lightning/utilities/argparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index ea8a6bdb01c68..10f8a15ed153f 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -295,7 +295,7 @@ def _gpus_allowed_type(x: str) -> Union[int, str]: return int(x) -def _int_or_float_type(x) -> Union[int, float]: +def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]: if "." in str(x): return float(x) return int(x) From 9cc44fe5c96e605e017b641af59eb5301426ead7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:24:46 +0200 Subject: [PATCH 37/45] update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bba7ed346980c..c7bfda97a6531 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -98,6 +98,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) +- Added `pl_legacy_patch` load utility for loading old checkpoints that have pickled legacy Lightning attributes ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) @@ -220,6 +223,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943)) +- Removed deprecated `pytorch_lighting.utilities.argparse_utils` module ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166)) + + ### Fixed - Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( From 5e48576c02b95d5d3cab93c4df470f8927b5b202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 12:50:30 +0200 Subject: [PATCH 38/45] remove test for 1.4 deprecation --- tests/deprecated_api/test_remove_1-4.py | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 tests/deprecated_api/test_remove_1-4.py diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py deleted file mode 100644 index a3a4a0b1b9180..0000000000000 --- a/tests/deprecated_api/test_remove_1-4.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test deprecated functionality which will be removed in v1.4.0""" - -import pytest - -from tests.deprecated_api import _soft_unimport_module - - -def test_v1_4_0_deprecated_imports(): - _soft_unimport_module("pytorch_lightning.utilities.argparse_utils") - with pytest.deprecated_call(match="will be removed in v1.4"): - from pytorch_lightning.utilities.argparse_utils import _gpus_arg_default # noqa: F401 From a7cea9e95af834b50750ee0f230f3f48234fdcb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 14:26:08 +0200 Subject: [PATCH 39/45] update changelog --- CHANGELOG.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f5f2eeb27dc9..c7bfda97a6531 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -161,10 +161,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) -- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) - - - ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) From 493d018d99c0c3d0271564a102ebb75ea68023ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Aug 2021 14:55:19 +0200 Subject: [PATCH 40/45] add assert to prevent "unused import" complaint --- tests/utilities/test_migration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_migration.py b/tests/utilities/test_migration.py index 8ed6bac2611ae..ee94ee690e798 100644 --- a/tests/utilities/test_migration.py +++ b/tests/utilities/test_migration.py @@ -21,6 +21,7 @@ def test_patch_legacy_argparse_utils(): with pl_legacy_patch(): from pytorch_lightning.utilities import argparse_utils + assert callable(argparse_utils._gpus_arg_default) assert "pytorch_lightning.utilities.argparse_utils" in sys.modules assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules From bed88086972102109d60524d3e486d00fa0d6f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Sep 2021 10:41:35 +0200 Subject: [PATCH 41/45] update changelog --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2baacb93d827f..6516ab9c5b0fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -245,9 +245,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) -- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) - - - Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9260](https://github.com/PyTorchLightning/pytorch-lightning/pull/9260)) From a4a8eb44845dd913b5b47b2027a8d79931ff557d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Sep 2021 10:44:34 +0200 Subject: [PATCH 42/45] rm notebook --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 6100885854c80..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6100885854c803458886c731cddd6bd67498c0a1 From 20ab34a6a1e0bfab854512bf4a38c61e429a2715 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Sep 2021 10:44:47 +0200 Subject: [PATCH 43/45] update notebook --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..a2fb6468112b7 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc From dd4532a85d06806f26b3bd4e08035eb275c86012 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Sep 2021 08:46:13 +0000 Subject: [PATCH 44/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/utilities/migration.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 71eebcb803ea5..605a9849cf2bc 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -26,8 +26,8 @@ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.types import _PATH from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.types import _PATH from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py index 2eba0d2ea490a..68f20403bb0c7 100644 --- a/pytorch_lightning/utilities/migration.py +++ b/pytorch_lightning/utilities/migration.py @@ -18,9 +18,8 @@ class pl_legacy_patch: - """ - Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be - included for unpickling old checkpoints. The following patches apply. + """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for + unpickling old checkpoints. The following patches apply. 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 From fae11ff506fb73bc2a290f7df3ceb76017e174d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Sep 2021 10:46:55 +0200 Subject: [PATCH 45/45] move context manager to _load_and_validate method --- .../trainer/connectors/checkpoint_connector.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 71eebcb803ea5..b8f7bfbe62f9a 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -63,11 +63,11 @@ def resume_start(self) -> None: return rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") - with pl_legacy_patch(): - self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path) + self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path) def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: - loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) + with pl_legacy_patch(): + loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): raise ValueError( "The checkpoint you're attempting to load follows an" @@ -158,8 +158,7 @@ def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: """Restore only the model weights.""" checkpoint = self._loaded_checkpoint if checkpoint_path is not None: - with pl_legacy_patch(): - checkpoint = self._load_and_validate_checkpoint(checkpoint_path) + checkpoint = self._load_and_validate_checkpoint(checkpoint_path) self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint)