Skip to content

Commit 8c27fa7

Browse files
awaelchlicarmoccapre-commit-ci[bot]
authored
[1 / 3] improvements to saving and loading callback state (#6886)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0c0b24c commit 8c27fa7

File tree

9 files changed

+68
-26
lines changed

9 files changed

+68
-26
lines changed

CHANGELOG.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
13-
14-
15-
-
12+
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
1613

1714

1815
-
@@ -32,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3229

3330

3431

35-
-
32+
- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
3633

3734

3835
-

pytorch_lightning/callbacks/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
import abc
20-
from typing import Any, Dict, List, Optional
20+
from typing import Any, Dict, List, Optional, Type
2121

2222
import torch
2323
from torch.optim import Optimizer
@@ -33,6 +33,21 @@ class Callback(abc.ABC):
3333
Subclass this class and override any of the relevant hooks
3434
"""
3535

36+
@property
37+
def state_id(self) -> str:
38+
"""
39+
Identifier for the state of the callback. Used to store and retrieve a callback's state from the
40+
checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to
41+
provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of
42+
multiple instances of that callback.
43+
"""
44+
return self.__class__.__qualname__
45+
46+
@property
47+
def _legacy_state_id(self) -> Type["Callback"]:
48+
"""State identifier for checkpoints saved prior to version 1.5.0."""
49+
return type(self)
50+
3651
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
3752
"""Called before configure sharded model"""
3853

pytorch_lightning/trainer/callback_hook.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC
1616
from copy import deepcopy
1717
from inspect import signature
18-
from typing import Any, Callable, Dict, List, Optional, Type
18+
from typing import Any, Callable, Dict, List, Optional, Type, Union
1919

2020
import torch
2121

@@ -247,7 +247,7 @@ def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool:
247247
parameters = list(signature(fn).parameters)
248248
return len(parameters) == 1 and parameters[0] != "args"
249249

250-
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
250+
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
251251
"""Called when saving a model checkpoint."""
252252
callback_states = {}
253253
for callback in self.callbacks:
@@ -261,16 +261,15 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
261261
else:
262262
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
263263
if state:
264-
callback_states[type(callback)] = state
264+
callback_states[callback.state_id] = state
265265
return callback_states
266266

267-
def on_load_checkpoint(self, checkpoint):
267+
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
268268
"""Called when loading a model checkpoint."""
269-
270269
# Todo: the `callback_states` are dropped with TPUSpawn as they
271270
# can't be saved using `xm.save`
272271
# https://github.com/pytorch/xla/issues/2773
273-
callback_states = checkpoint.get("callbacks")
272+
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
274273

275274
if callback_states is None:
276275
return
@@ -287,7 +286,7 @@ def on_load_checkpoint(self, checkpoint):
287286
)
288287

289288
for callback in self.callbacks:
290-
state = callback_states.get(type(callback))
289+
state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id))
291290
if state:
292291
state = deepcopy(state)
293292
if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint):

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
308308
structured dictionary: {
309309
'epoch': training epoch
310310
'global_step': training global step
311-
'pytorch-lightning_version': PyTorch Lightning's version
311+
'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
312312
'callbacks': "callback specific state"[] # if not weights_only
313313
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
314314
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only

tests/callbacks/test_callbacks.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from pathlib import Path
1415
from unittest.mock import call, Mock
1516

16-
from pytorch_lightning import Trainer
17+
from pytorch_lightning import Callback, Trainer
1718
from tests.helpers import BoringModel
1819

1920

@@ -101,3 +102,33 @@ def configure_callbacks(self):
101102
trainer_fn(model)
102103
callbacks_after = trainer.callbacks.copy()
103104
assert callbacks_after == callbacks_after_fit
105+
106+
107+
class OldStatefulCallback(Callback):
108+
def __init__(self, state):
109+
self.state = state
110+
111+
@property
112+
def state_id(self):
113+
return type(self)
114+
115+
def on_save_checkpoint(self, *args):
116+
return {"state": self.state}
117+
118+
def on_load_checkpoint(self, trainer, pl_module, callback_state):
119+
self.state = callback_state["state"]
120+
121+
122+
def test_resume_callback_state_saved_by_type(tmpdir):
123+
"""Test that a legacy checkpoint that didn't use a state identifier before can still be loaded."""
124+
model = BoringModel()
125+
callback = OldStatefulCallback(state=111)
126+
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
127+
trainer.fit(model)
128+
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
129+
assert ckpt_path.exists()
130+
131+
callback = OldStatefulCallback(state=222)
132+
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
133+
trainer.fit(model)
134+
assert callback.state == 111

tests/callbacks/test_early_stopping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
7676
checkpoint = torch.load(checkpoint_filepath)
7777
# the checkpoint saves "epoch + 1"
7878
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
79-
assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state
79+
assert 4 == len(early_stop_callback.saved_states)
80+
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
8081

8182
# ensure state is reloaded properly (assertion in the callback)
8283
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")

tests/checkpointing/test_model_checkpoint.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def on_validation_epoch_end(self):
148148
assert chk["epoch"] == epoch + 1
149149
assert chk["global_step"] == limit_train_batches * (epoch + 1)
150150

151-
mc_specific_data = chk["callbacks"][type(checkpoint)]
151+
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
152152
assert mc_specific_data["dirpath"] == checkpoint.dirpath
153153
assert mc_specific_data["monitor"] == monitor
154154
assert mc_specific_data["current_score"] == score
@@ -259,7 +259,7 @@ def _make_assertions(epoch, ix, version=""):
259259
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
260260
assert chk["global_step"] == expected_global_step
261261

262-
mc_specific_data = chk["callbacks"][type(checkpoint)]
262+
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
263263
assert mc_specific_data["dirpath"] == checkpoint.dirpath
264264
assert mc_specific_data["monitor"] == monitor
265265
assert mc_specific_data["current_score"] == score
@@ -857,9 +857,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
857857

858858
assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
859859
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]
860-
861-
ch_type = type(model_checkpoint)
862-
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
860+
assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"]
863861

864862
# it is easier to load the model objects than to iterate over the raw dict of tensors
865863
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
@@ -1097,7 +1095,7 @@ def training_step(self, *args):
10971095
trainer.fit(TestModel())
10981096
assert model_checkpoint.current_score == 0.3
10991097
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
1100-
ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts]
1098+
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
11011099
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]
11021100

11031101

tests/trainer/connectors/test_callback_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
7676
trainer.fit(model)
7777

7878
ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
79-
state0 = ckpt["callbacks"][type(callback0)]
80-
state1 = ckpt["callbacks"][type(callback1)]
79+
state0 = ckpt["callbacks"]["StatefulCallback0"]
80+
state1 = ckpt["callbacks"]["StatefulCallback1"]
8181
assert "content0" in state0 and state0["content0"] == 0
8282
assert "content1" in state1 and state1["content1"] == 1
83-
assert type(checkpoint_callback) in ckpt["callbacks"]
83+
assert "ModelCheckpoint" in ckpt["callbacks"]
8484

8585

8686
def test_attach_model_callbacks():

tests/trainer/logging_/test_logger_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727
from tests.helpers.boring_model import BoringModel, RandomDataset
2828
from tests.helpers.runif import RunIf
29+
from tests.models.test_hooks import get_members
2930

3031

3132
def test_fx_validator(tmpdir):
32-
funcs_name = sorted(f for f in dir(Callback) if not f.startswith("_"))
33+
funcs_name = sorted(get_members(Callback))
3334

3435
callbacks_func = [
3536
"on_before_backward",

0 commit comments

Comments
 (0)