diff --git a/pyproject.toml b/pyproject.toml index 05eba62c50402..d69fb21066555 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,6 @@ module = [ "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.trainer.callback_hook", - "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 83881905beeb1..bb7f912420256 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -17,6 +17,7 @@ from datetime import timedelta from typing import Dict, List, Optional, Sequence, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks import ( Callback, Checkpoint, @@ -37,7 +38,7 @@ class CallbackConnector: - def __init__(self, trainer): + def __init__(self, trainer: "pl.Trainer"): self.trainer = trainer def on_trainer_init( @@ -50,7 +51,7 @@ def on_trainer_init( enable_model_summary: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, - ): + ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() if weights_save_path: @@ -95,16 +96,18 @@ def on_trainer_init( def _configure_accumulated_gradients( self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None ) -> None: - grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)] + grad_accum_callbacks: List[GradientAccumulationScheduler] = [ + cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler) + ] - if grad_accum_callback: + if grad_accum_callbacks: if accumulate_grad_batches is not None: raise MisconfigurationException( "You have set both `accumulate_grad_batches` and passed an instance of " "`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` " "from trainer or remove `GradientAccumulationScheduler` from callbacks list." ) - grad_accum_callback = grad_accum_callback[0] + grad_accum_callback = grad_accum_callbacks[0] else: if accumulate_grad_batches is None: accumulate_grad_batches = 1 @@ -148,6 +151,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) + model_summary: ModelSummary if progress_bar_callback is not None and is_progress_bar_rich: model_summary = RichModelSummary() else: @@ -188,7 +192,7 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic timer = Timer(duration=max_time, interval="step") self.trainer.callbacks.append(timer) - def _configure_fault_tolerance_callbacks(self): + def _configure_fault_tolerance_callbacks(self) -> None: from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks): @@ -196,7 +200,7 @@ def _configure_fault_tolerance_callbacks(self): # don't use `log_dir` to minimize the chances of failure self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir)) - def _attach_model_logging_functions(self): + def _attach_model_logging_functions(self) -> None: lightning_module = self.trainer.lightning_module for callback in self.trainer.callbacks: callback.log = lightning_module.log @@ -243,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)] + checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)] not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints @@ -263,12 +267,12 @@ def _configure_external_callbacks() -> List[Callback]: else: from pkg_resources import iter_entry_points - factories = iter_entry_points("pytorch_lightning.callbacks_factory") + factories = iter_entry_points("pytorch_lightning.callbacks_factory") # type: ignore[assignment] - external_callbacks = [] + external_callbacks: List[Callback] = [] for factory in factories: callback_factory = factory.load() - callbacks_list: List[Callback] = callback_factory() + callbacks_list: Union[List[Callback], Callback] = callback_factory() callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list _log.info( f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"