Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ module = [
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.strategies.tpu_spawn",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
Expand Down
26 changes: 15 additions & 11 deletions src/pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from datetime import timedelta
from typing import Dict, List, Optional, Sequence, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
Callback,
Checkpoint,
Expand All @@ -37,7 +38,7 @@


class CallbackConnector:
def __init__(self, trainer):
def __init__(self, trainer: "pl.Trainer"):
self.trainer = trainer

def on_trainer_init(
Expand All @@ -50,7 +51,7 @@ def on_trainer_init(
enable_model_summary: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
):
) -> None:
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
if weights_save_path:
Expand Down Expand Up @@ -95,16 +96,18 @@ def on_trainer_init(
def _configure_accumulated_gradients(
self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None
) -> None:
grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)]
grad_accum_callbacks: List[GradientAccumulationScheduler] = [
cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)
]

if grad_accum_callback:
if grad_accum_callbacks:
if accumulate_grad_batches is not None:
raise MisconfigurationException(
"You have set both `accumulate_grad_batches` and passed an instance of "
"`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` "
"from trainer or remove `GradientAccumulationScheduler` from callbacks list."
)
grad_accum_callback = grad_accum_callback[0]
grad_accum_callback = grad_accum_callbacks[0]
else:
if accumulate_grad_batches is None:
accumulate_grad_batches = 1
Expand Down Expand Up @@ -148,6 +151,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
progress_bar_callback = self.trainer.progress_bar_callback
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)

model_summary: ModelSummary
if progress_bar_callback is not None and is_progress_bar_rich:
model_summary = RichModelSummary()
else:
Expand Down Expand Up @@ -188,15 +192,15 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
timer = Timer(duration=max_time, interval="step")
self.trainer.callbacks.append(timer)

def _configure_fault_tolerance_callbacks(self):
def _configure_fault_tolerance_callbacks(self) -> None:
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint

if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks):
raise RuntimeError("There should be only one fault-tolerance checkpoint callback.")
# don't use `log_dir` to minimize the chances of failure
self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir))

def _attach_model_logging_functions(self):
def _attach_model_logging_functions(self) -> None:
lightning_module = self.trainer.lightning_module
for callback in self.trainer.callbacks:
callback.log = lightning_module.log
Expand Down Expand Up @@ -243,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
A new list in which the last elements are Checkpoint if there were any present in the
input.
"""
checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)]
checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
return not_checkpoints + checkpoints

Expand All @@ -263,12 +267,12 @@ def _configure_external_callbacks() -> List[Callback]:
else:
from pkg_resources import iter_entry_points

factories = iter_entry_points("pytorch_lightning.callbacks_factory")
factories = iter_entry_points("pytorch_lightning.callbacks_factory") # type: ignore[assignment]

external_callbacks = []
external_callbacks: List[Callback] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: List[Callback] = callback_factory()
callbacks_list: Union[List[Callback], Callback] = callback_factory()
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
Expand Down