Skip to content

Commit 48c6759

Browse files
committed
Revert "remove save_on_train_epoch_end from state_key"
This reverts commit 8ad08ca.
1 parent 27789fb commit 48c6759

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class ModelCheckpoint(Callback):
197197
.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
198198
following arguments:
199199
200-
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*
200+
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end*
201201
202202
Read more: :ref:`Persisting Callback State`
203203
"""
@@ -262,6 +262,7 @@ def state_key(self) -> str:
262262
every_n_train_steps=self._every_n_train_steps,
263263
every_n_epochs=self._every_n_epochs,
264264
train_time_interval=self._train_time_interval,
265+
save_on_train_epoch_end=self._save_on_train_epoch_end,
265266
)
266267

267268
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

tests/checkpointing/test_model_checkpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_model_checkpoint_state_key():
4747
early_stopping = ModelCheckpoint(monitor="val_loss")
4848
expected_id = (
4949
"ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
50-
" 'train_time_interval': None}"
50+
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
5151
)
5252
assert early_stopping.state_key == expected_id
5353

@@ -159,7 +159,7 @@ def on_validation_epoch_end(self):
159159

160160
mc_specific_data = chk["callbacks"][
161161
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
162-
" 'train_time_interval': None}"
162+
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
163163
]
164164
assert mc_specific_data["dirpath"] == checkpoint.dirpath
165165
assert mc_specific_data["monitor"] == monitor
@@ -273,7 +273,7 @@ def _make_assertions(epoch, ix, version=""):
273273

274274
mc_specific_data = chk["callbacks"][
275275
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
276-
" 'train_time_interval': None}"
276+
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
277277
]
278278
assert mc_specific_data["dirpath"] == checkpoint.dirpath
279279
assert mc_specific_data["monitor"] == monitor
@@ -875,7 +875,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
875875

876876
ckpt_id = (
877877
"ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
878-
" 'train_time_interval': None}"
878+
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
879879
)
880880
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]
881881

@@ -1118,7 +1118,7 @@ def training_step(self, *args):
11181118
ckpts = [
11191119
ckpt["callbacks"][
11201120
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
1121-
" 'train_time_interval': None}"
1121+
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
11221122
]
11231123
for ckpt in ckpts
11241124
]

tests/trainer/connectors/test_callback_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
111111
assert "content1" in state2 and state2["content1"] == "two"
112112
assert (
113113
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
114-
" 'train_time_interval': None}" in ckpt["callbacks"]
114+
" 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"]
115115
)
116116

117117

0 commit comments

Comments
 (0)