Skip to content

Commit cea170e

Browse files
ananthsubBorda
andauthored
[feat] Support iteration-based checkpointing in model checkpoint callback (#6146)
* Update model_checkpoint.py * add tests * Update model_checkpoint.py * Update test_model_checkpoint.py * fix tests * every_n_batches * Update test_model_checkpoint.py * defaults * rm tests * Update model_checkpoint.py * Update test_model_checkpoint.py * Prune deprecated metrics for 1.3 (#6161) * prune deprecated metrics for 1.3 * isort / yapf * Update model_checkpoint.py * add tests * defaults * Update CHANGELOG.md * pre-commit * Update model_checkpoint.py * update defaults * Update test_remove_1-5.py * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * fix tests * Update test_model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * Update test_model_checkpoint.py * ckpt-callback * Update test_model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * validation-end * Update model_checkpoint.py * Update test_model_checkpoint.py * Update test_model_checkpoint.py * Update test_model_checkpoint.py * Update test_model_checkpoint.py * clarify-names - Make names explicit as to which hooks they apply to - Use step instead of batch for consistency with global step * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * Update model_checkpoint.py * mutual-exclusive Make every_n_train_steps and every_n_val_epochs mutually exclusive * fix-default-0 * Update CHANGELOG.md * formatting * make-private make attributes private to the class * rebase Co-authored-by: Jirka Borovec <[email protected]>
1 parent 62d4304 commit cea170e

File tree

4 files changed

+239
-32
lines changed

4 files changed

+239
-32
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
1313

14+
- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
1415

1516
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
1617

@@ -55,6 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5556

5657
### Deprecated
5758

59+
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
60+
5861

5962
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
6063

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 105 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,25 @@ class ModelCheckpoint(Callback):
9393
save_weights_only: if ``True``, then only the model's weights will be
9494
saved (``model.save_weights(filepath)``), else the full model
9595
is saved (``model.save(filepath)``).
96+
every_n_train_steps: Number of training steps between checkpoints.
97+
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
98+
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
99+
This must be mutually exclusive with ``every_n_val_epochs``.
100+
every_n_val_epochs: Number of validation epochs between checkpoints.
101+
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
102+
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
103+
This must be mutually exclusive with ``every_n_train_steps``.
104+
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
105+
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
106+
will only save checkpoints at epochs 0 < E <= N
107+
where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
96108
period: Interval (number of epochs) between checkpoints.
97109
110+
.. warning::
111+
This argument has been deprecated in v1.3 and will be removed in v1.5.
112+
113+
Use ``every_n_val_epochs`` instead.
114+
98115
Note:
99116
For extra customization, ModelCheckpoint includes the following attributes:
100117
@@ -165,16 +182,17 @@ def __init__(
165182
save_top_k: Optional[int] = None,
166183
save_weights_only: bool = False,
167184
mode: str = "min",
168-
period: int = 1,
169-
auto_insert_metric_name: bool = True
185+
auto_insert_metric_name: bool = True,
186+
every_n_train_steps: Optional[int] = None,
187+
every_n_val_epochs: Optional[int] = None,
188+
period: Optional[int] = None,
170189
):
171190
super().__init__()
172191
self.monitor = monitor
173192
self.verbose = verbose
174193
self.save_last = save_last
175194
self.save_top_k = save_top_k
176195
self.save_weights_only = save_weights_only
177-
self.period = period
178196
self.auto_insert_metric_name = auto_insert_metric_name
179197
self._last_global_step_saved = -1
180198
self.current_score = None
@@ -188,6 +206,7 @@ def __init__(
188206

189207
self.__init_monitor_mode(monitor, mode)
190208
self.__init_ckpt_dir(dirpath, filename, save_top_k)
209+
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
191210
self.__validate_init_configuration()
192211

193212
def on_pretrain_routine_start(self, trainer, pl_module):
@@ -197,10 +216,26 @@ def on_pretrain_routine_start(self, trainer, pl_module):
197216
self.__resolve_ckpt_dir(trainer)
198217
self.save_function = trainer.save_checkpoint
199218

200-
def on_validation_end(self, trainer, pl_module):
219+
def on_train_batch_end(self, trainer, *args, **kwargs) -> None:
220+
""" Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
221+
if self._should_skip_saving_checkpoint(trainer):
222+
return
223+
step = trainer.global_step
224+
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
225+
if skip_batch:
226+
return
227+
self.save_checkpoint(trainer)
228+
229+
def on_validation_end(self, trainer, *args, **kwargs) -> None:
201230
"""
202231
checkpoints can be saved at the end of the val loop
203232
"""
233+
skip = (
234+
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
235+
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
236+
)
237+
if skip:
238+
return
204239
self.save_checkpoint(trainer)
205240

206241
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
@@ -228,20 +263,8 @@ def save_checkpoint(self, trainer, unused: Optional = None):
228263
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
229264
)
230265

231-
epoch = trainer.current_epoch
232266
global_step = trainer.global_step
233267

234-
from pytorch_lightning.trainer.states import TrainerState
235-
if (
236-
trainer.fast_dev_run # disable checkpointing with fast_dev_run
237-
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
238-
or trainer.sanity_checking # don't save anything during sanity check
239-
or self.period < 1 # no models are saved
240-
or (epoch + 1) % self.period # skip epoch
241-
or self._last_global_step_saved == global_step # already saved at the last step
242-
):
243-
return
244-
245268
self._add_backward_monitor_support(trainer)
246269
self._validate_monitor_key(trainer)
247270

@@ -260,9 +283,32 @@ def save_checkpoint(self, trainer, unused: Optional = None):
260283
# Mode 3: save last checkpoints
261284
self._save_last_checkpoint(trainer, monitor_candidates)
262285

286+
def _should_skip_saving_checkpoint(self, trainer) -> bool:
287+
from pytorch_lightning.trainer.states import TrainerState
288+
return (
289+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
290+
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
291+
or trainer.sanity_checking # don't save anything during sanity check
292+
or self._last_global_step_saved == trainer.global_step # already saved at the last step
293+
)
294+
263295
def __validate_init_configuration(self):
264296
if self.save_top_k is not None and self.save_top_k < -1:
265297
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
298+
if self._every_n_train_steps < 0:
299+
raise MisconfigurationException(
300+
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
301+
)
302+
if self._every_n_val_epochs < 0:
303+
raise MisconfigurationException(
304+
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
305+
)
306+
if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:
307+
raise MisconfigurationException(
308+
f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
309+
' and every_n_val_epochs={self._every_n_val_epochs}.'
310+
' Both cannot be enabled at the same time.'
311+
)
266312
if self.monitor is None:
267313
# None: save last epoch, -1: save all epochs, 0: nothing is saved
268314
if self.save_top_k not in (None, -1, 0):
@@ -309,6 +355,46 @@ def __init_monitor_mode(self, monitor, mode):
309355

310356
self.kth_value, self.mode = mode_dict[mode]
311357

358+
def __init_triggers(
359+
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int]
360+
) -> None:
361+
362+
# Default to running once after each validation epoch if neither
363+
# every_n_train_steps nor every_n_val_epochs is set
364+
if every_n_train_steps is None and every_n_val_epochs is None:
365+
self._every_n_val_epochs = 1
366+
self._every_n_train_steps = 0
367+
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
368+
else:
369+
self._every_n_val_epochs = every_n_val_epochs or 0
370+
self._every_n_train_steps = every_n_train_steps or 0
371+
372+
# period takes precedence over every_n_val_epochs for backwards compatibility
373+
if period is not None:
374+
rank_zero_warn(
375+
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
376+
' Please use `every_n_val_epochs` instead.', DeprecationWarning
377+
)
378+
self._every_n_val_epochs = period
379+
380+
self._period = self._every_n_val_epochs
381+
382+
@property
383+
def period(self) -> Optional[int]:
384+
rank_zero_warn(
385+
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
386+
' Please use `every_n_val_epochs` instead.', DeprecationWarning
387+
)
388+
return self._period
389+
390+
@period.setter
391+
def period(self, value: Optional[int]) -> None:
392+
rank_zero_warn(
393+
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
394+
' Please use `every_n_val_epochs` instead.', DeprecationWarning
395+
)
396+
self._period = value
397+
312398
@rank_zero_only
313399
def _del_model(self, filepath: str):
314400
if self._fs.exists(filepath):
@@ -422,11 +508,8 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
422508
423509
"""
424510
filename = self._format_checkpoint_name(
425-
self.filename,
426-
epoch,
427-
step,
428-
metrics,
429-
auto_insert_metric_name=self.auto_insert_metric_name)
511+
self.filename, epoch, step, metrics, auto_insert_metric_name=self.auto_insert_metric_name
512+
)
430513

431514
if ver is not None:
432515
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
@@ -581,9 +664,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A
581664
self._save_model(trainer, filepath)
582665

583666
if (
584-
self.save_top_k is None
585-
and self.best_model_path
586-
and self.best_model_path != filepath
667+
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
587668
and trainer.is_global_zero
588669
):
589670
self._del_model(self.best_model_path)

tests/checkpointing/test_model_checkpoint.py

Lines changed: 124 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
434434

435435
# auto_insert_metric_name=False
436436
ckpt_name = ModelCheckpoint._format_checkpoint_name(
437-
'epoch={epoch:03d}-val_acc={val/acc}',
438-
3,
439-
2,
440-
{'val/acc': 0.03},
441-
auto_insert_metric_name=False)
437+
'epoch={epoch:03d}-val_acc={val/acc}', 3, 2, {'val/acc': 0.03}, auto_insert_metric_name=False
438+
)
442439
assert ckpt_name == 'epoch=003-val_acc=0.03'
443440

444441

@@ -524,6 +521,45 @@ def test_none_monitor_save_last(tmpdir):
524521
ModelCheckpoint(dirpath=tmpdir, save_last=False)
525522

526523

524+
def test_invalid_every_n_val_epochs(tmpdir):
525+
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
526+
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
527+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
528+
# These should not fail
529+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
530+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
531+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
532+
533+
534+
def test_invalid_every_n_train_steps(tmpdir):
535+
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
536+
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
537+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
538+
# These should not fail
539+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
540+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
541+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
542+
543+
544+
def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir):
545+
"""
546+
Test that a MisconfigurationException is raised if both
547+
every_n_val_epochs and every_n_train_steps are enabled together.
548+
"""
549+
with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'):
550+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
551+
# These should not fail
552+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
553+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)
554+
555+
556+
def test_none_every_n_train_steps_val_epochs(tmpdir):
557+
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
558+
assert checkpoint_callback.period == 1
559+
assert checkpoint_callback._every_n_val_epochs == 1
560+
assert checkpoint_callback._every_n_train_steps == 0
561+
562+
527563
def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
528564
""" Test that it is possible to save all checkpoints when monitor=None. """
529565
seed_everything()
@@ -567,9 +603,8 @@ def test_model_checkpoint_period(tmpdir, period: int):
567603
default_root_dir=tmpdir,
568604
callbacks=[checkpoint_callback],
569605
max_epochs=epochs,
570-
limit_train_batches=0.1,
571-
limit_val_batches=0.1,
572-
val_check_interval=1.0,
606+
limit_train_batches=1,
607+
limit_val_batches=1,
573608
logger=False,
574609
)
575610
trainer.fit(model)
@@ -579,6 +614,87 @@ def test_model_checkpoint_period(tmpdir, period: int):
579614
assert set(os.listdir(tmpdir)) == set(expected)
580615

