Skip to content

Commit 79565be

Browse files
rohitgr7carmoccatchatons-rog
authored
Fix saved filename in ModelCheckpoint if it already exists (#4861)
* disable version if not required * disable version if not required * pep * chlog * improve test * improve test * parametrize test and update del_list * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Carlos Mocholí <[email protected]> * try appending version to already saved ckpt_file * Revert "try appending version to already saved ckpt_file" This reverts commit 710e05e. * add more assertions * use BoringModel Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Roger Shieh <[email protected]>
1 parent fde972f commit 79565be

File tree

3 files changed

+71
-21
lines changed

3 files changed

+71
-21
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Fixed `LightningOptimizer` exposes optimizer attributes ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095))
2828

2929

30+
- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))
31+
32+
3033
- Do not warn when the `name` key is used in the `lr_scheduler` dict ([#5057](https://github.com/PyTorchLightning/pytorch-lightning/pull/5057))
3134

3235

36+
3337
## [1.1.0] - 2020-12-09
3438

3539
### Added

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,14 @@ def save_checkpoint(self, trainer, pl_module):
240240
# what can be monitored
241241
monitor_candidates = self._monitor_candidates(trainer)
242242

243-
# ie: path/val_loss=0.5.ckpt
244-
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)
245-
246243
# callback supports multiple simultaneous modes
247244
# here we call each mode sequentially
248245
# Mode 1: save all checkpoints OR only the top k
249246
if self.save_top_k:
250-
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)
247+
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)
251248

252249
# Mode 2: save the last checkpoint
253-
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)
250+
self._save_last_checkpoint(trainer, pl_module, monitor_candidates)
254251

255252
def __validate_init_configuration(self):
256253
if self.save_top_k is not None and self.save_top_k < -1:
@@ -444,6 +441,7 @@ def format_checkpoint_name(
444441
)
445442
if ver is not None:
446443
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
444+
447445
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
448446
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
449447

@@ -515,13 +513,20 @@ def _validate_monitor_key(self, trainer):
515513
)
516514
raise MisconfigurationException(m)
517515

518-
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
516+
def _get_metric_interpolated_filepath_name(
517+
self,
518+
ckpt_name_metrics: Dict[str, Any],
519+
epoch: int,
520+
step: int,
521+
del_filepath: Optional[str] = None
522+
) -> str:
519523
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
524+
520525
version_cnt = 0
521-
while self._fs.exists(filepath):
526+
while self._fs.exists(filepath) and filepath != del_filepath:
522527
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
523-
# this epoch called before
524528
version_cnt += 1
529+
525530
return filepath
526531

527532
def _monitor_candidates(self, trainer):
@@ -531,13 +536,11 @@ def _monitor_candidates(self, trainer):
531536
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
532537
return ckpt_name_metrics
533538

534-
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
539+
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
535540
should_save_last = self.monitor is None or self.save_last
536541
if not should_save_last:
537542
return
538543

539-
last_filepath = filepath
540-
541544
# when user ALSO asked for the 'last.ckpt' change the name
542545
if self.save_last:
543546
last_filepath = self._format_checkpoint_name(
@@ -548,6 +551,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
548551
prefix=self.prefix
549552
)
550553
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
554+
else:
555+
last_filepath = self._get_metric_interpolated_filepath_name(
556+
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
557+
)
551558

552559
accelerator_backend = trainer.accelerator_backend
553560

@@ -568,7 +575,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
568575
if self.monitor is None:
569576
self.best_model_path = self.last_model_path
570577

571-
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
578+
def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
572579
current = metrics.get(self.monitor)
573580
epoch = metrics.get("epoch")
574581
step = metrics.get("step")
@@ -577,7 +584,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
577584
current = torch.tensor(current, device=pl_module.device)
578585

579586
if self.check_monitor_top_k(current):
580-
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
587+
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
581588
elif self.verbose:
582589
rank_zero_info(
583590
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
@@ -588,25 +595,26 @@ def _is_valid_monitor_key(self, metrics):
588595

589596
def _update_best_and_save(
590597
self,
591-
filepath: str,
592598
current: torch.Tensor,
593599
epoch: int,
594600
step: int,
595601
trainer,
596602
pl_module,
603+
ckpt_name_metrics
597604
):
598605
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
599606

600-
del_list = []
607+
del_filepath = None
601608
if len(self.best_k_models) == k and k > 0:
602-
delpath = self.kth_best_model_path
603-
self.best_k_models.pop(self.kth_best_model_path)
604-
del_list.append(delpath)
609+
del_filepath = self.kth_best_model_path
610+
self.best_k_models.pop(del_filepath)
605611

606612
# do not save nan, replace with +/- inf
607613
if torch.isnan(current):
608614
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))
609615

616+
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
617+
610618
# save the current score
611619
self.current_score = current
612620
self.best_k_models[filepath] = current
@@ -630,9 +638,8 @@ def _update_best_and_save(
630638
)
631639
self._save_model(filepath, trainer, pl_module)
632640

633-
for cur_path in del_list:
634-
if cur_path != filepath:
635-
self._del_model(cur_path)
641+
if del_filepath is not None and filepath != del_filepath:
642+
self._del_model(del_filepath)
636643

637644
def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
638645
"""

tests/checkpointing/test_model_checkpoint.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,3 +938,42 @@ def __init__(self, hparams):
938938
else:
939939
# make sure it's not AttributeDict
940940
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type
941+
942+
943+
@pytest.mark.parametrize('max_epochs', [3, 4])
944+
@pytest.mark.parametrize(
945+
'save_top_k, expected',
946+
[
947+
(1, ['curr_epoch.ckpt']),
948+
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
949+
]
950+
)
951+
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
952+
"""
953+
Test that version is added to filename if required and it already exists in dirpath.
954+
"""
955+
model_checkpoint = ModelCheckpoint(
956+
dirpath=tmpdir,
957+
filename='curr_epoch',
958+
save_top_k=save_top_k,
959+
monitor='epoch',
960+
mode='max',
961+
)
962+
trainer = Trainer(
963+
default_root_dir=tmpdir,
964+
callbacks=[model_checkpoint],
965+
max_epochs=max_epochs,
966+
limit_train_batches=2,
967+
limit_val_batches=2,
968+
logger=None,
969+
weights_summary=None,
970+
progress_bar_refresh_rate=0,
971+
)
972+
973+
model = BoringModel()
974+
trainer.fit(model)
975+
ckpt_files = os.listdir(tmpdir)
976+
assert set(ckpt_files) == set(expected)
977+
978+
epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
979+
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))

0 commit comments

Comments
 (0)