3939from pytorch_lightning .utilities .exceptions import MisconfigurationException
4040from pytorch_lightning .utilities .logger import _name , _version
4141from pytorch_lightning .utilities .rank_zero import rank_zero_deprecation , rank_zero_info , rank_zero_warn
42- from pytorch_lightning .utilities .types import _METRIC , _PATH , STEP_OUTPUT
42+ from pytorch_lightning .utilities .types import _PATH , STEP_OUTPUT
4343from pytorch_lightning .utilities .warnings import WarningCache
4444
4545log = logging .getLogger (__name__ )
@@ -231,13 +231,14 @@ def __init__(
231231 self ._save_on_train_epoch_end = save_on_train_epoch_end
232232 self ._last_global_step_saved = 0 # no need to save when no steps were taken
233233 self ._last_time_checked : Optional [float ] = None
234- self .current_score = None
235- self .best_k_models = {}
234+ self .current_score : Optional [ Tensor ] = None
235+ self .best_k_models : Dict [ str , Tensor ] = {}
236236 self .kth_best_model_path = ""
237- self .best_model_score = None
237+ self .best_model_score : Optional [ Tensor ] = None
238238 self .best_model_path = ""
239239 self .last_model_path = ""
240240
241+ self .kth_value : Tensor
241242 self .__init_monitor_mode (mode )
242243 self .__init_ckpt_dir (dirpath , filename )
243244 self .__init_triggers (every_n_train_steps , every_n_epochs , train_time_interval )
@@ -256,6 +257,7 @@ def state_key(self) -> str:
256257
257258 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : Optional [str ] = None ) -> None :
258259 self .__resolve_ckpt_dir (trainer )
260+ assert self .dirpath is not None
259261 if trainer .is_global_zero and stage == "fit" :
260262 self .__warn_if_dir_not_empty (self .dirpath )
261263
@@ -362,7 +364,7 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
362364 self ._save_topk_checkpoint (trainer , monitor_candidates )
363365 self ._save_last_checkpoint (trainer , monitor_candidates )
364366
365- def _save_topk_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
367+ def _save_topk_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , Tensor ]) -> None :
366368 if self .save_top_k == 0 :
367369 return
368370
@@ -395,7 +397,7 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
395397 from pytorch_lightning .trainer .states import TrainerFn
396398
397399 return (
398- trainer .fast_dev_run # disable checkpointing with fast_dev_run
400+ bool ( trainer .fast_dev_run ) # disable checkpointing with fast_dev_run
399401 or trainer .state .fn != TrainerFn .FITTING # don't save anything during non-fit
400402 or trainer .sanity_checking # don't save anything during sanity check
401403 or self ._last_global_step_saved == trainer .global_step # already saved at the last step
@@ -493,15 +495,15 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
493495 should_update_best_and_save = monitor_op (current , self .best_k_models [self .kth_best_model_path ])
494496
495497 # If using multiple devices, make sure all processes are unanimous on the decision.
496- should_update_best_and_save = trainer .strategy .reduce_boolean_decision (should_update_best_and_save )
498+ should_update_best_and_save = trainer .strategy .reduce_boolean_decision (bool ( should_update_best_and_save ) )
497499
498500 return should_update_best_and_save
499501
500502 @classmethod
501503 def _format_checkpoint_name (
502504 cls ,
503505 filename : Optional [str ],
504- metrics : Dict [str , _METRIC ],
506+ metrics : Dict [str , Tensor ],
505507 prefix : str = "" ,
506508 auto_insert_metric_name : bool = True ,
507509 ) -> str :
@@ -522,7 +524,7 @@ def _format_checkpoint_name(
522524 filename = filename .replace (group , f"{{0[{ name } ]" )
523525
524526 if name not in metrics :
525- metrics [name ] = 0
527+ metrics [name ] = torch . tensor ( 0 )
526528 filename = filename .format (metrics )
527529
528530 if prefix :
@@ -531,7 +533,7 @@ def _format_checkpoint_name(
531533 return filename
532534
533535 def format_checkpoint_name (
534- self , metrics : Dict [str , _METRIC ], filename : Optional [str ] = None , ver : Optional [int ] = None
536+ self , metrics : Dict [str , Tensor ], filename : Optional [str ] = None , ver : Optional [int ] = None
535537 ) -> str :
536538 """Generate a filename according to the defined template.
537539
@@ -591,6 +593,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
591593 ckpt_path = os .path .join (trainer ._weights_save_path_internal , "checkpoints" )
592594 elif trainer .loggers :
593595 if len (trainer .loggers ) == 1 :
596+ assert trainer .logger is not None
594597 save_dir = trainer .logger .save_dir or trainer .default_root_dir
595598 else :
596599 save_dir = trainer .default_root_dir
@@ -613,7 +616,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
613616 rank_zero_warn (f"Checkpoint directory { dirpath } exists and is not empty." )
614617
615618 def _get_metric_interpolated_filepath_name (
616- self , monitor_candidates : Dict [str , _METRIC ], trainer : "pl.Trainer" , del_filepath : Optional [str ] = None
619+ self , monitor_candidates : Dict [str , Tensor ], trainer : "pl.Trainer" , del_filepath : Optional [str ] = None
617620 ) -> str :
618621 filepath = self .format_checkpoint_name (monitor_candidates )
619622
@@ -624,7 +627,7 @@ def _get_metric_interpolated_filepath_name(
624627
625628 return filepath
626629
627- def _monitor_candidates (self , trainer : "pl.Trainer" ) -> Dict [str , _METRIC ]:
630+ def _monitor_candidates (self , trainer : "pl.Trainer" ) -> Dict [str , Tensor ]:
628631 monitor_candidates = deepcopy (trainer .callback_metrics )
629632 # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
630633 # or does not exist we overwrite it as it's likely an error
@@ -634,7 +637,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
634637 monitor_candidates ["step" ] = step .int () if isinstance (step , Tensor ) else torch .tensor (trainer .global_step )
635638 return monitor_candidates
636639
637- def _save_last_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
640+ def _save_last_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , Tensor ]) -> None :
638641 if not self .save_last :
639642 return
640643
@@ -651,16 +654,18 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
651654 if previous and previous != filepath :
652655 trainer .strategy .remove_checkpoint (previous )
653656
654- def _save_monitor_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
657+ def _save_monitor_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , Tensor ]) -> None :
658+ assert self .monitor
655659 current = monitor_candidates .get (self .monitor )
656660 if self .check_monitor_top_k (trainer , current ):
661+ assert current is not None
657662 self ._update_best_and_save (current , trainer , monitor_candidates )
658663 elif self .verbose :
659664 epoch = monitor_candidates ["epoch" ]
660665 step = monitor_candidates ["step" ]
661666 rank_zero_info (f"Epoch { epoch :d} , global step { step :d} : { self .monitor !r} was not in top { self .save_top_k } " )
662667
663- def _save_none_monitor_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
668+ def _save_none_monitor_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , Tensor ]) -> None :
664669 filepath = self ._get_metric_interpolated_filepath_name (monitor_candidates , trainer )
665670 # set the best model path before saving because it will be part of the state.
666671 previous , self .best_model_path = self .best_model_path , filepath
@@ -669,7 +674,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
669674 trainer .strategy .remove_checkpoint (previous )
670675
671676 def _update_best_and_save (
672- self , current : Tensor , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]
677+ self , current : Tensor , trainer : "pl.Trainer" , monitor_candidates : Dict [str , Tensor ]
673678 ) -> None :
674679 k = len (self .best_k_models ) + 1 if self .save_top_k == - 1 else self .save_top_k
675680
@@ -691,11 +696,11 @@ def _update_best_and_save(
691696 if len (self .best_k_models ) == k :
692697 # monitor dict has reached k elements
693698 _op = max if self .mode == "min" else min
694- self .kth_best_model_path = _op (self .best_k_models , key = self .best_k_models .get )
699+ self .kth_best_model_path = _op (self .best_k_models , key = self .best_k_models .get ) # type: ignore[arg-type]
695700 self .kth_value = self .best_k_models [self .kth_best_model_path ]
696701
697702 _op = min if self .mode == "min" else max
698- self .best_model_path = _op (self .best_k_models , key = self .best_k_models .get )
703+ self .best_model_path = _op (self .best_k_models , key = self .best_k_models .get ) # type: ignore[arg-type]
699704 self .best_model_score = self .best_k_models [self .best_model_path ]
700705
701706 if self .verbose :
@@ -715,6 +720,7 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
715720 file."""
716721 best_k = {k : v .item () for k , v in self .best_k_models .items ()}
717722 if filepath is None :
723+ assert self .dirpath
718724 filepath = os .path .join (self .dirpath , "best_k_models.yaml" )
719725 with self ._fs .open (filepath , "w" ) as fp :
720726 yaml .dump (best_k , fp )
0 commit comments