@@ -80,10 +80,10 @@ class ModelCheckpoint(Callback):
8080 the quantity monitored will be saved.
8181 if ``save_top_k == 0``, no models are saved.
8282 if ``save_top_k == -1``, all models are saved.
83- Please note that the monitors are checked every `period` epochs.
83+ Please note that the monitors are checked every `` period` ` epochs.
8484 if ``save_top_k >= 2`` and the callback is called multiple
8585 times inside an epoch, the name of the saved file will be
86- appended with a version count starting with `v0 `.
86+ appended with a version count starting with ``v1` `.
8787 mode: one of {auto, min, max}.
8888 If ``save_top_k != 0``, the decision
8989 to overwrite the current save file is made
@@ -105,6 +105,17 @@ class ModelCheckpoint(Callback):
105105 .. warning::
106106 This argument has been deprecated in v1.1 and will be removed in v1.3
107107
108+ Note:
109+ For extra customization, ModelCheckpoint includes the following attributes:
110+
111+ - ``CHECKPOINT_JOIN_CHAR = "-"``
112+ - ``CHECKPOINT_NAME_LAST = "last"``
113+ - ``FILE_EXTENSION = ".ckpt"``
114+ - ``STARTING_VERSION = 1``
115+
116+ For example, you can change the default last checkpoint name by doing
117+ ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
118+
108119 Example::
109120
110121 >>> from pytorch_lightning import Trainer
@@ -128,11 +139,13 @@ class ModelCheckpoint(Callback):
128139 model = ...
129140 trainer.fit(model)
130141 checkpoint_callback.best_model_path
142+
131143 """
132144
133145 CHECKPOINT_JOIN_CHAR = "-"
134146 CHECKPOINT_NAME_LAST = "last"
135147 FILE_EXTENSION = ".ckpt"
148+ STARTING_VERSION = 1
136149
137150 def __init__ (
138151 self ,
@@ -485,28 +498,24 @@ def _validate_monitor_key(self, trainer):
485498
486499 def _get_metric_interpolated_filepath_name (
487500 self ,
488- ckpt_name_metrics : Dict [str , Any ],
501+ monitor_candidates : Dict [str , Any ],
489502 epoch : int ,
490503 step : int ,
491504 del_filepath : Optional [str ] = None
492505 ) -> str :
493- filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics )
494-
495- version_cnt = 0
506+ filepath = self .format_checkpoint_name (epoch , step , monitor_candidates )
507+ version = self .STARTING_VERSION
496508 while self ._fs .exists (filepath ) and filepath != del_filepath :
497- filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics , ver = version_cnt )
498- version_cnt += 1
499-
509+ filepath = self .format_checkpoint_name (epoch , step , monitor_candidates , ver = version )
510+ version += 1
500511 return filepath
501512
502513 def _monitor_candidates (self , trainer ):
503- ckpt_name_metrics = deepcopy (trainer .logger_connector .logged_metrics )
504- ckpt_name_metrics .update (trainer .logger_connector .callback_metrics )
505- ckpt_name_metrics .update (trainer .logger_connector .progress_bar_metrics )
506- ckpt_name_metrics .update ({"step" : trainer .global_step , "epoch" : trainer .current_epoch })
507- return ckpt_name_metrics
514+ monitor_candidates = deepcopy (trainer .logger_connector .callback_metrics )
515+ monitor_candidates .update (step = trainer .global_step , epoch = trainer .current_epoch )
516+ return monitor_candidates
508517
509- def _save_last_checkpoint (self , trainer , pl_module , ckpt_name_metrics ):
518+ def _save_last_checkpoint (self , trainer , pl_module , monitor_candidates ):
510519 should_save_last = self .monitor is None or self .save_last
511520 if not should_save_last :
512521 return
@@ -517,13 +526,13 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
517526 self .CHECKPOINT_NAME_LAST ,
518527 trainer .current_epoch ,
519528 trainer .global_step ,
520- ckpt_name_metrics ,
521- prefix = self .prefix
529+ monitor_candidates ,
530+ prefix = self .prefix ,
522531 )
523532 last_filepath = os .path .join (self .dirpath , f"{ last_filepath } { self .FILE_EXTENSION } " )
524533 else :
525534 last_filepath = self ._get_metric_interpolated_filepath_name (
526- ckpt_name_metrics , trainer .current_epoch , trainer .global_step
535+ monitor_candidates , trainer .current_epoch , trainer .global_step
527536 )
528537
529538 accelerator_backend = trainer .accelerator_backend
@@ -534,10 +543,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
534543 else :
535544 self ._save_model (last_filepath , trainer , pl_module )
536545 if (
537- self .last_model_path
538- and self .last_model_path != last_filepath
539- and (self .save_top_k != - 1 or self .save_last )
540- and trainer .is_global_zero
546+ self .last_model_path
547+ and self .last_model_path != last_filepath
548+ and (self .save_top_k != - 1 or self .save_last )
549+ and trainer .is_global_zero
541550 ):
542551 self ._del_model (self .last_model_path )
543552 self .last_model_path = last_filepath
0 commit comments