@@ -504,8 +504,24 @@ def _get_metric_interpolated_filepath_name(
504504 ) -> str :
505505 filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics )
506506
507- version_cnt = 0
508- while self ._fs .exists (filepath ) and filepath != del_filepath :
507+ version_cnt = 1
508+ old_ckpt_ver_0 = self .format_checkpoint_name (epoch , step , ckpt_name_metrics , ver = 0 )
509+ while (
510+ self ._fs .exists (filepath )
511+ or (self ._fs .exists (old_ckpt_ver_0 ) and version_cnt == 1 )
512+ ):
513+ if del_filepath == filepath :
514+ return filepath
515+
516+ if del_filepath == old_ckpt_ver_0 :
517+ return old_ckpt_ver_0
518+
519+ if self ._fs .exists (filepath ):
520+ self ._fs .rename (filepath , old_ckpt_ver_0 )
521+ old_ckpt_score = self .best_k_models [filepath ]
522+ self .best_k_models .pop (filepath )
523+ self .best_k_models [old_ckpt_ver_0 ] = old_ckpt_score
524+
509525 filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics , ver = version_cnt )
510526 version_cnt += 1
511527
@@ -523,10 +539,6 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
523539 if not should_save_last :
524540 return
525541
526- last_filepath = self ._get_metric_interpolated_filepath_name (
527- ckpt_name_metrics , trainer .current_epoch , trainer .global_step
528- )
529-
530542 # when user ALSO asked for the 'last.ckpt' change the name
531543 if self .save_last :
532544 last_filepath = self ._format_checkpoint_name (
@@ -537,6 +549,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
537549 prefix = self .prefix
538550 )
539551 last_filepath = os .path .join (self .dirpath , f"{ last_filepath } .ckpt" )
552+ else :
553+ last_filepath = self ._get_metric_interpolated_filepath_name (
554+ ckpt_name_metrics , trainer .current_epoch , trainer .global_step
555+ )
540556
541557 self ._save_model (last_filepath , trainer , pl_module )
542558 if (
0 commit comments