581616

617+
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
618+
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
619+
model = LogInTwoMethods()
620+
epochs = 5
621+
checkpoint_callback = ModelCheckpoint(
622+
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
623+
)
624+
trainer = Trainer(
625+
default_root_dir=tmpdir,
626+
callbacks=[checkpoint_callback],
627+
max_epochs=epochs,
628+
limit_train_batches=1,
629+
limit_val_batches=1,
630+
logger=False,
631+
)
632+
trainer.fit(model)
633+
634+
# check that the correct ckpts were created
635+
expected = [f'epoch={e}.ckpt' for e in range(epochs)
636+
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
637+
assert set(os.listdir(tmpdir)) == set(expected)
638+
639+
640+
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
641+
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
642+
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
643+
model = LogInTwoMethods()
644+
epochs = 5
645+
checkpoint_callback = ModelCheckpoint(
646+
dirpath=tmpdir,
647+
filename='{epoch}',
648+
save_top_k=-1,
649+
every_n_val_epochs=(2 * every_n_val_epochs),
650+
period=every_n_val_epochs
651+
)
652+
trainer = Trainer(
653+
default_root_dir=tmpdir,
654+
callbacks=[checkpoint_callback],
655+
max_epochs=epochs,
656+
limit_train_batches=1,
657+
limit_val_batches=1,
658+
logger=False,
659+
)
660+
trainer.fit(model)
661+
662+
# check that the correct ckpts were created
663+
expected = [f'epoch={e}.ckpt' for e in range(epochs)
664+
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
665+
assert set(os.listdir(tmpdir)) == set(expected)
666+
667+
668+
def test_ckpt_every_n_train_steps(tmpdir):
669+
""" Tests that the checkpoints are saved every n training steps. """
670+
671+
model = LogInTwoMethods()
672+
every_n_train_steps = 16
673+
max_epochs = 2
674+
epoch_length = 64
675+
checkpoint_callback = ModelCheckpoint(
676+
filename="{step}",
677+
every_n_val_epochs=0,
678+
every_n_train_steps=every_n_train_steps,
679+
dirpath=tmpdir,
680+
save_top_k=-1,
681+
save_last=False,
682+
)
683+
trainer = Trainer(
684+
default_root_dir=tmpdir,
685+
max_epochs=2,
686+
progress_bar_refresh_rate=0,
687+
callbacks=[checkpoint_callback],
688+
logger=False,
689+
)
690+
691+
trainer.fit(model)
692+
expected = [
693+
f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps)
694+
]
695+
assert set(os.listdir(tmpdir)) == set(expected)
696+
697+
582698
def test_model_checkpoint_topk_zero(tmpdir):
583699
""" Test that no checkpoints are saved when save_top_k=0. """
584700
model = LogInTwoMethods()

tests/deprecated_api/test_remove_1-5.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,10 @@ def configure_optimizers(self):
104104

105105
with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
106106
trainer.fit(model)
107+
108+
109+
def test_v1_5_0_model_checkpoint_period(tmpdir):
110+
with no_warning_call(DeprecationWarning):
111+
ModelCheckpoint(dirpath=tmpdir)
112+
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
113+
ModelCheckpoint(dirpath=tmpdir, period=1)

0 commit comments

Comments
 (0)