@@ -240,17 +240,14 @@ def save_checkpoint(self, trainer, pl_module):
240240 # what can be monitored
241241 monitor_candidates = self ._monitor_candidates (trainer )
242242
243- # ie: path/val_loss=0.5.ckpt
244- filepath = self ._get_metric_interpolated_filepath_name (monitor_candidates , epoch , global_step )
245-
246243 # callback supports multiple simultaneous modes
247244 # here we call each mode sequentially
248245 # Mode 1: save all checkpoints OR only the top k
249246 if self .save_top_k :
250- self ._save_top_k_checkpoints (monitor_candidates , trainer , pl_module , filepath )
247+ self ._save_top_k_checkpoints (trainer , pl_module , monitor_candidates )
251248
252249 # Mode 2: save the last checkpoint
253- self ._save_last_checkpoint (trainer , pl_module , monitor_candidates , filepath )
250+ self ._save_last_checkpoint (trainer , pl_module , monitor_candidates )
254251
255252 def __validate_init_configuration (self ):
256253 if self .save_top_k is not None and self .save_top_k < - 1 :
@@ -444,6 +441,7 @@ def format_checkpoint_name(
444441 )
445442 if ver is not None :
446443 filename = self .CHECKPOINT_JOIN_CHAR .join ((filename , f"v{ ver } " ))
444+
447445 ckpt_name = f"{ filename } { self .FILE_EXTENSION } "
448446 return os .path .join (self .dirpath , ckpt_name ) if self .dirpath else ckpt_name
449447
@@ -515,13 +513,20 @@ def _validate_monitor_key(self, trainer):
515513 )
516514 raise MisconfigurationException (m )
517515
518- def _get_metric_interpolated_filepath_name (self , ckpt_name_metrics : Dict [str , Any ], epoch : int , step : int ):
516+ def _get_metric_interpolated_filepath_name (
517+ self ,
518+ ckpt_name_metrics : Dict [str , Any ],
519+ epoch : int ,
520+ step : int ,
521+ del_filepath : Optional [str ] = None
522+ ) -> str :
519523 filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics )
524+
520525 version_cnt = 0
521- while self ._fs .exists (filepath ):
526+ while self ._fs .exists (filepath ) and filepath != del_filepath :
522527 filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics , ver = version_cnt )
523- # this epoch called before
524528 version_cnt += 1
529+
525530 return filepath
526531
527532 def _monitor_candidates (self , trainer ):
@@ -531,13 +536,11 @@ def _monitor_candidates(self, trainer):
531536 ckpt_name_metrics .update ({"step" : trainer .global_step , "epoch" : trainer .current_epoch })
532537 return ckpt_name_metrics
533538
534- def _save_last_checkpoint (self , trainer , pl_module , ckpt_name_metrics , filepath ):
539+ def _save_last_checkpoint (self , trainer , pl_module , ckpt_name_metrics ):
535540 should_save_last = self .monitor is None or self .save_last
536541 if not should_save_last :
537542 return
538543
539- last_filepath = filepath
540-
541544 # when user ALSO asked for the 'last.ckpt' change the name
542545 if self .save_last :
543546 last_filepath = self ._format_checkpoint_name (
@@ -548,6 +551,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
548551 prefix = self .prefix
549552 )
550553 last_filepath = os .path .join (self .dirpath , f"{ last_filepath } { self .FILE_EXTENSION } " )
554+ else :
555+ last_filepath = self ._get_metric_interpolated_filepath_name (
556+ ckpt_name_metrics , trainer .current_epoch , trainer .global_step
557+ )
551558
552559 accelerator_backend = trainer .accelerator_backend
553560
@@ -568,7 +575,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
568575 if self .monitor is None :
569576 self .best_model_path = self .last_model_path
570577
571- def _save_top_k_checkpoints (self , metrics , trainer , pl_module , filepath ):
578+ def _save_top_k_checkpoints (self , trainer , pl_module , metrics ):
572579 current = metrics .get (self .monitor )
573580 epoch = metrics .get ("epoch" )
574581 step = metrics .get ("step" )
@@ -577,7 +584,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
577584 current = torch .tensor (current , device = pl_module .device )
578585
579586 if self .check_monitor_top_k (current ):
580- self ._update_best_and_save (filepath , current , epoch , step , trainer , pl_module )
587+ self ._update_best_and_save (current , epoch , step , trainer , pl_module , metrics )
581588 elif self .verbose :
582589 rank_zero_info (
583590 f"Epoch { epoch :d} , step { step :d} : { self .monitor } was not in top { self .save_top_k } "
@@ -588,25 +595,26 @@ def _is_valid_monitor_key(self, metrics):
588595
589596 def _update_best_and_save (
590597 self ,
591- filepath : str ,
592598 current : torch .Tensor ,
593599 epoch : int ,
594600 step : int ,
595601 trainer ,
596602 pl_module ,
603+ ckpt_name_metrics
597604 ):
598605 k = len (self .best_k_models ) + 1 if self .save_top_k == - 1 else self .save_top_k
599606
600- del_list = []
607+ del_filepath = None
601608 if len (self .best_k_models ) == k and k > 0 :
602- delpath = self .kth_best_model_path
603- self .best_k_models .pop (self .kth_best_model_path )
604- del_list .append (delpath )
609+ del_filepath = self .kth_best_model_path
610+ self .best_k_models .pop (del_filepath )
605611
606612 # do not save nan, replace with +/- inf
607613 if torch .isnan (current ):
608614 current = torch .tensor (float ('inf' if self .mode == "min" else '-inf' ))
609615
616+ filepath = self ._get_metric_interpolated_filepath_name (ckpt_name_metrics , epoch , step , del_filepath )
617+
610618 # save the current score
611619 self .current_score = current
612620 self .best_k_models [filepath ] = current
@@ -630,9 +638,8 @@ def _update_best_and_save(
630638 )
631639 self ._save_model (filepath , trainer , pl_module )
632640
633- for cur_path in del_list :
634- if cur_path != filepath :
635- self ._del_model (cur_path )
641+ if del_filepath is not None and filepath != del_filepath :
642+ self ._del_model (del_filepath )
636643
637644 def to_yaml (self , filepath : Optional [Union [str , Path ]] = None ):
638645 """
0 commit comments