Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ module = [
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.strategies.tpu_spawn",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",
Expand Down
60 changes: 30 additions & 30 deletions src/pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def teardown(self, stage: Optional[str] = None) -> None:
for callback in self.callbacks:
callback.teardown(self, self.lightning_module, stage=stage)

def on_init_start(self):
def on_init_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_init_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -98,7 +98,7 @@ def on_init_start(self):
for callback in self.callbacks:
callback.on_init_start(self)

def on_init_end(self):
def on_init_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_init_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -111,7 +111,7 @@ def on_init_end(self):
for callback in self.callbacks:
callback.on_init_end(self)

def on_fit_start(self):
def on_fit_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_fit_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -124,7 +124,7 @@ def on_fit_start(self):
for callback in self.callbacks:
callback.on_fit_start(self, self.lightning_module)

def on_fit_end(self):
def on_fit_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_fit_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -137,7 +137,7 @@ def on_fit_end(self):
for callback in self.callbacks:
callback.on_fit_end(self, self.lightning_module)

def on_sanity_check_start(self):
def on_sanity_check_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_sanity_check_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -150,7 +150,7 @@ def on_sanity_check_start(self):
for callback in self.callbacks:
callback.on_sanity_check_start(self, self.lightning_module)

def on_sanity_check_end(self):
def on_sanity_check_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_sanity_check_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -163,7 +163,7 @@ def on_sanity_check_end(self):
for callback in self.callbacks:
callback.on_sanity_check_end(self, self.lightning_module)

def on_train_epoch_start(self):
def on_train_epoch_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_train_epoch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -176,7 +176,7 @@ def on_train_epoch_start(self):
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.lightning_module)

def on_train_epoch_end(self):
def on_train_epoch_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_train_epoch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -189,7 +189,7 @@ def on_train_epoch_end(self):
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.lightning_module)

def on_validation_epoch_start(self):
def on_validation_epoch_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_validation_epoch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -202,7 +202,7 @@ def on_validation_epoch_start(self):
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.lightning_module)

def on_validation_epoch_end(self):
def on_validation_epoch_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_validation_epoch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -215,7 +215,7 @@ def on_validation_epoch_end(self):
for callback in self.callbacks:
callback.on_validation_epoch_end(self, self.lightning_module)

def on_test_epoch_start(self):
def on_test_epoch_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_test_epoch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -228,7 +228,7 @@ def on_test_epoch_start(self):
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.lightning_module)

def on_test_epoch_end(self):
def on_test_epoch_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_test_epoch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand Down Expand Up @@ -267,7 +267,7 @@ def on_predict_epoch_end(self, outputs: List[Any]) -> None:
for callback in self.callbacks:
callback.on_predict_epoch_end(self, self.lightning_module, outputs)

def on_epoch_start(self):
def on_epoch_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_epoch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -280,7 +280,7 @@ def on_epoch_start(self):
for callback in self.callbacks:
callback.on_epoch_start(self, self.lightning_module)

def on_epoch_end(self):
def on_epoch_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_epoch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -293,7 +293,7 @@ def on_epoch_end(self):
for callback in self.callbacks:
callback.on_epoch_end(self, self.lightning_module)

def on_train_start(self):
def on_train_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_train_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -306,7 +306,7 @@ def on_train_start(self):
for callback in self.callbacks:
callback.on_train_start(self, self.lightning_module)

def on_train_end(self):
def on_train_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_train_end` was deprecated in v1.6 and will be removed in v1.8.
Expand Down Expand Up @@ -345,7 +345,7 @@ def on_pretrain_routine_end(self) -> None:
for callback in self.callbacks:
callback.on_pretrain_routine_end(self, self.lightning_module)

def on_batch_start(self):
def on_batch_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_batch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -358,7 +358,7 @@ def on_batch_start(self):
for callback in self.callbacks:
callback.on_batch_start(self, self.lightning_module)

def on_batch_end(self):
def on_batch_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_batch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -371,7 +371,7 @@ def on_batch_end(self):
for callback in self.callbacks:
callback.on_batch_end(self, self.lightning_module)

def on_train_batch_start(self, batch, batch_idx):
def on_train_batch_start(self, batch, batch_idx: int) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -384,7 +384,7 @@ def on_train_batch_start(self, batch, batch_idx):
for callback in self.callbacks:
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx)

def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx):
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx: int) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -397,7 +397,7 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx):
for callback in self.callbacks:
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx)

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_validation_batch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -410,7 +410,7 @@ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
for callback in self.callbacks:
callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx):
def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx: int, dataloader_idx: int) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_validation_batch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -423,7 +423,7 @@ def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, datalo
for callback in self.callbacks:
callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
def on_test_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -436,7 +436,7 @@ def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
for callback in self.callbacks:
callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_test_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx):
def on_test_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx: int, dataloader_idx: int) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_test_batch_end` was deprecated in v1.6 and will be removed in v1.8.
Expand Down Expand Up @@ -475,7 +475,7 @@ def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int,
for callback in self.callbacks:
callback.on_predict_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_validation_start(self):
def on_validation_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_validation_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -488,7 +488,7 @@ def on_validation_start(self):
for callback in self.callbacks:
callback.on_validation_start(self, self.lightning_module)

def on_validation_end(self):
def on_validation_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_validation_end` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -501,7 +501,7 @@ def on_validation_end(self):
for callback in self.callbacks:
callback.on_validation_end(self, self.lightning_module)

def on_test_start(self):
def on_test_start(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_test_start` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -514,7 +514,7 @@ def on_test_start(self):
for callback in self.callbacks:
callback.on_test_start(self, self.lightning_module)

def on_test_end(self):
def on_test_end(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_test_end` was deprecated in v1.6 and will be removed in v1.8.
Expand Down Expand Up @@ -630,7 +630,7 @@ def on_before_backward(self, loss: Tensor) -> None:
for callback in self.callbacks:
callback.on_before_backward(self, self.lightning_module, loss)

def on_after_backward(self):
def on_after_backward(self) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_after_backward` was deprecated in v1.6 and will be removed in v1.8.
Expand All @@ -643,7 +643,7 @@ def on_after_backward(self):
for callback in self.callbacks:
callback.on_after_backward(self, self.lightning_module)

def on_before_optimizer_step(self, optimizer, optimizer_idx):
def on_before_optimizer_step(self, optimizer, optimizer_idx) -> None:
r"""
.. deprecated:: v1.6
`TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8.
Expand Down