@@ -189,7 +189,7 @@ def on_validation_end(self, trainer, pl_module):
189189 """
190190 checkpoints can be saved at the end of the val loop
191191 """
192- self .save_checkpoint (trainer , pl_module )
192+ self .save_checkpoint (trainer )
193193
194194 def on_save_checkpoint (self , trainer , pl_module , checkpoint : Dict [str , Any ]) -> Dict [str , Any ]:
195195 return {
@@ -204,12 +204,18 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):
204204 self .best_model_score = callback_state ["best_model_score" ]
205205 self .best_model_path = callback_state ["best_model_path" ]
206206
207- def save_checkpoint (self , trainer , pl_module ):
207+ def save_checkpoint (self , trainer , unused : Optional = None ):
208208 """
209209 Performs the main logic around saving a checkpoint.
210210 This method runs on all ranks, it is the responsibility of `self.save_function`
211211 to handle correct behaviour in distributed training, i.e., saving only on rank 0.
212212 """
213+ if unused is not None :
214+ rank_zero_warn (
215+ "`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter"
216+ " has been removed. Support for the old signature will be removed in v1.5" , DeprecationWarning
217+ )
218+
213219 epoch = trainer .current_epoch
214220 global_step = trainer .global_step
215221
@@ -218,7 +224,6 @@ def save_checkpoint(self, trainer, pl_module):
218224 trainer .fast_dev_run # disable checkpointing with fast_dev_run
219225 or trainer .state != TrainerState .FITTING # don't save anything during non-fit
220226 or trainer .sanity_checking # don't save anything during sanity check
221- or self .save_top_k == 0 # no models are saved
222227 or self .period < 1 # no models are saved
223228 or (epoch + 1 ) % self .period # skip epoch
224229 or self ._last_global_step_saved == global_step # already saved at the last step
@@ -236,28 +241,33 @@ def save_checkpoint(self, trainer, pl_module):
236241
237242 # callback supports multiple simultaneous modes
238243 # here we call each mode sequentially
239- # Mode 1: save all checkpoints OR only the top k
240- if self .save_top_k :
241- self . _save_top_k_checkpoints ( trainer , pl_module , monitor_candidates )
242-
243- # Mode 2 : save the last checkpoint
244+ # Mode 1: save the top k checkpoints
245+ self ._save_top_k_checkpoint ( trainer , monitor_candidates )
246+ # Mode 2: save monitor=None checkpoints
247+ self . _save_none_monitor_checkpoint ( trainer , monitor_candidates )
248+ # Mode 3 : save last checkpoints
244249 self ._save_last_checkpoint (trainer , monitor_candidates )
245250
246251 def __validate_init_configuration (self ):
247252 if self .save_top_k is not None and self .save_top_k < - 1 :
248253 raise MisconfigurationException (f'Invalid value for save_top_k={ self .save_top_k } . Must be None or >= -1' )
249254 if self .monitor is None :
250255 # None: save last epoch, -1: save all epochs, 0: nothing is saved
251- if self .save_top_k not in [ None , - 1 , 0 ] :
256+ if self .save_top_k not in ( None , - 1 , 0 ) :
252257 raise MisconfigurationException (
253258 f'ModelCheckpoint(save_top_k={ self .save_top_k } , monitor=None) is not a valid'
254259 ' configuration. No quantity for top_k to track.'
255260 )
256261 if self .save_last :
257262 rank_zero_warn (
258- 'ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration.'
263+ 'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.'
259264 ' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).'
260265 )
266+ if self .save_top_k == - 1 and self .save_last :
267+ rank_zero_info (
268+ 'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)'
269+ ' will duplicate the last checkpoint saved.'
270+ )
261271
262272 def __init_ckpt_dir (self , dirpath , filename , save_top_k ):
263273
@@ -293,7 +303,16 @@ def _del_model(self, filepath: str):
293303 self ._fs .rm (filepath )
294304 log .debug (f"Removed checkpoint: { filepath } " )
295305
296- def _save_model (self , filepath : str , trainer ):
306+ def _save_model (self , trainer , filepath : str ):
307+ if trainer .training_type_plugin .rpc_enabled :
308+ # RPCPlugin manages saving all model states
309+ # TODO: the rpc plugin should wrap trainer.save_checkpoint
310+ # instead of us having to do it here manually
311+ trainer .training_type_plugin .rpc_save_model (trainer , self ._do_save , filepath )
312+ else :
313+ self ._do_save (trainer , filepath )
314+
315+ def _do_save (self , trainer , filepath : str ):
297316 # in debugging, track when we save checkpoints
298317 trainer .dev_debugger .track_checkpointing_history (filepath )
299318
@@ -307,7 +326,7 @@ def _save_model(self, filepath: str, trainer):
307326 else :
308327 raise ValueError (".save_function() not set" )
309328
310- def check_monitor_top_k (self , current ) -> bool :
329+ def check_monitor_top_k (self , current : torch . Tensor ) -> bool :
311330 if current is None :
312331 return False
313332
@@ -462,17 +481,17 @@ def _validate_monitor_key(self, trainer):
462481
463482 def _get_metric_interpolated_filepath_name (
464483 self ,
465- ckpt_name_metrics : Dict [str , Any ],
484+ monitor_candidates : Dict [str , Any ],
466485 epoch : int ,
467486 step : int ,
468487 trainer ,
469488 del_filepath : Optional [str ] = None ,
470489 ) -> str :
471- filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics )
490+ filepath = self .format_checkpoint_name (epoch , step , monitor_candidates )
472491
473492 version_cnt = self .STARTING_VERSION
474493 while self .file_exists (filepath , trainer ) and filepath != del_filepath :
475- filepath = self .format_checkpoint_name (epoch , step , ckpt_name_metrics , ver = version_cnt )
494+ filepath = self .format_checkpoint_name (epoch , step , monitor_candidates , ver = version_cnt )
476495 version_cnt += 1
477496
478497 return filepath
@@ -482,47 +501,32 @@ def _monitor_candidates(self, trainer):
482501 monitor_candidates .update (step = trainer .global_step , epoch = trainer .current_epoch )
483502 return monitor_candidates
484503
485- def _save_last_checkpoint (self , trainer , ckpt_name_metrics ):
486- should_save_last = self .monitor is None or self .save_last
487- if not should_save_last :
504+ def _save_last_checkpoint (self , trainer , monitor_candidates : Dict [str , Any ]):
505+ if not self .save_last :
488506 return
489507
490- # when user ALSO asked for the 'last.ckpt' change the name
491- if self .save_last :
492- last_filepath = self ._format_checkpoint_name (
493- self .CHECKPOINT_NAME_LAST ,
494- trainer .current_epoch ,
495- trainer .global_step ,
496- ckpt_name_metrics ,
497- )
498- last_filepath = os .path .join (self .dirpath , f"{ last_filepath } { self .FILE_EXTENSION } " )
499- else :
500- last_filepath = self ._get_metric_interpolated_filepath_name (
501- ckpt_name_metrics ,
502- trainer .current_epoch ,
503- trainer .global_step ,
504- trainer ,
505- )
508+ filepath = self ._format_checkpoint_name (
509+ self .CHECKPOINT_NAME_LAST ,
510+ trainer .current_epoch ,
511+ trainer .global_step ,
512+ monitor_candidates ,
513+ )
514+ filepath = os .path .join (self .dirpath , f"{ filepath } { self .FILE_EXTENSION } " )
506515
507- if trainer .training_type_plugin .rpc_enabled :
508- # RPCPlugin manages saving all model states
509- trainer .training_type_plugin .rpc_save_model (self ._save_model , last_filepath , trainer )
510- else :
511- self ._save_model (last_filepath , trainer )
512- if (
513- self .last_model_path and self .last_model_path != last_filepath
514- and (self .save_top_k != - 1 or self .save_last ) and trainer .is_global_zero
515- ):
516+ self ._save_model (trainer , filepath )
517+
518+ if self .last_model_path and self .last_model_path != filepath and trainer .is_global_zero :
516519 self ._del_model (self .last_model_path )
517- self .last_model_path = last_filepath
518520
519- if self .monitor is None :
520- self .best_model_path = self .last_model_path
521+ self .last_model_path = filepath
522+
523+ def _save_top_k_checkpoint (self , trainer , monitor_candidates : Dict [str , Any ]):
524+ if self .monitor is None or self .save_top_k == 0 :
525+ return
521526
522- def _save_top_k_checkpoints (self , trainer , pl_module , metrics ):
523- current = metrics .get (self .monitor )
524- epoch = metrics .get ("epoch" )
525- step = metrics .get ("step" )
527+ current = monitor_candidates .get (self .monitor )
528+ epoch = monitor_candidates .get ("epoch" )
529+ step = monitor_candidates .get ("step" )
526530
527531 # when `val_loss` is being logged and no ModelCheckpoint is being provided
528532 # `val_loss` will be selected for monitor and need to be reduced to
@@ -533,15 +537,37 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
533537 current = trainer .training_type_plugin .reduce (current , reduce_op = "mean" )
534538
535539 if self .check_monitor_top_k (current ):
536- self ._update_best_and_save (current , epoch , step , trainer , pl_module , metrics )
537- elif self .monitor is not None and self . verbose :
540+ self ._update_best_and_save (current , epoch , step , trainer , monitor_candidates )
541+ elif self .verbose :
538542 rank_zero_info (f"Epoch { epoch :d} , step { step :d} : { self .monitor } was not in top { self .save_top_k } " )
539543
544+ def _save_none_monitor_checkpoint (self , trainer , monitor_candidates : Dict [str , Any ]):
545+ if self .monitor is not None or self .save_top_k == 0 :
546+ return
547+
548+ filepath = self ._get_metric_interpolated_filepath_name (
549+ monitor_candidates ,
550+ trainer .current_epoch ,
551+ trainer .global_step ,
552+ trainer ,
553+ )
554+ self ._save_model (trainer , filepath )
555+
556+ if (
557+ self .save_top_k is None
558+ and self .best_model_path
559+ and self .best_model_path != filepath
560+ and trainer .is_global_zero
561+ ):
562+ self ._del_model (self .best_model_path )
563+
564+ self .best_model_path = filepath
565+
540566 def _is_valid_monitor_key (self , metrics ):
541567 return self .monitor in metrics or len (metrics ) == 0
542568
543569 def _update_best_and_save (
544- self , current : torch .Tensor , epoch : int , step : int , trainer , pl_module , ckpt_name_metrics
570+ self , current : torch .Tensor , epoch : int , step : int , trainer , monitor_candidates : Dict [ str , Any ]
545571 ):
546572 k = len (self .best_k_models ) + 1 if self .save_top_k == - 1 else self .save_top_k
547573
@@ -554,7 +580,7 @@ def _update_best_and_save(
554580 if isinstance (current , torch .Tensor ) and torch .isnan (current ):
555581 current = torch .tensor (float ('inf' if self .mode == "min" else '-inf' ))
556582
557- filepath = self ._get_metric_interpolated_filepath_name (ckpt_name_metrics , epoch , step , trainer , del_filepath )
583+ filepath = self ._get_metric_interpolated_filepath_name (monitor_candidates , epoch , step , trainer , del_filepath )
558584
559585 # save the current score
560586 self .current_score = current
@@ -575,7 +601,7 @@ def _update_best_and_save(
575601 f"Epoch { epoch :d} , global step { step :d} : { self .monitor } reached { current :0.5f} "
576602 f' (best { self .best_model_score :0.5f} ), saving model to "{ filepath } " as top { k } '
577603 )
578- self ._save_model (filepath , trainer )
604+ self ._save_model (trainer , filepath )
579605
580606 if del_filepath is not None and filepath != del_filepath :
581607 self ._del_model (del_filepath )
0 commit comments