Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,16 @@ def state_key(self) -> str:
save_on_train_epoch_end=self._save_on_train_epoch_end,
)

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""
When pretrain routine starts we build the ckpt dir on the fly
"""
self.__resolve_ckpt_dir(trainer)
def on_init_end(self, trainer: "pl.Trainer") -> None:
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch, we try to save checkpoint after
# validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
self.__resolve_ckpt_dir(trainer)

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()

Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type, Union

import torch
from packaging.version import Version

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -255,14 +255,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if callback_states is None:
return

current_callbacks_type = {type(cb) for cb in self.callbacks}
saved_callbacks_type = set(callback_states.keys())
difference = saved_callbacks_type.difference(current_callbacks_type)
is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev")
current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks}
difference = callback_states.keys() - current_callbacks_keys
if difference:
rank_zero_warn(
"Be aware that when using ``resume_from_checkpoint``, "
"callbacks used to create the checkpoint need to be provided. "
f"Please, add the following callbacks: {list(difference)}. ",
"Be aware that when using `resume_from_checkpoint`,"
" callbacks used to create the checkpoint need to be provided."
f" Please add the following callbacks: {list(difference)}.",
UserWarning,
)

Expand Down
36 changes: 36 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from re import escape
from unittest.mock import call, Mock

import pytest

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call


def test_callbacks_configured_in_model(tmpdir):
Expand Down Expand Up @@ -132,3 +137,34 @@ def test_resume_callback_state_saved_by_type(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
trainer.fit(model)
assert callback.state == 111


def test_resume_incomplete_callbacks_list_warning(tmpdir):
model = BoringModel()
callback0 = ModelCheckpoint(monitor="epoch")
callback1 = ModelCheckpoint(monitor="global_step")
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
callbacks=[callback0, callback1],
)
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path

trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
callbacks=[callback1], # one callback is missing!
resume_from_checkpoint=ckpt_path,
)
with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")):
trainer.fit(model)

trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
callbacks=[callback1, callback0], # all callbacks here, order switched
resume_from_checkpoint=ckpt_path,
)
with no_warning_call(UserWarning, match="Please add the following callbacks:"):
trainer.fit(model)