Skip to content

Commit bbbe7c6

Browse files
authored
Merge branch 'release/1.2-dev' into refactor/legacy-accel-plug
2 parents 94c3728 + 9d165f6 commit bbbe7c6

File tree

3 files changed

+92
-55
lines changed

3 files changed

+92
-55
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9191
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))
9292

9393

94+
- Changed `ModelCheckpoint` version suffixes to start at 1 ([5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008))
95+
96+
9497
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))
9598

9699

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ class ModelCheckpoint(Callback):
8080
the quantity monitored will be saved.
8181
if ``save_top_k == 0``, no models are saved.
8282
if ``save_top_k == -1``, all models are saved.
83-
Please note that the monitors are checked every `period` epochs.
83+
Please note that the monitors are checked every ``period`` epochs.
8484
if ``save_top_k >= 2`` and the callback is called multiple
8585
times inside an epoch, the name of the saved file will be
86-
appended with a version count starting with `v0`.
86+
appended with a version count starting with ``v1``.
8787
mode: one of {auto, min, max}.
8888
If ``save_top_k != 0``, the decision
8989
to overwrite the current save file is made
@@ -105,6 +105,17 @@ class ModelCheckpoint(Callback):
105105
.. warning::
106106
This argument has been deprecated in v1.1 and will be removed in v1.3
107107
108+
Note:
109+
For extra customization, ModelCheckpoint includes the following attributes:
110+
111+
- ``CHECKPOINT_JOIN_CHAR = "-"``
112+
- ``CHECKPOINT_NAME_LAST = "last"``
113+
- ``FILE_EXTENSION = ".ckpt"``
114+
- ``STARTING_VERSION = 1``
115+
116+
For example, you can change the default last checkpoint name by doing
117+
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
118+
108119
Example::
109120
110121
>>> from pytorch_lightning import Trainer
@@ -128,11 +139,13 @@ class ModelCheckpoint(Callback):
128139
model = ...
129140
trainer.fit(model)
130141
checkpoint_callback.best_model_path
142+
131143
"""
132144

133145
CHECKPOINT_JOIN_CHAR = "-"
134146
CHECKPOINT_NAME_LAST = "last"
135147
FILE_EXTENSION = ".ckpt"
148+
STARTING_VERSION = 1
136149

137150
def __init__(
138151
self,
@@ -485,28 +498,24 @@ def _validate_monitor_key(self, trainer):
485498

486499
def _get_metric_interpolated_filepath_name(
487500
self,
488-
ckpt_name_metrics: Dict[str, Any],
501+
monitor_candidates: Dict[str, Any],
489502
epoch: int,
490503
step: int,
491504
del_filepath: Optional[str] = None
492505
) -> str:
493-
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
494-
495-
version_cnt = 0
506+
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
507+
version = self.STARTING_VERSION
496508
while self._fs.exists(filepath) and filepath != del_filepath:
497-
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
498-
version_cnt += 1
499-
509+
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version)
510+
version += 1
500511
return filepath
501512

502513
def _monitor_candidates(self, trainer):
503-
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
504-
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
505-
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
506-
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
507-
return ckpt_name_metrics
514+
monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics)
515+
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
516+
return monitor_candidates
508517

509-
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
518+
def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates):
510519
should_save_last = self.monitor is None or self.save_last
511520
if not should_save_last:
512521
return
@@ -517,13 +526,13 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
517526
self.CHECKPOINT_NAME_LAST,
518527
trainer.current_epoch,
519528
trainer.global_step,
520-
ckpt_name_metrics,
521-
prefix=self.prefix
529+
monitor_candidates,
530+
prefix=self.prefix,
522531
)
523532
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
524533
else:
525534
last_filepath = self._get_metric_interpolated_filepath_name(
526-
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
535+
monitor_candidates, trainer.current_epoch, trainer.global_step
527536
)
528537

529538
accelerator_backend = trainer.accelerator_backend
@@ -534,10 +543,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
534543
else:
535544
self._save_model(last_filepath, trainer, pl_module)
536545
if (
537-
self.last_model_path
538-
and self.last_model_path != last_filepath
539-
and (self.save_top_k != -1 or self.save_last)
540-
and trainer.is_global_zero
546+
self.last_model_path
547+
and self.last_model_path != last_filepath
548+
and (self.save_top_k != -1 or self.save_last)
549+
and trainer.is_global_zero
541550
):
542551
self._del_model(self.last_model_path)
543552
self.last_model_path = last_filepath

tests/checkpointing/test_model_checkpoint.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -738,18 +738,20 @@ def test_val_check_interval_checkpoint_files(tmpdir):
738738
save_top_k=-1,
739739
monitor="val_acc",
740740
mode="max",
741-
verbose=True
742741
)
743742
trainer = Trainer(
744743
default_root_dir=tmpdir,
745744
val_check_interval=0.2,
746745
max_epochs=1,
747746
limit_train_batches=10,
748-
callbacks=[model_checkpoint]
747+
callbacks=[model_checkpoint],
748+
logger=False,
749+
weights_summary=None,
750+
progress_bar_refresh_rate=0,
749751
)
750752
trainer.fit(model)
751-
files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")])
752-
assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]]
753+
files = {p.basename for p in tmpdir.listdir()}
754+
assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]}
753755

754756

755757
def test_current_score(tmpdir):
@@ -844,43 +846,66 @@ def __init__(self, hparams):
844846
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type
845847

846848

847-
@pytest.mark.parametrize('max_epochs', [3, 4])
848-
@pytest.mark.parametrize(
849-
'save_top_k, expected',
850-
[
851-
(1, ['curr_epoch.ckpt']),
852-
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
853-
]
854-
)
855-
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
849+
def test_ckpt_version_after_rerun_new_trainer(tmpdir):
856850
"""
857-
Test that version is added to filename if required and it already exists in dirpath.
851+
Check that previous checkpoints are renamed to have the correct
852+
version suffix when new trainer instances are used
858853
"""
859-
model_checkpoint = ModelCheckpoint(
860-
dirpath=tmpdir,
861-
filename='curr_epoch',
862-
save_top_k=save_top_k,
863-
monitor='epoch',
864-
mode='max',
865-
)
854+
epochs = 2
855+
for i in range(epochs):
856+
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}")
857+
trainer = Trainer(
858+
max_epochs=epochs,
859+
limit_train_batches=1,
860+
limit_val_batches=1,
861+
default_root_dir=tmpdir,
862+
callbacks=[mc],
863+
logger=False,
864+
weights_summary=None,
865+
progress_bar_refresh_rate=0,
866+
)
867+
trainer.fit(BoringModel())
868+
869+
# check best_k_models state
870+
expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"}
871+
assert {Path(f).name for f in mc.best_k_models.keys()} == expected
872+
873+
# check created ckpts
874+
assert set(f.basename for f in tmpdir.listdir()) == {
875+
"epoch=0.ckpt",
876+
"epoch=1.ckpt",
877+
"epoch=0-v1.ckpt",
878+
"epoch=1-v1.ckpt",
879+
}
880+
881+
882+
def test_ckpt_version_after_rerun_same_trainer(tmpdir):
883+
"""
884+
Check that previous checkpoints are renamed to have the correct
885+
version suffix when the same trainer instance is used
886+
"""
887+
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test")
888+
mc.STARTING_VERSION = 9
866889
trainer = Trainer(
890+
max_epochs=2,
891+
limit_train_batches=1,
892+
limit_val_batches=1,
867893
default_root_dir=tmpdir,
868-
callbacks=[model_checkpoint],
869-
max_epochs=max_epochs,
870-
limit_train_batches=2,
871-
limit_val_batches=2,
872-
logger=None,
894+
callbacks=[mc],
895+
logger=False,
873896
weights_summary=None,
874897
progress_bar_refresh_rate=0,
875898
)
899+
trainer.fit(BoringModel())
900+
trainer.max_epochs = 4
901+
trainer.fit(BoringModel())
876902

877-
model = BoringModel()
878-
trainer.fit(model)
879-
ckpt_files = os.listdir(tmpdir)
880-
assert set(ckpt_files) == set(expected)
881-
882-
epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
883-
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
903+
ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION)
904+
expected = {'test.ckpt', *[f"test-v{i}.ckpt" for i in ckpt_range]}
905+
# check best_k_models state
906+
assert {Path(f).name for f in mc.best_k_models.keys()} == expected
907+
# check created ckpts
908+
assert set(sorted(os.listdir(tmpdir))) == expected
884909

885910

886911
def test_model_checkpoint_mode_options():

0 commit comments

Comments
 (0)