@@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
101101 ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
102102 ... )
103103
104- By default, filename is ``None`` and will be set to ``'{epoch}'``.
104+ By default, filename is ``None`` and will be set to ``'{epoch}-{step} '``.
105105
106106
107107 Example::
@@ -223,16 +223,16 @@ def save_checkpoint(self, trainer, pl_module):
223223 monitor_candidates = self ._monitor_candidates (trainer )
224224
225225 # ie: path/val_loss=0.5.ckpt
226- filepath = self ._get_metric_interpolated_filepath_name (epoch , monitor_candidates )
226+ filepath = self ._get_metric_interpolated_filepath_name (monitor_candidates , epoch , global_step )
227227
228228 # callback supports multiple simultaneous modes
229229 # here we call each mode sequentially
230230 # Mode 1: save all checkpoints OR only the top k
231231 if self .save_top_k :
232- self ._save_top_k_checkpoints (monitor_candidates , trainer , pl_module , epoch , filepath )
232+ self ._save_top_k_checkpoints (monitor_candidates , trainer , pl_module , filepath )
233233
234234 # Mode 2: save the last checkpoint
235- self ._save_last_checkpoint (trainer , pl_module , epoch , monitor_candidates , filepath )
235+ self ._save_last_checkpoint (trainer , pl_module , monitor_candidates , filepath )
236236
237237 def __validate_init_configuration (self ):
238238 if self .save_top_k is not None and self .save_top_k < - 1 :
@@ -361,16 +361,17 @@ def _format_checkpoint_name(
361361 cls ,
362362 filename : Optional [str ],
363363 epoch : int ,
364+ step : int ,
364365 metrics : Dict [str , Any ],
365366 prefix : str = "" ,
366367 ) -> str :
367368 if not filename :
368369 # filename is not set, use default name
369- filename = "{epoch}"
370+ filename = "{epoch}-{step} "
370371 # check and parse user passed keys in the string
371372 groups = re .findall (r"(\{.*?)[:\}]" , filename )
372373 if len (groups ) >= 0 :
373- metrics [ "epoch" ] = epoch
374+ metrics . update ({ "epoch" : epoch , 'step' : step })
374375 for group in groups :
375376 name = group [1 :]
376377 filename = filename .replace (group , name + "={" + name )
@@ -380,32 +381,32 @@ def _format_checkpoint_name(
380381 return cls .CHECKPOINT_JOIN_CHAR .join ([txt for txt in (prefix , filename ) if txt ])
381382
382383 def format_checkpoint_name (
383- self , epoch : int , metrics : Dict [str , Any ], ver : Optional [int ] = None
384+ self , epoch : int , step : int , metrics : Dict [str , Any ], ver : Optional [int ] = None
384385 ) -> str :
385386 """Generate a filename according to the defined template.
386387
387388 Example::
388389
389390 >>> tmpdir = os.path.dirname(__file__)
390391 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
391- >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
392+ >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics= {}))
392393 'epoch=0.ckpt'
393394 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
394- >>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
395+ >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics= {}))
395396 'epoch=005.ckpt'
396397 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
397- >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
398+ >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics= dict(val_loss=0.123456)))
398399 'epoch=2-val_loss=0.12.ckpt'
399400 >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
400- >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
401+ >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics= {}))
401402 'missing=0.ckpt'
402- >>> ckpt = ModelCheckpoint(filename='{epoch }')
403- >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
404- 'epoch =0.ckpt'
403+ >>> ckpt = ModelCheckpoint(filename='{step }')
404+ >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
405+ 'step =0.ckpt'
405406
406407 """
407408 filename = self ._format_checkpoint_name (
408- self .filename , epoch , metrics , prefix = self .prefix
409+ self .filename , epoch , step , metrics , prefix = self .prefix
409410 )
410411 if ver is not None :
411412 filename = self .CHECKPOINT_JOIN_CHAR .join ((filename , f"v{ ver } " ))
@@ -480,13 +481,11 @@ def _validate_monitor_key(self, trainer):
480481 )
481482 raise MisconfigurationException (m )
482483
483- def _get_metric_interpolated_filepath_name (self , epoch , ckpt_name_metrics ):
484- filepath = self .format_checkpoint_name (epoch , ckpt_name_metrics )
484+ def _get_metric_interpolated_filepath_name (self , ckpt_name_metrics : Dict [ str , Any ], epoch : int , step : int ):
485+ filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics )
485486 version_cnt = 0
486487 while self ._fs .exists (filepath ):
487- filepath = self .format_checkpoint_name (
488- epoch , ckpt_name_metrics , ver = version_cnt
489- )
488+ filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics , ver = version_cnt )
490489 # this epoch called before
491490 version_cnt += 1
492491 return filepath
@@ -495,9 +494,10 @@ def _monitor_candidates(self, trainer):
495494 ckpt_name_metrics = deepcopy (trainer .logger_connector .logged_metrics )
496495 ckpt_name_metrics .update (trainer .logger_connector .callback_metrics )
497496 ckpt_name_metrics .update (trainer .logger_connector .progress_bar_metrics )
497+ ckpt_name_metrics .update ({"step" : trainer .global_step , "epoch" : trainer .current_epoch })
498498 return ckpt_name_metrics
499499
500- def _save_last_checkpoint (self , trainer , pl_module , epoch , ckpt_name_metrics , filepath ):
500+ def _save_last_checkpoint (self , trainer , pl_module , ckpt_name_metrics , filepath ):
501501 should_save_last = self .monitor is None or self .save_last
502502 if not should_save_last :
503503 return
@@ -507,7 +507,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
507507 # when user ALSO asked for the 'last.ckpt' change the name
508508 if self .save_last :
509509 last_filepath = self ._format_checkpoint_name (
510- self .CHECKPOINT_NAME_LAST , epoch , ckpt_name_metrics , prefix = self .prefix
510+ self .CHECKPOINT_NAME_LAST ,
511+ trainer .current_epoch ,
512+ trainer .global_step ,
513+ ckpt_name_metrics ,
514+ prefix = self .prefix
511515 )
512516 last_filepath = os .path .join (self .dirpath , f"{ last_filepath } .ckpt" )
513517
@@ -524,17 +528,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
524528 if self .monitor is None :
525529 self .best_model_path = self .last_model_path
526530
527- def _save_top_k_checkpoints (self , metrics , trainer , pl_module , epoch , filepath ):
531+ def _save_top_k_checkpoints (self , metrics , trainer , pl_module , filepath ):
528532 current = metrics .get (self .monitor )
533+ epoch = metrics .get ("epoch" )
534+ step = metrics .get ("step" )
529535
530536 if not isinstance (current , torch .Tensor ) and current is not None :
531537 current = torch .tensor (current , device = pl_module .device )
532538
533539 if self .check_monitor_top_k (current ):
534- self ._update_best_and_save (filepath , current , epoch , trainer , pl_module )
540+ self ._update_best_and_save (filepath , current , epoch , step , trainer , pl_module )
535541 elif self .verbose :
536542 rank_zero_info (
537- f"Epoch { epoch :d} : { self .monitor } was not in top { self .save_top_k } "
543+ f"Epoch { epoch :d} , step { step :d } : { self .monitor } was not in top { self .save_top_k } "
538544 )
539545
540546 def _is_valid_monitor_key (self , metrics ):
@@ -545,11 +551,11 @@ def _update_best_and_save(
545551 filepath : str ,
546552 current : torch .Tensor ,
547553 epoch : int ,
554+ step : int ,
548555 trainer ,
549556 pl_module ,
550557 ):
551-
552- k = epoch + 1 if self .save_top_k == - 1 else self .save_top_k
558+ k = len (self .best_k_models ) + 1 if self .save_top_k == - 1 else self .save_top_k
553559
554560 del_list = []
555561 if len (self .best_k_models ) == k and k > 0 :
@@ -576,9 +582,8 @@ def _update_best_and_save(
576582
577583 if self .verbose :
578584 rank_zero_info (
579- f"Epoch { epoch :d} : { self .monitor } reached"
580- f" { current :0.5f} (best { self .best_model_score :0.5f} ),"
581- f" saving model to { filepath } as top { k } "
585+ f"Epoch { epoch :d} , global step { step :d} : { self .monitor } reached { current :0.5f} "
586+ f' (best { self .best_model_score :0.5f} ), saving model to "{ filepath } " as top { k } '
582587 )
583588 self ._save_model (filepath , trainer , pl_module )
584589
0 commit comments