Skip to content

Commit 0abd6e9

Browse files
awaelchlicarmocca
andauthored
[3 / 3] improvements to saving and loading callback state (#7161)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 366fb39 commit 0abd6e9

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,16 @@ def state_key(self) -> str:
265265
save_on_train_epoch_end=self._save_on_train_epoch_end,
266266
)
267267

268-
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
269-
"""
270-
When pretrain routine starts we build the ckpt dir on the fly
271-
"""
272-
self.__resolve_ckpt_dir(trainer)
268+
def on_init_end(self, trainer: "pl.Trainer") -> None:
273269
if self._save_on_train_epoch_end is None:
274270
# if the user runs validation multiple times per training epoch, we try to save checkpoint after
275271
# validation instead of on train epoch end
276272
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0
277273

274+
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
275+
"""When pretrain routine starts we build the ckpt dir on the fly."""
276+
self.__resolve_ckpt_dir(trainer)
277+
278278
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
279279
self._last_time_checked = time.monotonic()
280280

pytorch_lightning/trainer/callback_hook.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from abc import ABC
1615
from copy import deepcopy
1716
from typing import Any, Dict, List, Optional, Type, Union
1817

1918
import torch
19+
from packaging.version import Version
2020

2121
import pytorch_lightning as pl
2222
from pytorch_lightning.callbacks import Callback
@@ -255,14 +255,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
255255
if callback_states is None:
256256
return
257257

258-
current_callbacks_type = {type(cb) for cb in self.callbacks}
259-
saved_callbacks_type = set(callback_states.keys())
260-
difference = saved_callbacks_type.difference(current_callbacks_type)
258+
is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev")
259+
current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks}
260+
difference = callback_states.keys() - current_callbacks_keys
261261
if difference:
262262
rank_zero_warn(
263-
"Be aware that when using ``resume_from_checkpoint``, "
264-
"callbacks used to create the checkpoint need to be provided. "
265-
f"Please, add the following callbacks: {list(difference)}. ",
263+
"Be aware that when using `resume_from_checkpoint`,"
264+
" callbacks used to create the checkpoint need to be provided."
265+
f" Please add the following callbacks: {list(difference)}.",
266266
UserWarning,
267267
)
268268

tests/callbacks/test_callbacks.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pathlib import Path
15+
from re import escape
1516
from unittest.mock import call, Mock
1617

18+
import pytest
19+
1720
from pytorch_lightning import Callback, Trainer
21+
from pytorch_lightning.callbacks import ModelCheckpoint
1822
from tests.helpers import BoringModel
23+
from tests.helpers.utils import no_warning_call
1924

2025

2126
def test_callbacks_configured_in_model(tmpdir):
@@ -132,3 +137,34 @@ def test_resume_callback_state_saved_by_type(tmpdir):
132137
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
133138
trainer.fit(model)
134139
assert callback.state == 111
140+
141+
142+
def test_resume_incomplete_callbacks_list_warning(tmpdir):
143+
model = BoringModel()
144+
callback0 = ModelCheckpoint(monitor="epoch")
145+
callback1 = ModelCheckpoint(monitor="global_step")
146+
trainer = Trainer(
147+
default_root_dir=tmpdir,
148+
max_steps=1,
149+
callbacks=[callback0, callback1],
150+
)
151+
trainer.fit(model)
152+
ckpt_path = trainer.checkpoint_callback.best_model_path
153+
154+
trainer = Trainer(
155+
default_root_dir=tmpdir,
156+
max_steps=1,
157+
callbacks=[callback1], # one callback is missing!
158+
resume_from_checkpoint=ckpt_path,
159+
)
160+
with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")):
161+
trainer.fit(model)
162+
163+
trainer = Trainer(
164+
default_root_dir=tmpdir,
165+
max_steps=1,
166+
callbacks=[callback1, callback0], # all callbacks here, order switched
167+
resume_from_checkpoint=ckpt_path,
168+
)
169+
with no_warning_call(UserWarning, match="Please add the following callbacks:"):
170+
trainer.fit(model)

0 commit comments

Comments
 (0)