Skip to content

Commit 5271ed9

Browse files
krishnakalyan3pre-commit-ci[bot]otajawaelchlirohitgr7
authored
Fix mypy errors attributed to pytorch_lightning.trainer.connectors.callback_connector.py (#13750)
* Apply suggestions from code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: otaj <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 355fda3 commit 5271ed9

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ module = [
6363
"pytorch_lightning.strategies.sharded",
6464
"pytorch_lightning.strategies.sharded_spawn",
6565
"pytorch_lightning.trainer.callback_hook",
66-
"pytorch_lightning.trainer.connectors.callback_connector",
6766
"pytorch_lightning.trainer.connectors.data_connector",
6867
"pytorch_lightning.trainer.supporters",
6968
"pytorch_lightning.trainer.trainer",

src/pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from datetime import timedelta
1818
from typing import Dict, List, Optional, Sequence, Union
1919

20+
import pytorch_lightning as pl
2021
from pytorch_lightning.callbacks import (
2122
Callback,
2223
Checkpoint,
@@ -37,7 +38,7 @@
3738

3839

3940
class CallbackConnector:
40-
def __init__(self, trainer):
41+
def __init__(self, trainer: "pl.Trainer"):
4142
self.trainer = trainer
4243

4344
def on_trainer_init(
@@ -50,7 +51,7 @@ def on_trainer_init(
5051
enable_model_summary: bool,
5152
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
5253
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
53-
):
54+
) -> None:
5455
# init folder paths for checkpoint + weights save callbacks
5556
self.trainer._default_root_dir = default_root_dir or os.getcwd()
5657
if weights_save_path:
@@ -95,16 +96,18 @@ def on_trainer_init(
9596
def _configure_accumulated_gradients(
9697
self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None
9798
) -> None:
98-
grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)]
99+
grad_accum_callbacks: List[GradientAccumulationScheduler] = [
100+
cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)
101+
]
99102

100-
if grad_accum_callback:
103+
if grad_accum_callbacks:
101104
if accumulate_grad_batches is not None:
102105
raise MisconfigurationException(
103106
"You have set both `accumulate_grad_batches` and passed an instance of "
104107
"`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` "
105108
"from trainer or remove `GradientAccumulationScheduler` from callbacks list."
106109
)
107-
grad_accum_callback = grad_accum_callback[0]
110+
grad_accum_callback = grad_accum_callbacks[0]
108111
else:
109112
if accumulate_grad_batches is None:
110113
accumulate_grad_batches = 1
@@ -148,6 +151,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
148151
progress_bar_callback = self.trainer.progress_bar_callback
149152
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)
150153

154+
model_summary: ModelSummary
151155
if progress_bar_callback is not None and is_progress_bar_rich:
152156
model_summary = RichModelSummary()
153157
else:
@@ -188,15 +192,15 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
188192
timer = Timer(duration=max_time, interval="step")
189193
self.trainer.callbacks.append(timer)
190194

191-
def _configure_fault_tolerance_callbacks(self):
195+
def _configure_fault_tolerance_callbacks(self) -> None:
192196
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint
193197

194198
if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks):
195199
raise RuntimeError("There should be only one fault-tolerance checkpoint callback.")
196200
# don't use `log_dir` to minimize the chances of failure
197201
self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir))
198202

199-
def _attach_model_logging_functions(self):
203+
def _attach_model_logging_functions(self) -> None:
200204
lightning_module = self.trainer.lightning_module
201205
for callback in self.trainer.callbacks:
202206
callback.log = lightning_module.log
@@ -243,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
243247
A new list in which the last elements are Checkpoint if there were any present in the
244248
input.
245249
"""
246-
checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)]
250+
checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)]
247251
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
248252
return not_checkpoints + checkpoints
249253

@@ -263,12 +267,12 @@ def _configure_external_callbacks() -> List[Callback]:
263267
else:
264268
from pkg_resources import iter_entry_points
265269

266-
factories = iter_entry_points("pytorch_lightning.callbacks_factory")
270+
factories = iter_entry_points("pytorch_lightning.callbacks_factory") # type: ignore[assignment]
267271

268-
external_callbacks = []
272+
external_callbacks: List[Callback] = []
269273
for factory in factories:
270274
callback_factory = factory.load()
271-
callbacks_list: List[Callback] = callback_factory()
275+
callbacks_list: Union[List[Callback], Callback] = callback_factory()
272276
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
273277
_log.info(
274278
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"

0 commit comments

Comments
 (0)