Skip to content

Commit b9443a0

Browse files
awaelchlicarmoccapre-commit-ci[bot]Borda
authored
[2 / 3] improvements to saving and loading callback state (#7187)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 376734a commit b9443a0

File tree

11 files changed

+180
-27
lines changed

11 files changed

+180
-27
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added a flavor of `training_step` that takes `dataloader_iter` as an argument ([#8807](https://github.com/PyTorchLightning/pytorch-lightning/pull/8807))
2020

2121

22-
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
22+
- Added `state_key` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
2323

2424

2525
- Progress tracking
@@ -60,6 +60,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6060
* Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045))
6161

6262

63+
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
64+
65+
6366
- Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974))
6467

6568

docs/source/extensions/callbacks.rst

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,69 @@ Lightning has a few built-in callbacks.
113113

114114
----------
115115

116+
.. _Persisting Callback State:
117+
116118
Persisting State
117119
----------------
118120

119121
Some callbacks require internal state in order to function properly. You can optionally
120122
choose to persist your callback's state as part of model checkpoint files using the callback hooks
121123
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
122-
However, you must follow two constraints:
124+
Note that the returned state must be able to be pickled.
125+
126+
When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
127+
to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
128+
the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_key` property in order for Lightning
129+
to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
130+
the following example.
131+
132+
.. testcode::
133+
134+
class Counter(Callback):
135+
def __init__(self, what="epochs", verbose=True):
136+
self.what = what
137+
self.verbose = verbose
138+
self.state = {"epochs": 0, "batches": 0}
139+
140+
@property
141+
def state_key(self):
142+
# note: we do not include `verbose` here on purpose
143+
return self._generate_state_key(what=self.what)
144+
145+
def on_train_epoch_end(self, *args, **kwargs):
146+
if self.what == "epochs":
147+
self.state["epochs"] += 1
148+
149+
def on_train_batch_end(self, *args, **kwargs):
150+
if self.what == "batches":
151+
self.state["batches"] += 1
152+
153+
def on_load_checkpoint(self, trainer, pl_module, callback_state):
154+
self.state.update(callback_state)
155+
156+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
157+
return self.state.copy()
158+
159+
160+
# two callbacks of the same type are being used
161+
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])
162+
163+
A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:
164+
165+
.. code-block::
166+
167+
{
168+
"state_dict": ...,
169+
"callbacks": {
170+
"Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
171+
"Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
172+
...
173+
}
174+
}
123175
124-
1. Your returned state must be able to be pickled.
125-
2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class.
176+
The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_key` is essential here. If it were missing,
177+
Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_key`
178+
by default only defines the class name as the key, e.g., here ``Counter``.
126179

127180

128181
Best Practices

pytorch_lightning/callbacks/base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,30 @@ class Callback(abc.ABC):
3434
"""
3535

3636
@property
37-
def state_id(self) -> str:
37+
def state_key(self) -> str:
3838
"""
3939
Identifier for the state of the callback. Used to store and retrieve a callback's state from the
40-
checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to
41-
provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of
40+
checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to
41+
provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of
4242
multiple instances of that callback.
4343
"""
4444
return self.__class__.__qualname__
4545

4646
@property
47-
def _legacy_state_id(self) -> Type["Callback"]:
48-
"""State identifier for checkpoints saved prior to version 1.5.0."""
47+
def _legacy_state_key(self) -> Type["Callback"]:
48+
"""State key for checkpoints saved prior to version 1.5.0."""
4949
return type(self)
5050

51+
def _generate_state_key(self, **kwargs: Any) -> str:
52+
"""
53+
Formats a set of key-value pairs into a state key string with the callback class name prefixed.
54+
Useful for defining a :attr:`state_key`.
55+
56+
Args:
57+
**kwargs: A set of key-value pairs. Must be serializable to :class:`str`.
58+
"""
59+
return f"{self.__class__.__qualname__}{repr(kwargs)}"
60+
5161
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
5262
"""Called before configure sharded model"""
5363

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ class EarlyStopping(Callback):
7575
>>> from pytorch_lightning.callbacks import EarlyStopping
7676
>>> early_stopping = EarlyStopping('val_loss')
7777
>>> trainer = Trainer(callbacks=[early_stopping])
78+
79+
.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
80+
following arguments:
81+
82+
*monitor, mode*
83+
84+
Read more: :ref:`Persisting Callback State`
7885
"""
7986
mode_dict = {"min": torch.lt, "max": torch.gt}
8087

@@ -120,6 +127,10 @@ def __init__(
120127
)
121128
self.monitor = monitor or "early_stop_on"
122129

130+
@property
131+
def state_key(self) -> str:
132+
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
133+
123134
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
124135
if self._check_on_train_epoch_end is None:
125136
# if the user runs validation multiple times per training epoch, we try to check after

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ class ModelCheckpoint(Callback):
194194
trainer.fit(model)
195195
checkpoint_callback.best_model_path
196196
197+
.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
198+
following arguments:
199+
200+
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end*
201+
202+
Read more: :ref:`Persisting Callback State`
197203
"""
198204

199205
CHECKPOINT_JOIN_CHAR = "-"
@@ -248,6 +254,17 @@ def __init__(
248254
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period)
249255
self.__validate_init_configuration()
250256

257+
@property
258+
def state_key(self) -> str:
259+
return self._generate_state_key(
260+
monitor=self.monitor,
261+
mode=self.mode,
262+
every_n_train_steps=self._every_n_train_steps,
263+
every_n_epochs=self._every_n_epochs,
264+
train_time_interval=self._train_time_interval,
265+
save_on_train_epoch_end=self._save_on_train_epoch_end,
266+
)
267+
251268
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
252269
"""
253270
When pretrain routine starts we build the ckpt dir on the fly

pytorch_lightning/trainer/callback_hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
242242
for callback in self.callbacks:
243243
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
244244
if state:
245-
callback_states[callback.state_id] = state
245+
callback_states[callback.state_key] = state
246246
return callback_states
247247

248248
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
@@ -267,7 +267,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
267267
)
268268

269269
for callback in self.callbacks:
270-
state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id))
270+
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
271271
if state:
272272
state = deepcopy(state)
273273
callback.on_load_checkpoint(self, self.lightning_module, state)

tests/callbacks/test_callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, state):
109109
self.state = state
110110

111111
@property
112-
def state_id(self):
112+
def state_key(self):
113113
return type(self)
114114

115115
def on_save_checkpoint(self, *args):
@@ -120,7 +120,7 @@ def on_load_checkpoint(self, trainer, pl_module, callback_state):
120120

121121

122122
def test_resume_callback_state_saved_by_type(tmpdir):
123-
"""Test that a legacy checkpoint that didn't use a state identifier before can still be loaded."""
123+
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded."""
124124
model = BoringModel()
125125
callback = OldStatefulCallback(state=111)
126126
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])

tests/callbacks/test_early_stopping.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
_logger = logging.getLogger(__name__)
3434

3535

36+
def test_early_stopping_state_key():
37+
early_stopping = EarlyStopping(monitor="val_loss")
38+
assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"
39+
40+
3641
class EarlyStoppingTestRestore(EarlyStopping):
3742
# this class has to be defined outside the test function, otherwise we get pickle error
3843
def __init__(self, expected_state, *args, **kwargs):
@@ -77,7 +82,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
7782
# the checkpoint saves "epoch + 1"
7883
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
7984
assert 4 == len(early_stop_callback.saved_states)
80-
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
85+
es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
86+
assert checkpoint["callbacks"][es_name] == early_stop_callback_state
8187

8288
# ensure state is reloaded properly (assertion in the callback)
8389
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")

tests/callbacks/test_lambda_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import inspect
1514
from functools import partial
1615

1716
from pytorch_lightning import seed_everything, Trainer
1817
from pytorch_lightning.callbacks import Callback, LambdaCallback
1918
from tests.helpers.boring_model import BoringModel
19+
from tests.models.test_hooks import get_members
2020

2121

2222
def test_lambda_call(tmpdir):
@@ -32,7 +32,7 @@ def on_train_epoch_start(self):
3232
def call(hook, *_, **__):
3333
checker.add(hook)
3434

35-
hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)}
35+
hooks = get_members(Callback)
3636
hooks_args = {h: partial(call, h) for h in hooks}
3737
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")]
3838

tests/checkpointing/test_model_checkpoint.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@
4343
from tests.helpers.runif import RunIf
4444

4545

46+
def test_model_checkpoint_state_key():
47+
early_stopping = ModelCheckpoint(monitor="val_loss")
48+
expected_id = (
49+
"ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
50+
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
51+
)
52+
assert early_stopping.state_key == expected_id
53+
54+
4655
class LogInTwoMethods(BoringModel):
4756
def training_step(self, batch, batch_idx):
4857
out = super().training_step(batch, batch_idx)
@@ -148,7 +157,10 @@ def on_validation_epoch_end(self):
148157
assert chk["epoch"] == epoch + 1
149158
assert chk["global_step"] == limit_train_batches * (epoch + 1)
150159

151-
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
160+
mc_specific_data = chk["callbacks"][
161+
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
162+
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
163+
]
152164
assert mc_specific_data["dirpath"] == checkpoint.dirpath
153165
assert mc_specific_data["monitor"] == monitor
154166
assert mc_specific_data["current_score"] == score
@@ -259,7 +271,10 @@ def _make_assertions(epoch, ix, version=""):
259271
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
260272
assert chk["global_step"] == expected_global_step
261273

262-
mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
274+
mc_specific_data = chk["callbacks"][
275+
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
276+
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
277+
]
263278
assert mc_specific_data["dirpath"] == checkpoint.dirpath
264279
assert mc_specific_data["monitor"] == monitor
265280
assert mc_specific_data["current_score"] == score
@@ -857,7 +872,12 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
857872

858873
assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
859874
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]
860-
assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"]
875+
876+
ckpt_id = (
877+
"ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
878+
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
879+
)
880+
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]
861881

862882
# it is easier to load the model objects than to iterate over the raw dict of tensors
863883
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
@@ -1095,7 +1115,13 @@ def training_step(self, *args):
10951115
trainer.fit(TestModel())
10961116
assert model_checkpoint.current_score == 0.3
10971117
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
1098-
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
1118+
ckpts = [
1119+
ckpt["callbacks"][
1120+
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
1121+
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
1122+
]
1123+
for ckpt in ckpts
1124+
]
10991125
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]
11001126

11011127

0 commit comments

Comments
 (0)