From 89131c2ac262a4ed21be35b8654e53a205f3ef68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:06:51 +0200 Subject: [PATCH 01/51] class name as key --- pytorch_lightning/trainer/callback_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 606f6b2e4b52b..62bfa00ddcf1a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -243,7 +243,7 @@ def __is_old_signature(fn: Callable) -> bool: return True return False - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: @@ -257,10 +257,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[type(callback)] = state + callback_states[callback.__class__.__name__] = state return callback_states - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') # Todo: the `callback_states` are dropped with TPUSpawn as they @@ -268,7 +268,7 @@ def on_load_checkpoint(self, checkpoint): # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(type(callback)) + state = callback_states.get(callback.__class__.__name__) if state: state = deepcopy(state) callback.on_load_checkpoint(state) From 63fb9830fbc7b4f81807035c03d1d8343df07bfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:37:22 +0200 Subject: [PATCH 02/51] string state identifier --- pytorch_lightning/callbacks/base.py | 4 ++++ pytorch_lightning/trainer/callback_hook.py | 5 +++-- .../trainer/connectors/checkpoint_connector.py | 2 +- tests/callbacks/test_early_stopping.py | 2 +- tests/trainer/connectors/test_callback_connector.py | 6 +++--- tests/trainer/logging_/test_logger_connector.py | 2 +- 6 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 768e4ebca30ee..a214904b1c31e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -29,6 +29,10 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ + @property + def state_identifier(self) -> str: + return self.__class__.__name__ + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 62bfa00ddcf1a..dfc4922e14277 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -257,18 +257,19 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.__class__.__name__] = state + callback_states[callback.state_identifier] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') + version = checkpoint.get('pytorch-lightning_version') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.__class__.__name__) + state = callback_states.get(callback.state_identifier) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4ae42e4bad6ac..887bd4064dc6c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -244,7 +244,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: structured dictionary: { 'epoch': training epoch 'global_step': training global step - 'pytorch-lightning_version': PyTorch Lightning's version + 'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cc619077ee136..7768e112d27cc 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -76,7 +76,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state + assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss') diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..aa9faa59a3188 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -68,11 +68,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) - state0 = ckpt["callbacks"][type(callback0)] - state1 = ckpt["callbacks"][type(callback1)] + state0 = ckpt["callbacks"]["StatefulCallback0"] + state1 = ckpt["callbacks"]["StatefulCallback1"] assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == 1 - assert type(checkpoint_callback) in ckpt["callbacks"] + assert "ModelCheckpoint" in ckpt["callbacks"] def test_attach_model_callbacks(): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 923821a5e50e4..310f43f94a177 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -268,7 +268,7 @@ def test_dataloader(self): def test_call_back_validator(tmpdir): - funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) + funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_') and callable(getattr(Callback, f))]) callbacks_func = [ 'on_after_backward', From 7dc218a7b491879b2d745be40ba6a1f18da7cd6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:51:04 +0200 Subject: [PATCH 03/51] add legacy state loading --- pytorch_lightning/callbacks/base.py | 6 +++++- pytorch_lightning/trainer/callback_hook.py | 6 +++++- tests/checkpointing/test_legacy_checkpoints.py | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a214904b1c31e..bf82899373230 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from pytorch_lightning.core.lightning import LightningModule @@ -33,6 +33,10 @@ class Callback(abc.ABC): def state_identifier(self) -> str: return self.__class__.__name__ + @property + def _legacy_state_identifier(self) -> Type: + return type(self) + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index dfc4922e14277..0f7d6a6fe5066 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,6 +14,7 @@ from abc import ABC from copy import deepcopy +from distutils.version import LooseVersion from inspect import signature from typing import Any, Callable, Dict, List, Optional, Type @@ -269,7 +270,10 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.state_identifier) + state = ( + callback_states.get(callback.state_identifier) + or callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7d1284ee0d329..4080eb1deb788 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -60,6 +60,7 @@ "1.2.5", "1.2.6", "1.2.7", + "1.2.8", ] ) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): From 04b588b7ebf11894ebefa2b7bd8cdfc9936b6464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:03:11 +0200 Subject: [PATCH 04/51] update test --- tests/checkpointing/test_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f58ff768759e8..5143e4e1eb33f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -840,7 +840,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_last = torch.load(path_last) assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) - ch_type = type(model_checkpoint) + ch_type = "ModelCheckpoint" assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1098,7 +1098,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts] + ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] From bb11e2872356ce8bbcf1f333cdec59cbe554d501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:07:56 +0200 Subject: [PATCH 05/51] update tests --- tests/checkpointing/test_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 5143e4e1eb33f..d2a57a6c4d2b8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -147,7 +147,7 @@ def configure_optimizers(self): assert chk['epoch'] == epoch + 1 assert chk['global_step'] == limit_train_batches * (epoch + 1) - mc_specific_data = chk['callbacks'][type(checkpoint)] + mc_specific_data = chk['callbacks']["ModelCheckpoint"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score @@ -251,7 +251,7 @@ def configure_optimizers(self): assert chk['epoch'] == epoch + 1 assert chk['global_step'] == per_epoch_steps * (global_ix + 1) - mc_specific_data = chk['callbacks'][type(checkpoint)] + mc_specific_data = chk['callbacks']["ModelCheckpoint"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score From 0259ecbdf72fe3fdbecd1524bcae62e9e328b871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 11:20:23 +0200 Subject: [PATCH 06/51] flake8 --- pytorch_lightning/trainer/callback_hook.py | 4 +--- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 0f7d6a6fe5066..fb5c82a2330db 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,9 +14,8 @@ from abc import ABC from copy import deepcopy -from distutils.version import LooseVersion from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -264,7 +263,6 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') - version = checkpoint.get('pytorch-lightning_version') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 17c1f0ea589a5..13236d32ea522 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -264,10 +264,10 @@ def _make_assertions(epoch, ix, add=''): expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num) assert chk['global_step'] == expected_global_step - mc_specific_data = chk['callbacks']["ModelCheckpoint"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk['callbacks']["ModelCheckpoint"] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score if not reduce_lr_on_plateau: lr_scheduler_specific_data = chk['lr_schedulers'][0] From d56e5e47d7cf74199650eebe7d409acea7c29285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 12:33:07 +0200 Subject: [PATCH 07/51] add test --- .../checkpointing/test_legacy_checkpoints.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 4080eb1deb788..d4d053f6136ac 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,11 +14,18 @@ import glob import os import sys +from copy import deepcopy +from pathlib import Path import pytest +import torch +from pytorch_lightning.callbacks.base import Callback + +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning import Trainer from tests import PATH_LEGACY +from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -87,3 +94,34 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): # assert result sys.path = orig_sys_paths + + +class StatefulCallback(Callback): + + def on_save_checkpoint(self, *args): + return {"content": 123} + + +def test_callback_state_loading_by_type(tmpdir): + """ Test that legacy checkpoints that don't use a state identifier can still be loaded. """ + model = BoringModel() + callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback], + ) + trainer.fit(model) + # simulate old format where type(callback) was the key + new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) + old_checkpiont = deepcopy(new_checkpoint) + old_checkpiont["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} + torch.save(old_checkpiont, Path(tmpdir, "old.ckpt")) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + callbacks=[callback], + resume_from_checkpoint=Path(tmpdir, "old.ckpt"), + ) + trainer.fit(model) From 72ba44026c411476bae3544ebcd67d2961c6ffa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 21:39:11 +0200 Subject: [PATCH 08/51] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/callback_hook.py | 7 ++----- tests/checkpointing/test_legacy_checkpoints.py | 6 +++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 839d78d276efe..4f35e30f7ee5d 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -293,16 +293,13 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - callback_states = checkpoint.get('callbacks') + callback_states: Dict[str, dict] = checkpoint.get('callbacks') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = ( - callback_states.get(callback.state_identifier) - or callback_states.get(callback._legacy_state_identifier) - ) + state = callback_states.get(callback.state_identifier, callback_states.get(callback._legacy_state_identifier)) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index d4d053f6136ac..c44dc526a1314 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -114,9 +114,9 @@ def test_callback_state_loading_by_type(tmpdir): trainer.fit(model) # simulate old format where type(callback) was the key new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) - old_checkpiont = deepcopy(new_checkpoint) - old_checkpiont["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} - torch.save(old_checkpiont, Path(tmpdir, "old.ckpt")) + old_checkpoint = deepcopy(new_checkpoint) + old_checkpoint["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} + torch.save(old_checkpoint, Path(tmpdir, "old.ckpt")) trainer = Trainer( default_root_dir=tmpdir, From 79d85684d80443d449e937aba433e5425a736020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 22:21:37 +0200 Subject: [PATCH 09/51] improve test --- .../checkpointing/test_legacy_checkpoints.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index c44dc526a1314..b3e639c0ee3d5 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,15 +14,11 @@ import glob import os import sys -from copy import deepcopy from pathlib import Path import pytest -import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - from pytorch_lightning import Trainer from tests import PATH_LEGACY from tests.helpers import BoringModel @@ -96,32 +92,41 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): sys.path = orig_sys_paths -class StatefulCallback(Callback): +class OldStatefulCallback(Callback): + + def __init__(self, state): + self.state = state + + @property + def state_identifier(self): + return type(self) def on_save_checkpoint(self, *args): - return {"content": 123} + return {"state": self.state} + + def on_load_checkpoint(self, callback_state): + self.state = callback_state["state"] -def test_callback_state_loading_by_type(tmpdir): - """ Test that legacy checkpoints that don't use a state identifier can still be loaded. """ +def test_resume_callback_state_saved_by_type(tmpdir): + """ Test that a legacy checkpoint that didn't use a state identifier before can still be loaded. """ model = BoringModel() - callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) + callback = OldStatefulCallback(state=111) trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback], ) trainer.fit(model) - # simulate old format where type(callback) was the key - new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) - old_checkpoint = deepcopy(new_checkpoint) - old_checkpoint["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} - torch.save(old_checkpoint, Path(tmpdir, "old.ckpt")) + ckpt_path = Path(trainer.checkpoint_callback.best_model_path) + assert ckpt_path.exists() + callback = OldStatefulCallback(state=222) trainer = Trainer( default_root_dir=tmpdir, max_steps=2, callbacks=[callback], - resume_from_checkpoint=Path(tmpdir, "old.ckpt"), + resume_from_checkpoint=ckpt_path, ) trainer.fit(model) + assert callback.state == 111 From d9ea8c165fc7983b6611e5c769a3fdbe70d59575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 22:29:34 +0200 Subject: [PATCH 10/51] flake --- pytorch_lightning/trainer/callback_hook.py | 4 +++- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 4f35e30f7ee5d..82bfd644ca3bc 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -299,7 +299,9 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.state_identifier, callback_states.get(callback._legacy_state_identifier)) + state = callback_states.get( + callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index b3e639c0ee3d5..40d755d76f2c6 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -17,9 +17,9 @@ from pathlib import Path import pytest -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.base import Callback from tests import PATH_LEGACY from tests.helpers import BoringModel From 0851f0d439d30e1cda22310ba76eeb5c22d92e7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:16:11 +0200 Subject: [PATCH 11/51] fix merge --- pytorch_lightning/trainer/callback_hook.py | 17 +++++------------ tests/checkpointing/test_legacy_checkpoints.py | 5 +++-- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 2ecdb5159d146..343019171f406 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Type, Union import torch @@ -282,19 +282,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 - if callback_states is not None: - for callback in self.callbacks: - state = callback_states.get( - callback.state_identifier, callback_states.get(callback._legacy_state_identifier) - ) - if state: - state = deepcopy(state) - callback.on_load_checkpoint(state) - if callback_states is None: return @@ -309,7 +300,9 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get(type(callback)) + state = callback_states.get( + callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 8c5adf19357dc..4bbccb5c85006 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,8 +18,9 @@ import pytest -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from tests import _PATH_LEGACY +from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -110,7 +111,7 @@ def state_identifier(self): def on_save_checkpoint(self, *args): return {"state": self.state} - def on_load_checkpoint(self, callback_state): + def on_load_checkpoint(self, trainer, pl_module, callback_state): self.state = callback_state["state"] From 82d5658a80d63974d9d4fe5dc1227cd0d8051c4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 09:17:52 +0000 Subject: [PATCH 12/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/governance.rst | 2 +- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 5c29f7d0da544..4114ccdb8a818 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -39,7 +39,7 @@ Board Alumni ------ -- Jeff Yang (`ydcjeff `_) +- Jeff Yang (`ydcjeff `_) - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 4bbccb5c85006..2dc499c1882d6 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,7 +18,7 @@ import pytest -from pytorch_lightning import Trainer, Callback +from pytorch_lightning import Callback, Trainer from tests import _PATH_LEGACY from tests.helpers import BoringModel From 334fd4a9d191db99d4cb10048fc9be82c64f27cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:21:18 +0200 Subject: [PATCH 13/51] use qualname --- pytorch_lightning/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 05685f7e9a688..a0af628c0aefc 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -35,7 +35,7 @@ class Callback(abc.ABC): @property def state_identifier(self) -> str: - return self.__class__.__name__ + return self.__class__.__qualname__ @property def _legacy_state_identifier(self) -> Type: From f144fd188af921b245eea9177c43d559aec9a818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:44:57 +0200 Subject: [PATCH 14/51] rename state_id --- pytorch_lightning/callbacks/base.py | 4 ++-- pytorch_lightning/trainer/callback_hook.py | 4 ++-- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a0af628c0aefc..f05ddd9217844 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,11 +34,11 @@ class Callback(abc.ABC): """ @property - def state_identifier(self) -> str: + def state_id(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_identifier(self) -> Type: + def _legacy_state_id(self) -> Type: return type(self) def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 343019171f406..5083bcae51263 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -277,7 +277,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.state_identifier] = state + callback_states[callback.state_id] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: @@ -301,7 +301,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: for callback in self.callbacks: state = callback_states.get( - callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + callback.state_id, callback_states.get(callback._legacy_state_id) ) if state: state = deepcopy(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 2dc499c1882d6..cb43a65ac5610 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -105,7 +105,7 @@ def __init__(self, state): self.state = state @property - def state_identifier(self): + def state_id(self): return type(self) def on_save_checkpoint(self, *args): From 615498670d3f9ae886c5f617af39963b4fd22df6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:45:33 +0200 Subject: [PATCH 15/51] fix diff --- pytorch_lightning/trainer/callback_hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5083bcae51263..3075a3cfa8c77 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -282,10 +282,11 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") + if callback_states is None: return From 2c0c707e297a4671febfae849459bb8ead0c4cf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 23 Apr 2021 11:12:34 +0200 Subject: [PATCH 16/51] unique identifiers --- pytorch_lightning/callbacks/base.py | 5 +++ pytorch_lightning/callbacks/early_stopping.py | 4 ++ .../callbacks/model_checkpoint.py | 4 ++ tests/callbacks/test_callbacks.py | 2 +- tests/callbacks/test_early_stopping.py | 5 +++ tests/checkpointing/test_model_checkpoint.py | 5 +++ .../connectors/test_callback_connector.py | 37 +++++++++++++++---- 7 files changed, 53 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index f05ddd9217844..9c0542fe2a584 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -41,6 +41,11 @@ def state_id(self) -> str: def _legacy_state_id(self) -> Type: return type(self) + def _generate_state_identifier(self, **kwargs: Any) -> str: + attrs = ", ".join(f"{k}={v}" for k, v in kwargs.items()) + identifier = f"{self.__class__.__name__}[{attrs}]" + return identifier + def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0015ac47f0d41..6f0972a20f308 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -126,6 +126,10 @@ def __init__( ) self.monitor = monitor or "early_stop_on" + @property + def state_id(self) -> str: + return self._generate_state_identifier(monitor=self.monitor) + def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e0153c05732fa..a87c80c6b1e4b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -249,6 +249,10 @@ def __init__( self.__validate_init_configuration() self._save_function = None + @property + def state_id(self) -> str: + return self._generate_state_identifier(monitor=self.monitor) + def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ When pretrain routine starts we build the ckpt dir on the fly diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 57fdd1bf66322..c38479b505482 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -13,7 +13,7 @@ # limitations under the License. from unittest.mock import call, Mock -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from tests.helpers import BoringModel diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 861a2a1b0ae1b..4d8420710dc27 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -33,6 +33,11 @@ _logger = logging.getLogger(__name__) +def test_early_stopping_state_identifier(): + early_stopping = EarlyStopping(monitor="val_loss") + assert early_stopping.state_identifier == "EarlyStopping[monitor=val_loss]" + + class EarlyStoppingTestRestore(EarlyStopping): # this class has to be defined outside the test function, otherwise we get pickle error def __init__(self, expected_state, *args, **kwargs): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index dd075a8a2afa9..f1e386f82e695 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -43,6 +43,11 @@ from tests.helpers.runif import RunIf +def test_model_checkpoint_state_identifier(): + early_stopping = ModelCheckpoint(monitor="val_loss") + assert early_stopping.state_identifier == "ModelCheckpoint[monitor=val_loss]" + + class LogInTwoMethods(BoringModel): def training_step(self, batch, batch_idx): diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 6bfe9954d8104..9e0f165c48d3a 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -61,31 +61,52 @@ def on_save_checkpoint(self, *args): class StatefulCallback1(Callback): + def __init__(self, unique=None, other=None): + self._unique = unique + self._other = other + + @property + def state_identifier(self): + return self._generate_state_identifier(unique=self._unique) + def on_save_checkpoint(self, *args): - return {"content1": 1} + return {"content1": self._unique} -def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): - """ Test that all callback states get saved even if the ModelCheckpoint is not given as last. """ +def test_all_callback_states_saved(tmpdir): + """ + Test that all callback states get saved even if the ModelCheckpoint is not given as last + and when there are multiple callbacks of the same type. + """ callback0 = StatefulCallback0() - callback1 = StatefulCallback1() + callback1 = StatefulCallback1(unique="one") + callback2 = StatefulCallback1(unique="two", other=2) checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states") model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, - callbacks=[callback0, checkpoint_callback, callback1] + callbacks=[ + callback0, + # checkpoint callback does not have to be at the end + checkpoint_callback, + # callback2 and callback3 have the same type + callback1, + callback2, + ] ) trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) state0 = ckpt["callbacks"]["StatefulCallback0"] - state1 = ckpt["callbacks"]["StatefulCallback1"] + state1 = ckpt["callbacks"]["StatefulCallback1[unique=one]"] + state2 = ckpt["callbacks"]["StatefulCallback1[unique=two]"] assert "content0" in state0 and state0["content0"] == 0 - assert "content1" in state1 and state1["content1"] == 1 - assert "ModelCheckpoint" in ckpt["callbacks"] + assert "content1" in state1 and state1["content1"] == "one" + assert "content1" in state2 and state2["content1"] == "two" + assert "ModelCheckpoint[monitor=None]" in ckpt["callbacks"] def test_attach_model_callbacks(): From 9f9a76dbad3d9b8cb88a797082e780e41fade53c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 23 Apr 2021 11:28:38 +0200 Subject: [PATCH 17/51] update tests --- tests/callbacks/test_early_stopping.py | 3 ++- tests/callbacks/test_lambda_function.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 10 +++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 4d8420710dc27..489cec50f268f 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -82,7 +82,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state + print(checkpoint["callbacks"]) + assert checkpoint["callbacks"]["EarlyStoppingTestRestore[monitor=train_loss]"] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss') diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 845846dfd1cfc..7daba6dfad97c 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -33,7 +33,7 @@ def on_train_epoch_start(self): def call(hook, *_, **__): checker.add(hook) - hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)} + hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction) if not m.startswith("_")] hooks_args = {h: partial(call, h) for h in hooks} hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')] diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f1e386f82e695..7b29bb07d7b35 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -159,7 +159,7 @@ def on_validation_epoch_end(self): assert chk['epoch'] == epoch + 1 assert chk['global_step'] == limit_train_batches * (epoch + 1) - mc_specific_data = chk['callbacks']["ModelCheckpoint"] + mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score @@ -275,7 +275,7 @@ def _make_assertions(epoch, ix, version=''): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk['global_step'] == expected_global_step - mc_specific_data = chk['callbacks']["ModelCheckpoint"] + mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score @@ -914,8 +914,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ch_type = "ModelCheckpoint" - assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] + ckpt_id = "ModelCheckpoint[monitor=early_stop_on]" + assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) @@ -1167,7 +1167,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] + ckpts = [ckpt["callbacks"]["ModelCheckpoint[monitor=foo]"] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] From 31a7737f81b3e7c5267f8659c7c159734fd5b57c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 23 Apr 2021 11:36:39 +0200 Subject: [PATCH 18/51] unused import --- tests/callbacks/test_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index c38479b505482..57fdd1bf66322 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -13,7 +13,7 @@ # limitations under the License. from unittest.mock import call, Mock -from pytorch_lightning import Trainer, Callback +from pytorch_lightning import Trainer from tests.helpers import BoringModel From 8eec798a20d3ec18fe82fd45d3b9d3f47b7b8665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:18:47 +0200 Subject: [PATCH 19/51] rename state_id --- pytorch_lightning/callbacks/base.py | 4 ++-- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/trainer/connectors/test_callback_connector.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 9c0542fe2a584..6051a81970653 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -41,9 +41,9 @@ def state_id(self) -> str: def _legacy_state_id(self) -> Type: return type(self) - def _generate_state_identifier(self, **kwargs: Any) -> str: + def _generate_state_id(self, **kwargs: Any) -> str: attrs = ", ".join(f"{k}={v}" for k, v in kwargs.items()) - identifier = f"{self.__class__.__name__}[{attrs}]" + identifier = f"{self.__class__.__qualname__}[{attrs}]" return identifier def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 6f0972a20f308..15cdf2eba8da2 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -128,7 +128,7 @@ def __init__( @property def state_id(self) -> str: - return self._generate_state_identifier(monitor=self.monitor) + return self._generate_state_id(monitor=self.monitor) def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a87c80c6b1e4b..82594bb4408ba 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -251,7 +251,7 @@ def __init__( @property def state_id(self) -> str: - return self._generate_state_identifier(monitor=self.monitor) + return self._generate_state_id(monitor=self.monitor) def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 9e0f165c48d3a..7cd35919f2ad1 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -67,7 +67,7 @@ def __init__(self, unique=None, other=None): @property def state_identifier(self): - return self._generate_state_identifier(unique=self._unique) + return self._generate_state_id(unique=self._unique) def on_save_checkpoint(self, *args): return {"content1": self._unique} From 291c9feda1e34579441724c0536dfb76ec2d154c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:20:20 +0200 Subject: [PATCH 20/51] rename state_id --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7b29bb07d7b35..3e947fb6326a8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -45,7 +45,7 @@ def test_model_checkpoint_state_identifier(): early_stopping = ModelCheckpoint(monitor="val_loss") - assert early_stopping.state_identifier == "ModelCheckpoint[monitor=val_loss]" + assert early_stopping.state_id == "ModelCheckpoint[monitor=val_loss]" class LogInTwoMethods(BoringModel): From 2308c5d60cfc2d41278627b0726c21cce96c275e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:20:59 +0200 Subject: [PATCH 21/51] rename state_id --- tests/callbacks/test_early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 489cec50f268f..d2263fcb35a14 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -35,7 +35,7 @@ def test_early_stopping_state_identifier(): early_stopping = EarlyStopping(monitor="val_loss") - assert early_stopping.state_identifier == "EarlyStopping[monitor=val_loss]" + assert early_stopping.state_id == "EarlyStopping[monitor=val_loss]" class EarlyStoppingTestRestore(EarlyStopping): From 4cda7233784a17601be520d41f9169cc6d6672d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:22:47 +0200 Subject: [PATCH 22/51] fix merge error --- tests/callbacks/test_lambda_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 7daba6dfad97c..62eb887a73238 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -33,7 +33,7 @@ def on_train_epoch_start(self): def call(hook, *_, **__): checker.add(hook) - hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction) if not m.startswith("_")] + hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction) if not m.startswith("_")} hooks_args = {h: partial(call, h) for h in hooks} hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')] From a92110e9142135b4c06d4df1e4e7072f6628d1ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 10:22:46 +0000 Subject: [PATCH 23/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/callback_hook.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 3075a3cfa8c77..4fae7edc2aa97 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -286,7 +286,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") - + if callback_states is None: return @@ -301,9 +301,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get( - callback.state_id, callback_states.get(callback._legacy_state_id) - ) + state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): From 40e8827d1abccb8406df2ff6e47f5910f215c00a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:25:49 +0200 Subject: [PATCH 24/51] remove print statements --- tests/callbacks/test_early_stopping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index d2263fcb35a14..d16036bcab04a 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -82,7 +82,6 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - print(checkpoint["callbacks"]) assert checkpoint["callbacks"]["EarlyStoppingTestRestore[monitor=train_loss]"] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) From 9472d1fe2a27aac5315d82c0e7e788cf30de9bb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:48:03 +0200 Subject: [PATCH 25/51] update fx validator --- .../trainer/connectors/logger_connector/fx_validator.py | 1 + tests/trainer/logging_/test_logger_connector.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 3604574fd1e81..7ad74001ea686 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -65,6 +65,7 @@ class FxValidator: on_save_checkpoint=None, on_load_checkpoint=None, setup=None, + state_id=None, teardown=None, configure_sharded_model=None, training_step=dict(on_step=(False, True), on_epoch=(False, True)), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 27598b40fbd31..64543c4357de0 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -78,6 +78,7 @@ def test_fx_validator(tmpdir): "on_predict_epoch_start", "on_predict_start", 'setup', + "state_id", 'teardown', ] @@ -105,6 +106,7 @@ def test_fx_validator(tmpdir): "on_train_end", "on_validation_end", "setup", + "state_id", "teardown", ] From 7579e92d9772236f6d020aebbcf538d7ce318a1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 00:24:53 +0200 Subject: [PATCH 26/51] formatting, updates from master --- pytorch_lightning/callbacks/base.py | 2 +- .../callbacks/model_checkpoint.py | 2 +- .../checkpointing/test_legacy_checkpoints.py | 27 ------------------- tests/checkpointing/test_model_checkpoint.py | 16 +++++------ .../connectors/test_callback_connector.py | 3 +-- 5 files changed, 11 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 53361f92872de..aaa6612626223 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -53,7 +53,7 @@ def _generate_state_id(self, **kwargs: Any) -> str: identifier = f"{self.__class__.__qualname__}[{attrs}]" return identifier - def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before configure sharded model""" def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2c35773c9661d..1d7e0c4f7451c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -253,7 +253,7 @@ def __init__( def state_id(self) -> str: return self._generate_state_id(monitor=self.monitor) - def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """ When pretrain routine starts we build the ckpt dir on the fly """ diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index fa9397671320f..e1af46e61cfec 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,13 +14,11 @@ import glob import os import sys -from pathlib import Path import pytest from pytorch_lightning import Callback, Trainer from tests import _PATH_LEGACY -from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") CHECKPOINT_EXTENSION = ".ckpt" @@ -100,7 +98,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): class OldStatefulCallback(Callback): - def __init__(self, state): self.state = state @@ -113,27 +110,3 @@ def on_save_checkpoint(self, *args): def on_load_checkpoint(self, trainer, pl_module, callback_state): self.state = callback_state["state"] - - -def test_resume_callback_state_saved_by_type(tmpdir): - """ Test that a legacy checkpoint that didn't use a state identifier before can still be loaded. """ - model = BoringModel() - callback = OldStatefulCallback(state=111) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - callbacks=[callback], - ) - trainer.fit(model) - ckpt_path = Path(trainer.checkpoint_callback.best_model_path) - assert ckpt_path.exists() - - callback = OldStatefulCallback(state=222) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=2, - callbacks=[callback], - resume_from_checkpoint=ckpt_path, - ) - trainer.fit(model) - assert callback.state == 111 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 172e1276eb89e..11453d3f41e27 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -153,10 +153,10 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}]"] + assert mc_specific_data["dirpath"] == checkpoint.dirpath + assert mc_specific_data["monitor"] == monitor + assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] @@ -264,10 +264,10 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}]"] + assert mc_specific_data["dirpath"] == checkpoint.dirpath + assert mc_specific_data["monitor"] == monitor + assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 103d0d3ccd6b2..c9ceaeaba50fd 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -59,7 +59,6 @@ def on_save_checkpoint(self, *args): class StatefulCallback1(Callback): - def __init__(self, unique=None, other=None): self._unique = unique self._other = other @@ -94,7 +93,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): # callback2 and callback3 have the same type callback1, callback2, - ] + ], ) trainer.fit(model) From 2fa8ff904580bc4d9eeb51cd9ab1787eabcfd562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 00:25:34 +0200 Subject: [PATCH 27/51] fix test --- tests/trainer/logging_/test_logger_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 752d627982aea..518a401a72037 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -106,7 +106,6 @@ def test_fx_validator(tmpdir): "on_train_end", "on_validation_end", "setup", - "state_id", "teardown", ] From ad9ba3ee21bfc746471beb75a5b1b5284dedcfa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 00:26:50 +0200 Subject: [PATCH 28/51] remove redundant change to logger connector --- .../trainer/connectors/logger_connector/fx_validator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 8b449f0323429..f2ad8f1130993 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -65,7 +65,6 @@ class FxValidator: on_save_checkpoint=None, on_load_checkpoint=None, setup=None, - state_id=None, teardown=None, configure_sharded_model=None, training_step=dict(on_step=(False, True), on_epoch=(False, True)), From c136662badc01ba7bbe1f8697de57fb804f8ca33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 00:29:35 +0200 Subject: [PATCH 29/51] use helper function for get members --- tests/callbacks/test_lambda_function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 3d64583b6ea9c..88752d56bf697 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from functools import partial from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback from tests.helpers.boring_model import BoringModel +from tests.models.test_hooks import get_members def test_lambda_call(tmpdir): @@ -32,7 +32,7 @@ def on_train_epoch_start(self): def call(hook, *_, **__): checker.add(hook) - hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction) if not m.startswith("_")} + hooks = get_members(Callback) hooks_args = {h: partial(call, h) for h in hooks} hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")] From 7f9aa47f2ae72444fee2f9173680e1d785de23b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 00:29:46 +0200 Subject: [PATCH 30/51] update with master --- tests/checkpointing/test_legacy_checkpoints.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index e1af46e61cfec..8693965a52abc 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -17,7 +17,7 @@ import pytest -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Trainer from tests import _PATH_LEGACY LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") @@ -95,18 +95,3 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): # trainer.fit(model) sys.path = orig_sys_paths - - -class OldStatefulCallback(Callback): - def __init__(self, state): - self.state = state - - @property - def state_id(self): - return type(self) - - def on_save_checkpoint(self, *args): - return {"state": self.state} - - def on_load_checkpoint(self, trainer, pl_module, callback_state): - self.state = callback_state["state"] From c4a0f15f30033613005fa9826633111e3bbb3d85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 00:35:08 +0200 Subject: [PATCH 31/51] rename state_identifier -> state id --- tests/callbacks/test_early_stopping.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 2 +- tests/trainer/connectors/test_callback_connector.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index ff7f640cd5a71..cad5315261817 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -33,7 +33,7 @@ _logger = logging.getLogger(__name__) -def test_early_stopping_state_identifier(): +def test_early_stopping_state_id(): early_stopping = EarlyStopping(monitor="val_loss") assert early_stopping.state_id == "EarlyStopping[monitor=val_loss]" diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 11453d3f41e27..02bbfe69c75b8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -43,7 +43,7 @@ from tests.helpers.runif import RunIf -def test_model_checkpoint_state_identifier(): +def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") assert early_stopping.state_id == "ModelCheckpoint[monitor=val_loss]" diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index c9ceaeaba50fd..f40437ae8fd09 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -64,7 +64,7 @@ def __init__(self, unique=None, other=None): self._other = other @property - def state_identifier(self): + def state_id(self): return self._generate_state_id(unique=self._unique) def on_save_checkpoint(self, *args): From f741bcd83494e017dbf23fe45add33db3ee6e2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 00:36:29 +0200 Subject: [PATCH 32/51] add changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a33cdd17031bd..d68ef0a7b03e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) -- +- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) - From ace9c1d24fc10523a349a1fb1b1017517bef8b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 01:42:15 +0200 Subject: [PATCH 33/51] add docs for persisting state --- docs/source/extensions/callbacks.rst | 57 ++++++++++++++++++++++++++-- pytorch_lightning/callbacks/base.py | 7 ++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 2c9ee612ceb22..58347d9cb28a2 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -119,10 +119,61 @@ Persisting State Some callbacks require internal state in order to function properly. You can optionally choose to persist your callback's state as part of model checkpoint files using the callback hooks :meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`. -However, you must follow two constraints: +Note that the returned state must be able to be pickled. -1. Your returned state must be able to be pickled. -2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class. +When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough +to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then +the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_id` property in order for Lightning +to be able to distinguish the different states when loading the callback state. This concept is best illustrated by +the following example. + +.. testcode:: + + class Counter(Callback): + def __init__(self, what="epochs", verbose=True): + self.what = what + self.verbose = verbose + self.state = {"epochs": 0, "batches": 0} + + @property + def state_id(self): + # note: we do not include `verbose` here on purpose + return self._generate_state_id(what=self.what) + + def on_train_epoch_end(self, *args, **kwargs): + if self.what == "epochs": + self.state["epochs"] += 1 + + def on_train_batch_end(self, *args, **kwargs): + if self.what == "batches": + self.state["batches"] += 1 + + def on_load_checkpoint(self, trainer, pl_module, callback_state): + self.state.update(callback_state) + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + return self.state.copy() + + + # two callbacks of the same type are being used + trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")]) + +A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information: + +.. code-block:: python + + { + "state_dict": ... + "callbacks": { + "Counter[what=batches]": {"batches": 32, "epochs": 0}, + "Counter[what=epochs]": {"batches": 0, "epochs": 2}, + ... + } + } + +The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_id` is essential here. If it were missing, +Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_id` +by default only defines the class name as the key, e.g., here ``Counter``. Best Practices diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index aaa6612626223..7295c73c37e8f 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -49,6 +49,13 @@ def _legacy_state_id(self) -> Type["Callback"]: return type(self) def _generate_state_id(self, **kwargs: Any) -> str: + """ + Formats a set of key-value pairs into a state id string by joining the string representation of each pair + into a comma separated list with the callback class name prefixed. Useful for defining a :attr:`state_id`. + + Args: + **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. + """ attrs = ", ".join(f"{k}={v}" for k, v in kwargs.items()) identifier = f"{self.__class__.__qualname__}[{attrs}]" return identifier From 65747969ca1bee2a605278133ef2fad1d7aa3e26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 02:02:37 +0200 Subject: [PATCH 34/51] update test --- tests/checkpointing/test_model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 02bbfe69c75b8..375cb76cdbee3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -862,7 +862,6 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"] ckpt_id = "ModelCheckpoint[monitor=early_stop_on]" assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] From 87ce42080b34b7a19d3eaed683c9032b97701cf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 31 Jul 2021 12:31:42 +0200 Subject: [PATCH 35/51] repr string representation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7295c73c37e8f..1b7eff3eedd61 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -56,7 +56,7 @@ def _generate_state_id(self, **kwargs: Any) -> str: Args: **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. """ - attrs = ", ".join(f"{k}={v}" for k, v in kwargs.items()) + attrs = ", ".join(f"{repr(k)}={repr(v)}" for k, v in kwargs.items()) identifier = f"{self.__class__.__qualname__}[{attrs}]" return identifier From eac19216d77c2a23c82d109d3ae6b5ae5d303372 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 31 Jul 2021 12:31:42 +0200 Subject: [PATCH 36/51] Revert "repr string representation" This reverts commit 87ce42080b34b7a19d3eaed683c9032b97701cf6. --- pytorch_lightning/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 1b7eff3eedd61..7295c73c37e8f 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -56,7 +56,7 @@ def _generate_state_id(self, **kwargs: Any) -> str: Args: **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. """ - attrs = ", ".join(f"{repr(k)}={repr(v)}" for k, v in kwargs.items()) + attrs = ", ".join(f"{k}={v}" for k, v in kwargs.items()) identifier = f"{self.__class__.__qualname__}[{attrs}]" return identifier From 9023a1f01e86ead83446654e4d331e010e33d331 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Aug 2021 13:31:59 +0200 Subject: [PATCH 37/51] add mode --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/callbacks/test_early_stopping.py | 6 ++++-- tests/checkpointing/test_model_checkpoint.py | 6 +++--- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f9e839ea770fb..f04b1d5a0bfe8 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -122,7 +122,7 @@ def __init__( @property def state_id(self) -> str: - return self._generate_state_id(monitor=self.monitor) + return self._generate_state_id(monitor=self.monitor, mode=self.mode) def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 23a24be915037..513fa9aa64abc 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -250,7 +250,7 @@ def __init__( @property def state_id(self) -> str: - return self._generate_state_id(monitor=self.monitor) + return self._generate_state_id(monitor=self.monitor, mode=self.mode) def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """ diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cad5315261817..14e06417acc82 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -35,7 +35,7 @@ def test_early_stopping_state_id(): early_stopping = EarlyStopping(monitor="val_loss") - assert early_stopping.state_id == "EarlyStopping[monitor=val_loss]" + assert early_stopping.state_id == "EarlyStopping[monitor=val_loss, mode=min]" class EarlyStoppingTestRestore(EarlyStopping): @@ -82,7 +82,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert checkpoint["callbacks"]["EarlyStoppingTestRestore[monitor=train_loss]"] == early_stop_callback_state + assert ( + checkpoint["callbacks"]["EarlyStoppingTestRestore[monitor=train_loss, mode=min]"] == early_stop_callback_state + ) # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 375cb76cdbee3..b9dc6a626a02c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -45,7 +45,7 @@ def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") - assert early_stopping.state_id == "ModelCheckpoint[monitor=val_loss]" + assert early_stopping.state_id == "ModelCheckpoint[monitor=val_loss, mode=min]" class LogInTwoMethods(BoringModel): @@ -153,7 +153,7 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}]"] + mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}, mode=min]"] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -1102,7 +1102,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"]["ModelCheckpoint[monitor=foo]"] for ckpt in ckpts] + ckpts = [ckpt["callbacks"]["ModelCheckpoint[monitor=foo, mode=min]"] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] From 76a81e389aa95d2bf9280b9ac5700885644f3ad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Aug 2021 14:21:02 +0200 Subject: [PATCH 38/51] update mode=min in test --- tests/checkpointing/test_model_checkpoint.py | 4 ++-- tests/trainer/connectors/test_callback_connector.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b9dc6a626a02c..574c1c6934375 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -264,7 +264,7 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}]"] + mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}, mode=min]"] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -863,7 +863,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ckpt_id = "ModelCheckpoint[monitor=early_stop_on]" + ckpt_id = "ModelCheckpoint[monitor=early_stop_on, mode=min]" assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 0fd82f6e10730..59eb2505764cc 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -106,7 +106,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == "one" assert "content1" in state2 and state2["content1"] == "two" - assert "ModelCheckpoint[monitor=None]" in ckpt["callbacks"] + assert "ModelCheckpoint[monitor=None, mode=min]" in ckpt["callbacks"] def test_attach_model_callbacks(): From 5bb4a48f16d9ccc48d96772f95190b16f8301228 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Aug 2021 15:00:42 +0200 Subject: [PATCH 39/51] repr everywhere --- pytorch_lightning/callbacks/base.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7295c73c37e8f..4c3360b6bb17a 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -50,15 +50,13 @@ def _legacy_state_id(self) -> Type["Callback"]: def _generate_state_id(self, **kwargs: Any) -> str: """ - Formats a set of key-value pairs into a state id string by joining the string representation of each pair - into a comma separated list with the callback class name prefixed. Useful for defining a :attr:`state_id`. + Formats a set of key-value pairs into a state id string with the callback class name prefixed. + Useful for defining a :attr:`state_id`. Args: **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. """ - attrs = ", ".join(f"{k}={v}" for k, v in kwargs.items()) - identifier = f"{self.__class__.__qualname__}[{attrs}]" - return identifier + return f"{self.__class__.__qualname__}{repr(kwargs)}" def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before configure sharded model""" From 8314c62a396d76802b2429c7730e83be6a58f458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Aug 2021 15:20:58 +0200 Subject: [PATCH 40/51] adapt tests to repr --- tests/callbacks/test_early_stopping.py | 5 +++-- tests/checkpointing/test_model_checkpoint.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 14e06417acc82..729a3a5fe4976 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -35,7 +35,7 @@ def test_early_stopping_state_id(): early_stopping = EarlyStopping(monitor="val_loss") - assert early_stopping.state_id == "EarlyStopping[monitor=val_loss, mode=min]" + assert early_stopping.state_id == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}" class EarlyStoppingTestRestore(EarlyStopping): @@ -83,7 +83,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) assert ( - checkpoint["callbacks"]["EarlyStoppingTestRestore[monitor=train_loss, mode=min]"] == early_stop_callback_state + checkpoint["callbacks"]["EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"] + == early_stop_callback_state ) # ensure state is reloaded properly (assertion in the callback) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 574c1c6934375..db0e3dd6d9c2c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -45,7 +45,7 @@ def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") - assert early_stopping.state_id == "ModelCheckpoint[monitor=val_loss, mode=min]" + assert early_stopping.state_id == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min'}" class LogInTwoMethods(BoringModel): @@ -153,7 +153,7 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}, mode=min]"] + mc_specific_data = chk["callbacks"][f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min'}}"] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -264,7 +264,7 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk["callbacks"][f"ModelCheckpoint[monitor={monitor}, mode=min]"] + mc_specific_data = chk["callbacks"][f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min'}}"] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -863,7 +863,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ckpt_id = "ModelCheckpoint[monitor=early_stop_on, mode=min]" + ckpt_id = "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min'}" assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1102,7 +1102,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"]["ModelCheckpoint[monitor=foo, mode=min]"] for ckpt in ckpts] + ckpts = [ckpt["callbacks"]["ModelCheckpoint{'monitor': 'foo', 'mode': 'min'}"] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] From 6e7c4b38059a648f8d824a5fc0f4161e42ba378a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Aug 2021 15:40:29 +0200 Subject: [PATCH 41/51] adjust test --- tests/trainer/connectors/test_callback_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 59eb2505764cc..8921142122ac8 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -101,12 +101,12 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): ckpt = torch.load(str(tmpdir / "all_states.ckpt")) state0 = ckpt["callbacks"]["StatefulCallback0"] - state1 = ckpt["callbacks"]["StatefulCallback1[unique=one]"] - state2 = ckpt["callbacks"]["StatefulCallback1[unique=two]"] + state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"] + state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"] assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == "one" assert "content1" in state2 and state2["content1"] == "two" - assert "ModelCheckpoint[monitor=None, mode=min]" in ckpt["callbacks"] + assert "ModelCheckpoint{'monitor': None, 'mode': 'min'}" in ckpt["callbacks"] def test_attach_model_callbacks(): From 8ea9deafb76d60de51eda097fe86ec45a18bcc47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 13 Aug 2021 02:18:16 +0200 Subject: [PATCH 42/51] add every_n_train_steps, every_n_epochs, train_time_interval --- .../callbacks/model_checkpoint.py | 8 ++++++- tests/checkpointing/test_model_checkpoint.py | 22 ++++++++++++++----- .../connectors/test_callback_connector.py | 5 ++++- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 513fa9aa64abc..bde0f51c5861f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -250,7 +250,13 @@ def __init__( @property def state_id(self) -> str: - return self._generate_state_id(monitor=self.monitor, mode=self.mode) + return self._generate_state_id( + monitor=self.monitor, + mode=self.mode, + every_n_train_steps=self._every_n_train_steps, + every_n_epochs=self._every_n_epochs, + train_time_interval=self._train_time_interval, + ) def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """ diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index db0e3dd6d9c2c..f6c67c06c8ea7 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -45,7 +45,10 @@ def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") - assert early_stopping.state_id == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min'}" + assert ( + early_stopping.state_id + == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" + ) class LogInTwoMethods(BoringModel): @@ -153,7 +156,9 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk["callbacks"][f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min'}}"] + mc_specific_data = chk["callbacks"][ + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}}" + ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -264,7 +269,9 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk["callbacks"][f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min'}}"] + mc_specific_data = chk["callbacks"][ + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}}" + ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -863,7 +870,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ckpt_id = "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min'}" + ckpt_id = "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1102,7 +1109,12 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"]["ModelCheckpoint{'monitor': 'foo', 'mode': 'min'}"] for ckpt in ckpts] + ckpts = [ + ckpt["callbacks"][ + "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" + ] + for ckpt in ckpts + ] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 8921142122ac8..2d195fff4872c 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -106,7 +106,10 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == "one" assert "content1" in state2 and state2["content1"] == "two" - assert "ModelCheckpoint{'monitor': None, 'mode': 'min'}" in ckpt["callbacks"] + assert ( + "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" + in ckpt["callbacks"] + ) def test_attach_model_callbacks(): From 62864ec2aaf59a3e2bffa4c59235d610a01427c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 13 Aug 2021 02:33:29 +0200 Subject: [PATCH 43/51] update docs update docs --- docs/source/extensions/callbacks.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 58347d9cb28a2..0b9bc9159f541 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -163,10 +163,10 @@ A Lightning checkpoint from this Trainer with the two stateful callbacks will in .. code-block:: python { - "state_dict": ... + "state_dict": ..., "callbacks": { - "Counter[what=batches]": {"batches": 32, "epochs": 0}, - "Counter[what=epochs]": {"batches": 0, "epochs": 2}, + "Counter{'what': 'batches'}": {"batches": 32, "epochs": 0}, + "Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2}, ... } } From 15a6492170d90b9e0a55d3742e36d9c5f972ca4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 13 Aug 2021 13:10:04 +0200 Subject: [PATCH 44/51] add save_on_train_epoch_end --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + tests/checkpointing/test_model_checkpoint.py | 10 +++++----- tests/trainer/connectors/test_callback_connector.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bde0f51c5861f..1964f7d910176 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -256,6 +256,7 @@ def state_id(self) -> str: every_n_train_steps=self._every_n_train_steps, every_n_epochs=self._every_n_epochs, train_time_interval=self._train_time_interval, + save_on_train_epoch_end=self._save_on_train_epoch_end, ) def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f6c67c06c8ea7..b71df2b635b6c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -47,7 +47,7 @@ def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") assert ( early_stopping.state_id - == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" + == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': None}" ) @@ -157,7 +157,7 @@ def on_validation_epoch_end(self): assert chk["global_step"] == limit_train_batches * (epoch + 1) mc_specific_data = chk["callbacks"][ - f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}}" + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -270,7 +270,7 @@ def _make_assertions(epoch, ix, version=""): assert chk["global_step"] == expected_global_step mc_specific_data = chk["callbacks"][ - f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}}" + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': False}}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -870,7 +870,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ckpt_id = "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" + ckpt_id = "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}" assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1111,7 +1111,7 @@ def training_step(self, *args): ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] ckpts = [ ckpt["callbacks"][ - "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" + "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}" ] for ckpt in ckpts ] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 2d195fff4872c..9485e614162b9 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -107,7 +107,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content1" in state1 and state1["content1"] == "one" assert "content1" in state2 and state2["content1"] == "two" assert ( - "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}" + "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"] ) From 49f376b4f05b67d40595496c9be35d7f789eea85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 13 Aug 2021 13:32:17 +0200 Subject: [PATCH 45/51] update docs with tip --- docs/source/extensions/callbacks.rst | 2 ++ pytorch_lightning/callbacks/early_stopping.py | 7 +++++++ pytorch_lightning/callbacks/model_checkpoint.py | 6 ++++++ 3 files changed, 15 insertions(+) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 0b9bc9159f541..c770f53348346 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -113,6 +113,8 @@ Lightning has a few built-in callbacks. ---------- +.. _Persisting Callback State: + Persisting State ---------------- diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f04b1d5a0bfe8..1b43ab373b033 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -75,6 +75,13 @@ class EarlyStopping(Callback): >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping]) + + .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the + following arguments: + + *monitor, mode* + + Read more: :ref:`Persisting Callback State` """ mode_dict = {"min": torch.lt, "max": torch.gt} diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1964f7d910176..a9a14119e6c91 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -194,6 +194,12 @@ class ModelCheckpoint(Callback): trainer.fit(model) checkpoint_callback.best_model_path + .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the + following arguments: + + *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end* + + Read more: :ref:`Persisting Callback State` """ CHECKPOINT_JOIN_CHAR = "-" From 77e7027248b6f6cf3f7503c2b804c55a59a7c180 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 13 Aug 2021 13:46:36 +0200 Subject: [PATCH 46/51] black formatting with line break x --- tests/checkpointing/test_model_checkpoint.py | 17 ++++++++++++----- .../connectors/test_callback_connector.py | 3 ++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b71df2b635b6c..9723dd46b712d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -47,7 +47,8 @@ def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") assert ( early_stopping.state_id - == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': None}" + == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ) @@ -157,7 +158,8 @@ def on_validation_epoch_end(self): assert chk["global_step"] == limit_train_batches * (epoch + 1) mc_specific_data = chk["callbacks"][ - f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}}" + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -270,7 +272,8 @@ def _make_assertions(epoch, ix, version=""): assert chk["global_step"] == expected_global_step mc_specific_data = chk["callbacks"][ - f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': False}}" + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': False}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -870,7 +873,10 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ckpt_id = "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}" + ckpt_id = ( + "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + ) assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1111,7 +1117,8 @@ def training_step(self, *args): ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] ckpts = [ ckpt["callbacks"][ - "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}" + "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ] for ckpt in ckpts ] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 9485e614162b9..d424eedb3ff32 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -107,7 +107,8 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content1" in state1 and state1["content1"] == "one" assert "content1" in state2 and state2["content1"] == "two" assert ( - "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}" + "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"] ) From 333b1b36412e1cdf301f71a51bb3b188bc1529c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Aug 2021 11:48:17 +0000 Subject: [PATCH 47/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/connectors/test_callback_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index d424eedb3ff32..010a476ff82ed 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -108,8 +108,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content1" in state2 and state2["content1"] == "two" assert ( "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" - in ckpt["callbacks"] + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"] ) From f0cc8b448db2b5be65595e2aab21d61006b04bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 16 Aug 2021 10:09:05 +0200 Subject: [PATCH 48/51] Apply suggestions from code review Co-authored-by: Jirka Borovec --- tests/callbacks/test_early_stopping.py | 6 ++---- tests/checkpointing/test_model_checkpoint.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 729a3a5fe4976..69a3718c31f29 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -82,10 +82,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert ( - checkpoint["callbacks"]["EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"] - == early_stop_callback_state - ) + es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}" + assert checkpoint["callbacks"][es_name] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9723dd46b712d..1cbf3d5484395 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -45,11 +45,9 @@ def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") - assert ( - early_stopping.state_id - == "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + expected_id = "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': None}" - ) + assert early_stopping.state_id == expected_id class LogInTwoMethods(BoringModel): From 049857770f28da2eef48af303634ff9aeab61fb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 16 Aug 2021 10:17:22 +0200 Subject: [PATCH 49/51] fix syntax error --- tests/checkpointing/test_model_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 1cbf3d5484395..31c28b075322e 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -45,8 +45,10 @@ def test_model_checkpoint_state_id(): early_stopping = ModelCheckpoint(monitor="val_loss") - expected_id = "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + expected_id = ( + "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': None}" + ) assert early_stopping.state_id == expected_id From 9260725d84acae74cdd8dd1c36251f2a5a43d845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 19 Aug 2021 13:18:35 +0200 Subject: [PATCH 50/51] rename state_id -> state_key --- CHANGELOG.md | 2 +- docs/source/extensions/callbacks.rst | 10 +++++----- pytorch_lightning/callbacks/base.py | 16 ++++++++-------- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- pytorch_lightning/trainer/callback_hook.py | 4 ++-- tests/callbacks/test_callbacks.py | 4 ++-- tests/callbacks/test_early_stopping.py | 4 ++-- tests/checkpointing/test_model_checkpoint.py | 4 ++-- .../connectors/test_callback_connector.py | 4 ++-- 10 files changed, 28 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff043f8c6baa4..8a833e658eab1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flavor of `training_step` that takes `dataloader_iter` as an argument ([#8807](https://github.com/PyTorchLightning/pytorch-lightning/pull/8807)) -- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) +- Added `state_key` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - Progress tracking diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index c770f53348346..a7195ac248742 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -125,7 +125,7 @@ Note that the returned state must be able to be pickled. When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then -the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_id` property in order for Lightning +the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_key` property in order for Lightning to be able to distinguish the different states when loading the callback state. This concept is best illustrated by the following example. @@ -138,9 +138,9 @@ the following example. self.state = {"epochs": 0, "batches": 0} @property - def state_id(self): + def state_key(self): # note: we do not include `verbose` here on purpose - return self._generate_state_id(what=self.what) + return self._generate_state_key(what=self.what) def on_train_epoch_end(self, *args, **kwargs): if self.what == "epochs": @@ -173,8 +173,8 @@ A Lightning checkpoint from this Trainer with the two stateful callbacks will in } } -The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_id` is essential here. If it were missing, -Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_id` +The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_key` is essential here. If it were missing, +Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_key` by default only defines the class name as the key, e.g., here ``Counter``. diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 4c3360b6bb17a..fdb22a44ed307 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,24 +34,24 @@ class Callback(abc.ABC): """ @property - def state_id(self) -> str: + def state_key(self) -> str: """ Identifier for the state of the callback. Used to store and retrieve a callback's state from the - checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to - provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of + checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to + provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback. """ return self.__class__.__qualname__ @property - def _legacy_state_id(self) -> Type["Callback"]: - """State identifier for checkpoints saved prior to version 1.5.0.""" + def _legacy_state_key(self) -> Type["Callback"]: + """State key for checkpoints saved prior to version 1.5.0.""" return type(self) - def _generate_state_id(self, **kwargs: Any) -> str: + def _generate_state_key(self, **kwargs: Any) -> str: """ - Formats a set of key-value pairs into a state id string with the callback class name prefixed. - Useful for defining a :attr:`state_id`. + Formats a set of key-value pairs into a state key string with the callback class name prefixed. + Useful for defining a :attr:`state_key`. Args: **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f44acb85c05bf..b6ab4d2d9dacd 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -128,8 +128,8 @@ def __init__( self.monitor = monitor or "early_stop_on" @property - def state_id(self) -> str: - return self._generate_state_id(monitor=self.monitor, mode=self.mode) + def state_key(self) -> str: + return self._generate_state_key(monitor=self.monitor, mode=self.mode) def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._check_on_train_epoch_end is None: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4a618fc7b96ed..414a92af6a66c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -255,8 +255,8 @@ def __init__( self.__validate_init_configuration() @property - def state_id(self) -> str: - return self._generate_state_id( + def state_key(self) -> str: + return self._generate_state_key( monitor=self.monitor, mode=self.mode, every_n_train_steps=self._every_n_train_steps, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5aac1acb6c572..472497ef2b376 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -261,7 +261,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.state_id] = state + callback_states[callback.state_key] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: @@ -286,7 +286,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) + state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key)) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index d190feed7e1f7..dc191f4853cc1 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -109,7 +109,7 @@ def __init__(self, state): self.state = state @property - def state_id(self): + def state_key(self): return type(self) def on_save_checkpoint(self, *args): @@ -120,7 +120,7 @@ def on_load_checkpoint(self, trainer, pl_module, callback_state): def test_resume_callback_state_saved_by_type(tmpdir): - """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" + """Test that a legacy checkpoint that didn't use a state key before can still be loaded.""" model = BoringModel() callback = OldStatefulCallback(state=111) trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index ac71a590276cb..ad343cdf329f5 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -33,9 +33,9 @@ _logger = logging.getLogger(__name__) -def test_early_stopping_state_id(): +def test_early_stopping_state_key(): early_stopping = EarlyStopping(monitor="val_loss") - assert early_stopping.state_id == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}" + assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}" class EarlyStoppingTestRestore(EarlyStopping): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7ab40b2948284..f49fa16598fd2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -43,13 +43,13 @@ from tests.helpers.runif import RunIf -def test_model_checkpoint_state_id(): +def test_model_checkpoint_state_key(): early_stopping = ModelCheckpoint(monitor="val_loss") expected_id = ( "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ) - assert early_stopping.state_id == expected_id + assert early_stopping.state_key == expected_id class LogInTwoMethods(BoringModel): diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 010a476ff82ed..66cad1d613955 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -66,8 +66,8 @@ def __init__(self, unique=None, other=None): self._other = other @property - def state_id(self): - return self._generate_state_id(unique=self._unique) + def state_key(self): + return self._generate_state_key(unique=self._unique) def on_save_checkpoint(self, *args): return {"content1": self._unique} From ebb2a73be0c6d4137595c479dbfaa22b4b152d55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 19 Aug 2021 13:24:46 +0200 Subject: [PATCH 51/51] fix blacken-docs precommit complaints --- docs/source/extensions/callbacks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index a7195ac248742..b007fd479b0d0 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -162,7 +162,7 @@ the following example. A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information: -.. code-block:: python +.. code-block:: { "state_dict": ...,