2525from datetime import timedelta
2626from functools import partial
2727from pathlib import Path
28- from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Tuple , Type , Union
28+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Type , Union
2929from weakref import proxy
3030
3131import torch
7777from pytorch_lightning .trainer .connectors .checkpoint_connector import CheckpointConnector
7878from pytorch_lightning .trainer .connectors .data_connector import DataConnector
7979from pytorch_lightning .trainer .connectors .logger_connector import LoggerConnector
80- from pytorch_lightning .trainer .connectors .logger_connector .result import _ResultCollection
80+ from pytorch_lightning .trainer .connectors .logger_connector .result import _OUT_DICT , _ResultCollection
8181from pytorch_lightning .trainer .connectors .signal_connector import SignalConnector
8282from pytorch_lightning .trainer .data_loading import TrainerDataLoadingMixin
8383from pytorch_lightning .trainer .optimizers import TrainerOptimizersMixin
@@ -545,6 +545,7 @@ def __init__(
545545 self ._logger_connector .on_trainer_init (logger , log_every_n_steps , move_metrics_to_cpu )
546546
547547 # init debugging flags
548+ self .val_check_batch : Union [int , float ]
548549 self .val_check_interval : Union [int , float ]
549550 self .num_sanity_val_steps : Union [int , float ]
550551 self .limit_train_batches : Union [int , float ]
@@ -741,7 +742,7 @@ def _fit_impl(
741742 # TODO: ckpt_path only in v2.0
742743 ckpt_path = ckpt_path or self .resume_from_checkpoint
743744 self ._ckpt_path = self .__set_ckpt_path (
744- ckpt_path , model_provided = True , model_connected = self .lightning_module is not None
745+ ckpt_path , model_provided = True , model_connected = self .lightning_module is not None # type: ignore
745746 )
746747 results = self ._run (model , ckpt_path = self .ckpt_path )
747748
@@ -985,7 +986,7 @@ def _predict_impl(
985986 self .state .status = TrainerStatus .RUNNING
986987 self .predicting = True
987988
988- self .predict_loop .return_predictions = return_predictions
989+ self .predict_loop .return_predictions = return_predictions # type: ignore
989990
990991 # if a datamodule comes in as the second arg, then fix it for the user
991992 if isinstance (dataloaders , LightningDataModule ):
@@ -1395,7 +1396,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
13951396
13961397 if model_provided and ckpt_path is None :
13971398 # use passed model to function without loading weights
1398- return
1399+ return None
13991400
14001401 if model_connected and ckpt_path is None :
14011402 ckpt_path = "best"
@@ -1449,8 +1450,8 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
14491450 f'.{ fn } (ckpt_path="last") is set, but there is no fault tolerant'
14501451 " or last checkpoint available. No checkpoint will be loaded."
14511452 )
1452- return
1453- ckpt_path = max (candidates_ts .keys (), key = partial (operator .getitem , candidates_ts ))
1453+ return None
1454+ ckpt_path = max (candidates_ts .keys (), key = partial (operator .getitem , candidates_ts )) # type: ignore
14541455
14551456 if not ckpt_path :
14561457 raise MisconfigurationException (
@@ -1664,7 +1665,7 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
16641665 prev_fx_name = pl_module ._current_fx_name
16651666 pl_module ._current_fx_name = "on_load_checkpoint"
16661667
1667- callback_states : Dict [Union [Type , str ], Dict ] = checkpoint .get ("callbacks" )
1668+ callback_states : Optional [ Dict [Union [Type , str ], Dict ] ] = checkpoint .get ("callbacks" )
16681669
16691670 if callback_states is None :
16701671 return
@@ -1692,7 +1693,7 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
16921693
16931694 def _call_callbacks_load_state_dict (self , checkpoint : Dict [str , Any ]) -> None :
16941695 """Called when loading a model checkpoint, calls every callback's `load_state_dict`."""
1695- callback_states : Dict [Union [Type , str ], Dict ] = checkpoint .get ("callbacks" )
1696+ callback_states : Optional [ Dict [Union [Type , str ], Dict ] ] = checkpoint .get ("callbacks" )
16961697
16971698 if callback_states is None :
16981699 return
@@ -1745,6 +1746,7 @@ def __init_profiler(self, profiler: Optional[Union[Profiler, str]]) -> None:
17451746 )
17461747 profiler_class = PROFILERS [profiler ]
17471748 profiler = profiler_class ()
1749+ assert isinstance (profiler , Profiler )
17481750 self .profiler : Profiler = profiler or PassThroughProfiler ()
17491751
17501752 def __setup_profiler (self ) -> None :
@@ -2126,8 +2128,9 @@ def data_parallel_device_ids(self) -> Optional[List[int]]:
21262128 return self .device_ids if isinstance (self .accelerator , CUDAAccelerator ) else None
21272129
21282130 @property
2129- def lightning_module (self ) -> "pl.LightningModule" :
2131+ def lightning_module (self ) -> "pl.LightningModule" : # type: ignore
21302132 # TODO: this is actually an optional return
2133+ assert self .strategy .lightning_module is not None
21312134 return self .strategy .lightning_module
21322135
21332136 @property
@@ -2219,12 +2222,12 @@ def model(self, model: torch.nn.Module) -> None:
22192222
22202223 @property
22212224 def log_dir (self ) -> Optional [str ]:
2222- assert self .logger is not None
22232225 if len (self .loggers ) == 1 :
2224- if isinstance (self .logger , TensorBoardLogger ):
2225- dirpath = self .logger .log_dir
2226- else :
2226+ assert self .logger is not None
2227+ if not isinstance (self .logger , TensorBoardLogger ):
22272228 dirpath = self .logger .save_dir
2229+ else :
2230+ dirpath = self .logger .log_dir
22282231 else :
22292232 dirpath = self .default_root_dir
22302233
@@ -2709,7 +2712,7 @@ def logger(self, logger: Optional[Logger]) -> None:
27092712 if not logger :
27102713 self .loggers = []
27112714 elif isinstance (logger , LoggerCollection ):
2712- self .loggers = list ( logger )
2715+ self .loggers = [ x for x in logger ]
27132716 else :
27142717 self .loggers = [logger ]
27152718
@@ -2722,17 +2725,17 @@ def loggers(self, loggers: Optional[List[Logger]]) -> None:
27222725 self ._loggers = loggers if loggers else []
27232726
27242727 @property
2725- def callback_metrics (self ) -> Dict [ str , Tensor ] :
2728+ def callback_metrics (self ) -> Dict :
27262729 # TODO: the true typing return can include dictionaries as defined in
27272730 # `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
27282731 return self ._logger_connector .callback_metrics
27292732
27302733 @property
2731- def logged_metrics (self ) -> dict :
2734+ def logged_metrics (self ) -> _OUT_DICT :
27322735 return self ._logger_connector .logged_metrics
27332736
27342737 @property
2735- def progress_bar_metrics (self ) -> dict :
2738+ def progress_bar_metrics (self ) -> Dict :
27362739 return self ._logger_connector .progress_bar_metrics
27372740
27382741 @property
@@ -2748,7 +2751,7 @@ def _exit_gracefully_on_signal(self) -> None:
27482751
27492752 def _should_terminate_gracefully (self ) -> bool :
27502753 value = torch .tensor (int (self ._terminate_gracefully ), device = self .strategy .root_device )
2751- return self .strategy .reduce (value , reduce_op = "sum" ) > 0
2754+ return bool ( self .strategy .reduce (value , reduce_op = "sum" ) > 0 )
27522755
27532756 """
27542757 Other
0 commit comments