From 01ddc4719671e1243387a5f98032151637d574c8 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Wed, 20 Jul 2022 01:34:16 -0400 Subject: [PATCH 01/13] fix function return type --- .../trainer/connectors/callback_connector.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 83881905beeb1..4f04ea812e5da 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -16,6 +16,8 @@ import os from datetime import timedelta from typing import Dict, List, Optional, Sequence, Union +import pytorch_lightning as pl + from pytorch_lightning.callbacks import ( Callback, @@ -37,7 +39,7 @@ class CallbackConnector: - def __init__(self, trainer): + def __init__(self, trainer: "pl.Trainer"): self.trainer = trainer def on_trainer_init( @@ -50,7 +52,7 @@ def on_trainer_init( enable_model_summary: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, - ): + ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() if weights_save_path: @@ -188,7 +190,7 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic timer = Timer(duration=max_time, interval="step") self.trainer.callbacks.append(timer) - def _configure_fault_tolerance_callbacks(self): + def _configure_fault_tolerance_callbacks(self) -> None: from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks): @@ -196,7 +198,7 @@ def _configure_fault_tolerance_callbacks(self): # don't use `log_dir` to minimize the chances of failure self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir)) - def _attach_model_logging_functions(self): + def _attach_model_logging_functions(self) -> None: lightning_module = self.trainer.lightning_module for callback in self.trainer.callbacks: callback.log = lightning_module.log From 9d08c69a4161b53d2f6ba81e1906ea9ca998f242 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Wed, 20 Jul 2022 01:35:28 -0400 Subject: [PATCH 02/13] add toml file --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 989e63122f640..b990079bdf144 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,6 @@ module = [ "pytorch_lightning.strategies.strategy", "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.trainer.callback_hook", - "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", From 58c62c528e2322339cc231320bec8156a1474c25 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Jul 2022 05:38:23 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/trainer/connectors/callback_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 4f04ea812e5da..8430555dd970e 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -16,9 +16,8 @@ import os from datetime import timedelta from typing import Dict, List, Optional, Sequence, Union -import pytorch_lightning as pl - +import pytorch_lightning as pl from pytorch_lightning.callbacks import ( Callback, Checkpoint, From cf48e92e6984d0bdfc05f273ff5f39bfc3e4e699 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Mon, 25 Jul 2022 08:59:18 -0400 Subject: [PATCH 04/13] precommit --- .../trainer/connectors/callback_connector.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 4f04ea812e5da..778d1a47dc4d1 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -16,9 +16,8 @@ import os from datetime import timedelta from typing import Dict, List, Optional, Sequence, Union -import pytorch_lightning as pl - +import pytorch_lightning as pl from pytorch_lightning.callbacks import ( Callback, Checkpoint, @@ -97,7 +96,9 @@ def on_trainer_init( def _configure_accumulated_gradients( self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None ) -> None: - grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)] + grad_accum_callback: List[GradientAccumulationScheduler] = [ + cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler) + ] if grad_accum_callback: if accumulate_grad_batches is not None: @@ -151,7 +152,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) if progress_bar_callback is not None and is_progress_bar_rich: - model_summary = RichModelSummary() + model_summary: ModelSummary = RichModelSummary() else: model_summary = ModelSummary() self.trainer.callbacks.append(model_summary) @@ -245,8 +246,8 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)] - not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] + checkpoints = [c for c in callbacks if isinstance(c, Checkpoint and Callback)] + not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint and Callback)] return not_checkpoints + checkpoints @@ -267,7 +268,7 @@ def _configure_external_callbacks() -> List[Callback]: factories = iter_entry_points("pytorch_lightning.callbacks_factory") - external_callbacks = [] + external_callbacks: List[Callback] = [] for factory in factories: callback_factory = factory.load() callbacks_list: List[Callback] = callback_factory() From 0ff7c8ef5bc1c4e853869f912036e33b81229097 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Fri, 29 Jul 2022 11:04:10 +0300 Subject: [PATCH 05/13] Update src/pytorch_lightning/trainer/connectors/callback_connector.py Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> --- src/pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 778d1a47dc4d1..65db166189a2a 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -271,7 +271,7 @@ def _configure_external_callbacks() -> List[Callback]: external_callbacks: List[Callback] = [] for factory in factories: callback_factory = factory.load() - callbacks_list: List[Callback] = callback_factory() + callbacks_list: Union[List[Callback], Callback] = callback_factory() callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list _log.info( f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':" From 59a5fae43867df6ba30f89956d5c38287f2fdb7c Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Fri, 29 Jul 2022 11:04:23 +0300 Subject: [PATCH 06/13] Update src/pytorch_lightning/trainer/connectors/callback_connector.py Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> --- src/pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 65db166189a2a..daa53c5856eee 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -247,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: input. """ checkpoints = [c for c in callbacks if isinstance(c, Checkpoint and Callback)] - not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint and Callback)] + not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints From 9e0368c8b2c4dae5077f92df93decccd655bab60 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Fri, 29 Jul 2022 11:04:29 +0300 Subject: [PATCH 07/13] Update src/pytorch_lightning/trainer/connectors/callback_connector.py Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> --- src/pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index daa53c5856eee..2ce013d4860f7 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -246,7 +246,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, Checkpoint and Callback)] + checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)] not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints From 481c0932959e790e89e9cb481712466137130e05 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Fri, 29 Jul 2022 04:23:08 -0400 Subject: [PATCH 08/13] suggestion commit var change --- .../trainer/connectors/callback_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 2ce013d4860f7..25fceccce3ed0 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -151,10 +151,10 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) + model_summary: ModelSummary = ModelSummary() if progress_bar_callback is not None and is_progress_bar_rich: - model_summary: ModelSummary = RichModelSummary() - else: - model_summary = ModelSummary() + model_summary = RichModelSummary() + self.trainer.callbacks.append(model_summary) def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: From 750b04f2b591bb354d6ee60d00cdb2333d2930b2 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Fri, 29 Jul 2022 05:11:28 -0400 Subject: [PATCH 09/13] review call back variable slit change --- .../trainer/connectors/callback_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 25fceccce3ed0..a2375309b2eb9 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -96,18 +96,18 @@ def on_trainer_init( def _configure_accumulated_gradients( self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None ) -> None: - grad_accum_callback: List[GradientAccumulationScheduler] = [ + grad_accum_callbacks: List[Callback] = [ cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler) ] - if grad_accum_callback: + if grad_accum_callbacks: if accumulate_grad_batches is not None: raise MisconfigurationException( "You have set both `accumulate_grad_batches` and passed an instance of " "`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` " "from trainer or remove `GradientAccumulationScheduler` from callbacks list." ) - grad_accum_callback = grad_accum_callback[0] + grad_accum_callback: Callback = grad_accum_callbacks[0] else: if accumulate_grad_batches is None: accumulate_grad_batches = 1 @@ -266,7 +266,7 @@ def _configure_external_callbacks() -> List[Callback]: else: from pkg_resources import iter_entry_points - factories = iter_entry_points("pytorch_lightning.callbacks_factory") + factories = iter_entry_points("pytorch_lightning.callbacks_factory") # type: ignore[assignment] external_callbacks: List[Callback] = [] for factory in factories: From 7ed941958c45c16edad8695285a18e30c34e3b69 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Sat, 30 Jul 2022 03:37:27 -0400 Subject: [PATCH 10/13] nit --- .../trainer/connectors/callback_connector.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index a2375309b2eb9..2e66d886d0ad0 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -101,13 +101,14 @@ def _configure_accumulated_gradients( ] if grad_accum_callbacks: + grad_accum_callback: Callback if accumulate_grad_batches is not None: raise MisconfigurationException( "You have set both `accumulate_grad_batches` and passed an instance of " "`GradientAccumulationScheduler` inside callbacks. Either remove `accumulate_grad_batches` " "from trainer or remove `GradientAccumulationScheduler` from callbacks list." ) - grad_accum_callback: Callback = grad_accum_callbacks[0] + grad_accum_callback = grad_accum_callbacks[0] else: if accumulate_grad_batches is None: accumulate_grad_batches = 1 @@ -151,10 +152,11 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) - model_summary: ModelSummary = ModelSummary() + model_summary: ModelSummary if progress_bar_callback is not None and is_progress_bar_rich: model_summary = RichModelSummary() - + else: + model_summary = ModelSummary() self.trainer.callbacks.append(model_summary) def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: From ed669e37f9585d57590fd8b994a502a04fe77d35 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 2 Aug 2022 10:10:01 +0200 Subject: [PATCH 11/13] Update src/pytorch_lightning/trainer/connectors/callback_connector.py Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> --- src/pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 2e66d886d0ad0..3bbc18e57ed3c 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -100,8 +100,8 @@ def _configure_accumulated_gradients( cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler) ] + grad_accum_callback: Callback if grad_accum_callbacks: - grad_accum_callback: Callback if accumulate_grad_batches is not None: raise MisconfigurationException( "You have set both `accumulate_grad_batches` and passed an instance of " From e721f23d9a9b9a40363ef9c828dc183accd00568 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 3 Aug 2022 17:33:27 -0400 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: Rohit Gupta --- .../trainer/connectors/callback_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 3bbc18e57ed3c..ef1aea6197295 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -96,11 +96,10 @@ def on_trainer_init( def _configure_accumulated_gradients( self, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None ) -> None: - grad_accum_callbacks: List[Callback] = [ + grad_accum_callbacks: List[GradientAccumulationScheduler] = [ cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler) ] - grad_accum_callback: Callback if grad_accum_callbacks: if accumulate_grad_batches is not None: raise MisconfigurationException( @@ -248,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)] + checkpoints: List[Checkpoint] = [c for c in callbacks if isinstance(c, Checkpoint)] not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints From 00fa20a086c4d10270c7a647eb143956f8ac8ab0 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 4 Aug 2022 14:00:30 +0530 Subject: [PATCH 13/13] Apply suggestions from code review --- src/pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index ef1aea6197295..bb7f912420256 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -247,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints: List[Checkpoint] = [c for c in callbacks if isinstance(c, Checkpoint)] + checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)] not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints