@@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
101101 ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
102102 ... )
103103
104- By default, filename is ``None`` and will be set to ``'{epoch}'``.
104+ By default, filename is ``None`` and will be set to ``'{epoch}-{step} '``.
105105
106106
107107 Example::
@@ -222,16 +222,16 @@ def save_checkpoint(self, trainer, pl_module):
222222 monitor_candidates = self ._monitor_candidates (trainer )
223223
224224 # ie: path/val_loss=0.5.ckpt
225- filepath = self ._get_metric_interpolated_filepath_name (epoch , monitor_candidates )
225+ filepath = self ._get_metric_interpolated_filepath_name (monitor_candidates , epoch , global_step )
226226
227227 # callback supports multiple simultaneous modes
228228 # here we call each mode sequentially
229229 # Mode 1: save all checkpoints OR only the top k
230230 if self .save_top_k :
231- self ._save_top_k_checkpoints (monitor_candidates , trainer , pl_module , epoch , filepath )
231+ self ._save_top_k_checkpoints (monitor_candidates , trainer , pl_module , filepath )
232232
233233 # Mode 2: save the last checkpoint
234- self ._save_last_checkpoint (trainer , pl_module , epoch , monitor_candidates , filepath )
234+ self ._save_last_checkpoint (trainer , pl_module , monitor_candidates , filepath )
235235
236236 def __validate_init_configuration (self ):
237237 if self .save_top_k is not None and self .save_top_k < - 1 :
@@ -360,16 +360,17 @@ def _format_checkpoint_name(
360360 cls ,
361361 filename : Optional [str ],
362362 epoch : int ,
363+ step : int ,
363364 metrics : Dict [str , Any ],
364365 prefix : str = "" ,
365366 ) -> str :
366367 if not filename :
367368 # filename is not set, use default name
368- filename = "{epoch}"
369+ filename = "{epoch}-{step} "
369370 # check and parse user passed keys in the string
370371 groups = re .findall (r"(\{.*?)[:\}]" , filename )
371372 if len (groups ) >= 0 :
372- metrics [ "epoch" ] = epoch
373+ metrics . update ({ "epoch" : epoch , 'step' : step })
373374 for group in groups :
374375 name = group [1 :]
375376 filename = filename .replace (group , name + "={" + name )
@@ -379,32 +380,32 @@ def _format_checkpoint_name(
379380 return cls .CHECKPOINT_JOIN_CHAR .join ([txt for txt in (prefix , filename ) if txt ])
380381
381382 def format_checkpoint_name (
382- self , epoch : int , metrics : Dict [str , Any ], ver : Optional [int ] = None
383+ self , epoch : int , step : int , metrics : Dict [str , Any ], ver : Optional [int ] = None
383384 ) -> str :
384385 """Generate a filename according to the defined template.
385386
386387 Example::
387388
388389 >>> tmpdir = os.path.dirname(__file__)
389390 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
390- >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
391+ >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics= {}))
391392 'epoch=0.ckpt'
392393 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
393- >>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
394+ >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics= {}))
394395 'epoch=005.ckpt'
395396 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
396- >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
397+ >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics= dict(val_loss=0.123456)))
397398 'epoch=2-val_loss=0.12.ckpt'
398399 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
399- >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
400+ >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics= {}))
400401 'missing=0.ckpt'
401- >>> ckpt = ModelCheckpoint(filename='{epoch }')
402- >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
403- 'epoch =0.ckpt'
402+ >>> ckpt = ModelCheckpoint(filename='{step }')
403+ >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
404+ 'step =0.ckpt'
404405
405406 """
406407 filename = self ._format_checkpoint_name (
407- self .filename , epoch , metrics , prefix = self .prefix
408+ self .filename , epoch , step , metrics , prefix = self .prefix
408409 )
409410 if ver is not None :
410411 filename = self .CHECKPOINT_JOIN_CHAR .join ((filename , f"v{ ver } " ))
@@ -479,13 +480,11 @@ def _validate_monitor_key(self, trainer):
479480 )
480481 raise MisconfigurationException (m )
481482
482- def _get_metric_interpolated_filepath_name (self , epoch , ckpt_name_metrics ):
483- filepath = self .format_checkpoint_name (epoch , ckpt_name_metrics )
483+ def _get_metric_interpolated_filepath_name (self , ckpt_name_metrics : Dict [ str , Any ], epoch : int , step : int ):
484+ filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics )
484485 version_cnt = 0
485486 while self ._fs .exists (filepath ):
486- filepath = self .format_checkpoint_name (
487- epoch , ckpt_name_metrics , ver = version_cnt
488- )
487+ filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics , ver = version_cnt )
489488 # this epoch called before
490489 version_cnt += 1
491490 return filepath
@@ -494,9 +493,10 @@ def _monitor_candidates(self, trainer):
494493 ckpt_name_metrics = deepcopy (trainer .logger_connector .logged_metrics )
495494 ckpt_name_metrics .update (trainer .logger_connector .callback_metrics )
496495 ckpt_name_metrics .update (trainer .logger_connector .progress_bar_metrics )
496+ ckpt_name_metrics .update ({"step" : trainer .global_step , "epoch" : trainer .current_epoch })
497497 return ckpt_name_metrics
498498
499- def _save_last_checkpoint (self , trainer , pl_module , epoch , ckpt_name_metrics , filepath ):
499+ def _save_last_checkpoint (self , trainer , pl_module , ckpt_name_metrics , filepath ):
500500 should_save_last = self .monitor is None or self .save_last
501501 if not should_save_last :
502502 return
@@ -506,7 +506,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
506506 # when user ALSO asked for the 'last.ckpt' change the name
507507 if self .save_last :
508508 last_filepath = self ._format_checkpoint_name (
509- self .CHECKPOINT_NAME_LAST , epoch , ckpt_name_metrics , prefix = self .prefix
509+ self .CHECKPOINT_NAME_LAST ,
510+ trainer .current_epoch ,
511+ trainer .global_step ,
512+ ckpt_name_metrics ,
513+ prefix = self .prefix
510514 )
511515 last_filepath = os .path .join (self .dirpath , f"{ last_filepath } .ckpt" )
512516
@@ -523,17 +527,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
523527 if self .monitor is None :
524528 self .best_model_path = self .last_model_path
525529
526- def _save_top_k_checkpoints (self , metrics , trainer , pl_module , epoch , filepath ):
530+ def _save_top_k_checkpoints (self , metrics , trainer , pl_module , filepath ):
527531 current = metrics .get (self .monitor )
532+ epoch = metrics .get ("epoch" )
533+ step = metrics .get ("step" )
528534
529535 if not isinstance (current , torch .Tensor ) and current is not None :
530536 current = torch .tensor (current , device = pl_module .device )
531537
532538 if self .check_monitor_top_k (current ):
533- self ._update_best_and_save (filepath , current , epoch , trainer , pl_module )
539+ self ._update_best_and_save (filepath , current , epoch , step , trainer , pl_module )
534540 elif self .verbose :
535541 rank_zero_info (
536- f"Epoch { epoch :d} : { self .monitor } was not in top { self .save_top_k } "
542+ f"Epoch { epoch :d} , step { step :d } : { self .monitor } was not in top { self .save_top_k } "
537543 )
538544
539545 def _is_valid_monitor_key (self , metrics ):
@@ -544,11 +550,11 @@ def _update_best_and_save(
544550 filepath : str ,
545551 current : torch .Tensor ,
546552 epoch : int ,
553+ step : int ,
547554 trainer ,
548555 pl_module ,
549556 ):
550-
551- k = epoch + 1 if self .save_top_k == - 1 else self .save_top_k
557+ k = len (self .best_k_models ) + 1 if self .save_top_k == - 1 else self .save_top_k
552558
553559 del_list = []
554560 if len (self .best_k_models ) == k and k > 0 :
@@ -575,9 +581,8 @@ def _update_best_and_save(
575581
576582 if self .verbose :
577583 rank_zero_info (
578- f"Epoch { epoch :d} : { self .monitor } reached"
579- f" { current :0.5f} (best { self .best_model_score :0.5f} ),"
580- f" saving model to { filepath } as top { k } "
584+ f"Epoch { epoch :d} , global step { step :d} : { self .monitor } reached { current :0.5f} "
585+ f' (best { self .best_model_score :0.5f} ), saving model to "{ filepath } " as top { k } '
581586 )
582587 self ._save_model (filepath , trainer , pl_module )
583588
0 commit comments