Skip to content

Commit 710e05e

Browse files
committed
try appending version to already saved ckpt_file
1 parent 41ae295 commit 710e05e

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,24 @@ def _get_metric_interpolated_filepath_name(
504504
) -> str:
505505
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
506506

507-
version_cnt = 0
508-
while self._fs.exists(filepath) and filepath != del_filepath:
507+
version_cnt = 1
508+
old_ckpt_ver_0 = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=0)
509+
while (
510+
self._fs.exists(filepath)
511+
or (self._fs.exists(old_ckpt_ver_0) and version_cnt == 1)
512+
):
513+
if del_filepath == filepath:
514+
return filepath
515+
516+
if del_filepath == old_ckpt_ver_0:
517+
return old_ckpt_ver_0
518+
519+
if self._fs.exists(filepath):
520+
self._fs.rename(filepath, old_ckpt_ver_0)
521+
old_ckpt_score = self.best_k_models[filepath]
522+
self.best_k_models.pop(filepath)
523+
self.best_k_models[old_ckpt_ver_0] = old_ckpt_score
524+
509525
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
510526
version_cnt += 1
511527

@@ -523,10 +539,6 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
523539
if not should_save_last:
524540
return
525541

526-
last_filepath = self._get_metric_interpolated_filepath_name(
527-
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
528-
)
529-
530542
# when user ALSO asked for the 'last.ckpt' change the name
531543
if self.save_last:
532544
last_filepath = self._format_checkpoint_name(
@@ -537,6 +549,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
537549
prefix=self.prefix
538550
)
539551
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
552+
else:
553+
last_filepath = self._get_metric_interpolated_filepath_name(
554+
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
555+
)
540556

541557
self._save_model(last_filepath, trainer, pl_module)
542558
if (

tests/checkpointing/test_model_checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ def __init__(self, hparams):
993993
'save_top_k, expected',
994994
[
995995
(1, ['curr_epoch.ckpt']),
996-
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
996+
(2, ['curr_epoch-v0.ckpt', 'curr_epoch-v1.ckpt']),
997997
]
998998
)
999999
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
@@ -1026,3 +1026,6 @@ def validation_epoch_end(self, outputs):
10261026
trainer.fit(model)
10271027
ckpt_files = os.listdir(tmpdir)
10281028
assert set(ckpt_files) == set(expected)
1029+
1030+
expected_epoch_in_files = sorted([pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files])
1031+
assert expected_epoch_in_files == sorted(list(range(max_epochs))[-save_top_k:])

0 commit comments

Comments
 (0)