From 5ac921e9bba8a85ac9ee215a60042af0b1758e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Aug 2021 14:08:42 +0200 Subject: [PATCH 1/9] fix warning --- pytorch_lightning/trainer/callback_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 36a3e9abb7b7a..3d0d6e507745e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -255,14 +255,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if callback_states is None: return - current_callbacks_type = {type(cb) for cb in self.callbacks} + current_callbacks_type = {cb.state_id for cb in self.callbacks} saved_callbacks_type = set(callback_states.keys()) difference = saved_callbacks_type.difference(current_callbacks_type) if difference: rank_zero_warn( - "Be aware that when using ``resume_from_checkpoint``, " - "callbacks used to create the checkpoint need to be provided. " - f"Please, add the following callbacks: {list(difference)}. ", + "Be aware that when using ``resume_from_checkpoint``," + " callbacks used to create the checkpoint need to be provided." + f" Please, add the following callbacks: {list(difference)}.", UserWarning, ) From 8ad08ca9ea8b04c9fcf1fd1d46ea881f2e1708e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 22:32:44 +0200 Subject: [PATCH 2/9] remove save_on_train_epoch_end from state_key --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- tests/checkpointing/test_model_checkpoint.py | 10 +++++----- tests/trainer/connectors/test_callback_connector.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 414a92af6a66c..0bdc7edf86923 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -197,7 +197,7 @@ class ModelCheckpoint(Callback): .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments: - *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end* + *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval* Read more: :ref:`Persisting Callback State` """ @@ -262,7 +262,6 @@ def state_key(self) -> str: every_n_train_steps=self._every_n_train_steps, every_n_epochs=self._every_n_epochs, train_time_interval=self._train_time_interval, - save_on_train_epoch_end=self._save_on_train_epoch_end, ) def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f49fa16598fd2..259ff2a3d9c4f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -47,7 +47,7 @@ def test_model_checkpoint_state_key(): early_stopping = ModelCheckpoint(monitor="val_loss") expected_id = ( "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': None}" + " 'train_time_interval': None}" ) assert early_stopping.state_key == expected_id @@ -159,7 +159,7 @@ def on_validation_epoch_end(self): mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + " 'train_time_interval': None}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -273,7 +273,7 @@ def _make_assertions(epoch, ix, version=""): mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': False}" + " 'train_time_interval': None}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -875,7 +875,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_id = ( "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + " 'train_time_interval': None}" ) assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] @@ -1118,7 +1118,7 @@ def training_step(self, *args): ckpts = [ ckpt["callbacks"][ "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + " 'train_time_interval': None}" ] for ckpt in ckpts ] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 455e08dc10ad5..ce6c1c544784f 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -111,7 +111,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content1" in state2 and state2["content1"] == "two" assert ( "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"] + " 'train_time_interval': None}" in ckpt["callbacks"] ) From e775881fae27ecb4f312b8946474c3b5c3661609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 12:13:26 +0200 Subject: [PATCH 3/9] update legacy checkpoint handling --- pytorch_lightning/trainer/callback_hook.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 3d0d6e507745e..be5c05c26bfac 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -11,16 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import operator from abc import ABC from copy import deepcopy from typing import Any, Dict, List, Optional, Type, Union import torch +import pytorch_lightning import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.imports import _compare_version from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -255,12 +257,13 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if callback_states is None: return - current_callbacks_type = {cb.state_id for cb in self.callbacks} - saved_callbacks_type = set(callback_states.keys()) - difference = saved_callbacks_type.difference(current_callbacks_type) + is_legacy_ckpt = operator.lt(checkpoint["pytorch-lightning_version"], "1.5") + current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks} + saved_callbacks_keys = set(callback_states.keys()) + difference = saved_callbacks_keys.difference(current_callbacks_keys) if difference: rank_zero_warn( - "Be aware that when using ``resume_from_checkpoint``," + "Be aware that when using `resume_from_checkpoint`," " callbacks used to create the checkpoint need to be provided." f" Please, add the following callbacks: {list(difference)}.", UserWarning, From 0d820745ea9bbf6d6f0d8c99db45d292efaaa721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 15:26:45 +0200 Subject: [PATCH 4/9] add test --- pytorch_lightning/trainer/callback_hook.py | 2 +- tests/callbacks/test_callbacks.py | 36 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index be5c05c26bfac..2c8b2faba7943 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -265,7 +265,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: rank_zero_warn( "Be aware that when using `resume_from_checkpoint`," " callbacks used to create the checkpoint need to be provided." - f" Please, add the following callbacks: {list(difference)}.", + f" Please add the following callbacks: {list(difference)}.", UserWarning, ) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index dc191f4853cc1..6e8bb1fd3594b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path +from re import escape from unittest.mock import call, Mock +import pytest + from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call def test_callbacks_configured_in_model(tmpdir): @@ -132,3 +137,34 @@ def test_resume_callback_state_saved_by_type(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) trainer.fit(model) assert callback.state == 111 + + +def test_resume_incomplete_callbacks_list_warning(tmpdir): + model = BoringModel() + callback0 = ModelCheckpoint(monitor="epoch") + callback1 = ModelCheckpoint(monitor="global_step") + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback0, callback1], + ) + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback1], # one callback is missing! + resume_from_checkpoint=ckpt_path, + ) + with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")): + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback1, callback0], # all callbacks here, order switched + resume_from_checkpoint=ckpt_path, + ) + with no_warning_call(UserWarning, match=escape(f"Please add the following callbacks:")): + trainer.fit(model) From 9ddb21160ddc8ecedbb37df02eb44b4f9c251960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 15:37:22 +0200 Subject: [PATCH 5/9] fix precommit --- pytorch_lightning/trainer/callback_hook.py | 4 +--- tests/callbacks/test_callbacks.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 2c8b2faba7943..31c1dff89502c 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -18,11 +18,9 @@ import torch -import pytorch_lightning import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn -from pytorch_lightning.utilities.imports import _compare_version +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 6e8bb1fd3594b..c363638d565d2 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -18,7 +18,7 @@ import pytest from pytorch_lightning import Callback, Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import ModelCheckpoint from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -166,5 +166,5 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir): callbacks=[callback1, callback0], # all callbacks here, order switched resume_from_checkpoint=ckpt_path, ) - with no_warning_call(UserWarning, match=escape(f"Please add the following callbacks:")): + with no_warning_call(UserWarning, match="Please add the following callbacks:"): trainer.fit(model) From 8137f4ef1e8a5cc787fe19998b769cffd343eefd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 02:45:21 +0200 Subject: [PATCH 6/9] Update pytorch_lightning/trainer/callback_hook.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/callback_hook.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 31c1dff89502c..c4ac1bebff71e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -257,8 +257,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: is_legacy_ckpt = operator.lt(checkpoint["pytorch-lightning_version"], "1.5") current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks} - saved_callbacks_keys = set(callback_states.keys()) - difference = saved_callbacks_keys.difference(current_callbacks_keys) + difference = callback_states.keys() - current_callbacks_keys if difference: rank_zero_warn( "Be aware that when using `resume_from_checkpoint`," From 27789fb54deac39ce60d9f276a602503331a916c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 02:58:29 +0200 Subject: [PATCH 7/9] update version parsing --- pytorch_lightning/trainer/callback_hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index c4ac1bebff71e..bbfcbb22802a8 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import operator from abc import ABC from copy import deepcopy from typing import Any, Dict, List, Optional, Type, Union import torch +from packaging.version import Version import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback @@ -255,7 +255,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if callback_states is None: return - is_legacy_ckpt = operator.lt(checkpoint["pytorch-lightning_version"], "1.5") + is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev") current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks} difference = callback_states.keys() - current_callbacks_keys if difference: From 48c675980a3d77ea410f3346aba92d751949708d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 22:32:44 +0200 Subject: [PATCH 8/9] Revert "remove save_on_train_epoch_end from state_key" This reverts commit 8ad08ca9ea8b04c9fcf1fd1d46ea881f2e1708e2. --- pytorch_lightning/callbacks/model_checkpoint.py | 3 ++- tests/checkpointing/test_model_checkpoint.py | 10 +++++----- tests/trainer/connectors/test_callback_connector.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0bdc7edf86923..414a92af6a66c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -197,7 +197,7 @@ class ModelCheckpoint(Callback): .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments: - *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval* + *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end* Read more: :ref:`Persisting Callback State` """ @@ -262,6 +262,7 @@ def state_key(self) -> str: every_n_train_steps=self._every_n_train_steps, every_n_epochs=self._every_n_epochs, train_time_interval=self._train_time_interval, + save_on_train_epoch_end=self._save_on_train_epoch_end, ) def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 259ff2a3d9c4f..f49fa16598fd2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -47,7 +47,7 @@ def test_model_checkpoint_state_key(): early_stopping = ModelCheckpoint(monitor="val_loss") expected_id = ( "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None}" + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ) assert early_stopping.state_key == expected_id @@ -159,7 +159,7 @@ def on_validation_epoch_end(self): mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None}" + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -273,7 +273,7 @@ def _make_assertions(epoch, ix, version=""): mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None}" + " 'train_time_interval': None, 'save_on_train_epoch_end': False}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -875,7 +875,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_id = ( "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None}" + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ) assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] @@ -1118,7 +1118,7 @@ def training_step(self, *args): ckpts = [ ckpt["callbacks"][ "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None}" + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" ] for ckpt in ckpts ] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index ce6c1c544784f..455e08dc10ad5 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -111,7 +111,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content1" in state2 and state2["content1"] == "two" assert ( "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None}" in ckpt["callbacks"] + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"] ) From 28ac360c20666fe883ed11e57dca95e5e4002b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 03:24:33 +0200 Subject: [PATCH 9/9] move the val_check_interval query to on_init_end --- pytorch_lightning/callbacks/model_checkpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 414a92af6a66c..e7daa4ee53cde 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -265,16 +265,16 @@ def state_key(self) -> str: save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """ - When pretrain routine starts we build the ckpt dir on the fly - """ - self.__resolve_ckpt_dir(trainer) + def on_init_end(self, trainer: "pl.Trainer") -> None: if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch, we try to save checkpoint after # validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """When pretrain routine starts we build the ckpt dir on the fly.""" + self.__resolve_ckpt_dir(trainer) + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._last_time_checked = time.monotonic()