From c554eabe5fdbea0c014b9cbe41137641cc48bcaa Mon Sep 17 00:00:00 2001 From: Gautier Dagan Date: Thu, 30 Jun 2022 17:30:36 +0100 Subject: [PATCH 1/3] fix: fix mypy typing errors in lightning/trainer/optimizers.py --- pyproject.toml | 1 - src/pytorch_lightning/trainer/optimizers.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ddc903d6af9d7..dc9db77d6dabd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,6 @@ module = [ "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.data_loading", - "pytorch_lightning.trainer.optimizers", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", diff --git a/src/pytorch_lightning/trainer/optimizers.py b/src/pytorch_lightning/trainer/optimizers.py index 1cb3430f2e488..91254dfba986e 100644 --- a/src/pytorch_lightning/trainer/optimizers.py +++ b/src/pytorch_lightning/trainer/optimizers.py @@ -28,7 +28,7 @@ class TrainerOptimizersMixin(ABC): The `TrainerOptimizersMixin` was deprecated in v1.6 and will be removed in v1.8. """ - def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]: + def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List]: r""" .. deprecated:: v1.6 `TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8. @@ -39,7 +39,7 @@ def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, pl_module = self.lightning_module or model return _init_optimizers_and_lr_schedulers(pl_module) - def convert_to_lightning_optimizers(self): + def convert_to_lightning_optimizers(self)->None: r""" .. deprecated:: v1.6 `TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in v1.8. @@ -59,6 +59,6 @@ def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer: break return optimizer # type: ignore [return-value] - self.strategy._cached_lightning_optimizers = { # type: ignore [assignment] + self.strategy._cached_lightning_optimizers = { idx: _convert_to_lightning_optimizer(opt) for idx, opt in enumerate(self.optimizers) } From 38656465b40cf0de84d24bce90bc0fdc72fcfec2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Jun 2022 16:37:43 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/trainer/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/optimizers.py b/src/pytorch_lightning/trainer/optimizers.py index 91254dfba986e..22df86ffae016 100644 --- a/src/pytorch_lightning/trainer/optimizers.py +++ b/src/pytorch_lightning/trainer/optimizers.py @@ -39,7 +39,7 @@ def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List pl_module = self.lightning_module or model return _init_optimizers_and_lr_schedulers(pl_module) - def convert_to_lightning_optimizers(self)->None: + def convert_to_lightning_optimizers(self) -> None: r""" .. deprecated:: v1.6 `TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in v1.8. From c379ec7ed014f51820bddf5afd35781df10bc9db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 30 Jun 2022 18:44:11 +0200 Subject: [PATCH 3/3] Use assert, the input type is correct --- src/pytorch_lightning/trainer/optimizers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/optimizers.py b/src/pytorch_lightning/trainer/optimizers.py index 22df86ffae016..8e25fb5ac60f7 100644 --- a/src/pytorch_lightning/trainer/optimizers.py +++ b/src/pytorch_lightning/trainer/optimizers.py @@ -28,7 +28,7 @@ class TrainerOptimizersMixin(ABC): The `TrainerOptimizersMixin` was deprecated in v1.6 and will be removed in v1.8. """ - def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List]: + def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]: r""" .. deprecated:: v1.6 `TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8. @@ -37,6 +37,7 @@ def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List "`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8." ) pl_module = self.lightning_module or model + assert isinstance(pl_module, pl.LightningModule) return _init_optimizers_and_lr_schedulers(pl_module) def convert_to_lightning_optimizers(self) -> None: