1717from datetime import timedelta
1818from typing import Dict , List , Optional , Sequence , Union
1919
20+ import pytorch_lightning as pl
2021from pytorch_lightning .callbacks import (
2122 Callback ,
2223 Checkpoint ,
3738
3839
3940class CallbackConnector :
40- def __init__ (self , trainer ):
41+ def __init__ (self , trainer : "pl.Trainer" ):
4142 self .trainer = trainer
4243
4344 def on_trainer_init (
@@ -50,7 +51,7 @@ def on_trainer_init(
5051 enable_model_summary : bool ,
5152 max_time : Optional [Union [str , timedelta , Dict [str , int ]]] = None ,
5253 accumulate_grad_batches : Optional [Union [int , Dict [int , int ]]] = None ,
53- ):
54+ ) -> None :
5455 # init folder paths for checkpoint + weights save callbacks
5556 self .trainer ._default_root_dir = default_root_dir or os .getcwd ()
5657 if weights_save_path :
@@ -95,16 +96,18 @@ def on_trainer_init(
9596 def _configure_accumulated_gradients (
9697 self , accumulate_grad_batches : Optional [Union [int , Dict [int , int ]]] = None
9798 ) -> None :
98- grad_accum_callback = [cb for cb in self .trainer .callbacks if isinstance (cb , GradientAccumulationScheduler )]
99+ grad_accum_callbacks : List [GradientAccumulationScheduler ] = [
100+ cb for cb in self .trainer .callbacks if isinstance (cb , GradientAccumulationScheduler )
101+ ]
99102
100- if grad_accum_callback :
103+ if grad_accum_callbacks :
101104 if accumulate_grad_batches is not None :
102105 raise MisconfigurationException (
103106 "You have set both `accumulate_grad_batches` and passed an instance of "
104107 "`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` "
105108 "from trainer or remove `GradientAccumulationScheduler` from callbacks list."
106109 )
107- grad_accum_callback = grad_accum_callback [0 ]
110+ grad_accum_callback = grad_accum_callbacks [0 ]
108111 else :
109112 if accumulate_grad_batches is None :
110113 accumulate_grad_batches = 1
@@ -148,6 +151,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
148151 progress_bar_callback = self .trainer .progress_bar_callback
149152 is_progress_bar_rich = isinstance (progress_bar_callback , RichProgressBar )
150153
154+ model_summary : ModelSummary
151155 if progress_bar_callback is not None and is_progress_bar_rich :
152156 model_summary = RichModelSummary ()
153157 else :
@@ -188,15 +192,15 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
188192 timer = Timer (duration = max_time , interval = "step" )
189193 self .trainer .callbacks .append (timer )
190194
191- def _configure_fault_tolerance_callbacks (self ):
195+ def _configure_fault_tolerance_callbacks (self ) -> None :
192196 from pytorch_lightning .callbacks .fault_tolerance import _FaultToleranceCheckpoint
193197
194198 if any (isinstance (cb , _FaultToleranceCheckpoint ) for cb in self .trainer .callbacks ):
195199 raise RuntimeError ("There should be only one fault-tolerance checkpoint callback." )
196200 # don't use `log_dir` to minimize the chances of failure
197201 self .trainer .callbacks .append (_FaultToleranceCheckpoint (dirpath = self .trainer .default_root_dir ))
198202
199- def _attach_model_logging_functions (self ):
203+ def _attach_model_logging_functions (self ) -> None :
200204 lightning_module = self .trainer .lightning_module
201205 for callback in self .trainer .callbacks :
202206 callback .log = lightning_module .log
@@ -243,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
243247 A new list in which the last elements are Checkpoint if there were any present in the
244248 input.
245249 """
246- checkpoints = [c for c in callbacks if isinstance (c , Checkpoint )]
250+ checkpoints : List [ Callback ] = [c for c in callbacks if isinstance (c , Checkpoint )]
247251 not_checkpoints = [c for c in callbacks if not isinstance (c , Checkpoint )]
248252 return not_checkpoints + checkpoints
249253
@@ -263,12 +267,12 @@ def _configure_external_callbacks() -> List[Callback]:
263267 else :
264268 from pkg_resources import iter_entry_points
265269
266- factories = iter_entry_points ("pytorch_lightning.callbacks_factory" )
270+ factories = iter_entry_points ("pytorch_lightning.callbacks_factory" ) # type: ignore[assignment]
267271
268- external_callbacks = []
272+ external_callbacks : List [ Callback ] = []
269273 for factory in factories :
270274 callback_factory = factory .load ()
271- callbacks_list : List [Callback ] = callback_factory ()
275+ callbacks_list : Union [ List [Callback ], Callback ] = callback_factory ()
272276 callbacks_list = [callbacks_list ] if isinstance (callbacks_list , Callback ) else callbacks_list
273277 _log .info (
274278 f"Adding { len (callbacks_list )} callbacks from entry point '{ factory .name } ':"
0 commit comments