diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 414a92af6a66c..e7daa4ee53cde 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -265,16 +265,16 @@ def state_key(self) -> str: save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """ - When pretrain routine starts we build the ckpt dir on the fly - """ - self.__resolve_ckpt_dir(trainer) + def on_init_end(self, trainer: "pl.Trainer") -> None: if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch, we try to save checkpoint after # validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """When pretrain routine starts we build the ckpt dir on the fly.""" + self.__resolve_ckpt_dir(trainer) + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._last_time_checked = time.monotonic() diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 36a3e9abb7b7a..bbfcbb22802a8 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from abc import ABC from copy import deepcopy from typing import Any, Dict, List, Optional, Type, Union import torch +from packaging.version import Version import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback @@ -255,14 +255,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if callback_states is None: return - current_callbacks_type = {type(cb) for cb in self.callbacks} - saved_callbacks_type = set(callback_states.keys()) - difference = saved_callbacks_type.difference(current_callbacks_type) + is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev") + current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks} + difference = callback_states.keys() - current_callbacks_keys if difference: rank_zero_warn( - "Be aware that when using ``resume_from_checkpoint``, " - "callbacks used to create the checkpoint need to be provided. " - f"Please, add the following callbacks: {list(difference)}. ", + "Be aware that when using `resume_from_checkpoint`," + " callbacks used to create the checkpoint need to be provided." + f" Please add the following callbacks: {list(difference)}.", UserWarning, ) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index dc191f4853cc1..c363638d565d2 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path +from re import escape from unittest.mock import call, Mock +import pytest + from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call def test_callbacks_configured_in_model(tmpdir): @@ -132,3 +137,34 @@ def test_resume_callback_state_saved_by_type(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) trainer.fit(model) assert callback.state == 111 + + +def test_resume_incomplete_callbacks_list_warning(tmpdir): + model = BoringModel() + callback0 = ModelCheckpoint(monitor="epoch") + callback1 = ModelCheckpoint(monitor="global_step") + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback0, callback1], + ) + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback1], # one callback is missing! + resume_from_checkpoint=ckpt_path, + ) + with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")): + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback1, callback0], # all callbacks here, order switched + resume_from_checkpoint=ckpt_path, + ) + with no_warning_call(UserWarning, match="Please add the following callbacks:"): + trainer.fit(model)