From 3617f92b68035b09092ee4f4e9dc88381db35b8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 05:13:26 +0100 Subject: [PATCH 01/16] fix typing for lr_monitor --- pyproject.toml | 1 - pytorch_lightning/accelerators/accelerator.py | 4 +- pytorch_lightning/callbacks/lr_monitor.py | 2 +- pytorch_lightning/trainer/optimizers.py | 3 +- pytorch_lightning/trainer/trainer.py | 5 ++- pytorch_lightning/utilities/types.py | 44 +++++++++++++++++-- 6 files changed, 48 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f219d8f509d37..561add15b3d93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,6 @@ module = [ "pytorch_lightning.accelerators.accelerator", "pytorch_lightning.accelerators.gpu", "pytorch_lightning.callbacks.finetuning", - "pytorch_lightning.callbacks.lr_monitor", "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.prediction_writer", "pytorch_lightning.callbacks.progress.base", diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index eb3886b209503..6d49b698a1fdc 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT, LR_SCHEDULER_CONFIG class Accelerator: @@ -63,7 +63,7 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl self.training_type_plugin._precision_plugin = precision_plugin self.optimizers: List = [] - self.lr_schedulers: List = [] + self.lr_schedulers: List[LR_SCHEDULER_CONFIG] = [] self.optimizer_frequencies: List = [] def setup_environment(self) -> None: diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index d72f42d8f8616..947909cdadfb1 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -86,7 +86,7 @@ def configure_optimizer(self): """ - def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False): + def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None: if logging_interval not in (None, "step", "epoch"): raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.") diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 9de9e83614f57..a91d98dd3ab2f 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -23,6 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LR_SCHEDULER_CONFIG class TrainerOptimizersMixin(ABC): @@ -123,7 +124,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer): @staticmethod def _configure_schedulers( schedulers: list, monitor: Optional[str], is_manual_optimization: bool - ) -> List[Dict[str, Any]]: + ) -> List[LR_SCHEDULER_CONFIG]: """Convert each scheduler into dict structure with relevant information.""" lr_schedulers = [] default_config = _get_default_scheduler_config() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b417d40484028..b70e23a54034b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -95,6 +95,7 @@ EVAL_DATALOADERS, LRSchedulerTypeUnion, TRAIN_DATALOADERS, + LR_SCHEDULER_CONFIG, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning @@ -1656,11 +1657,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.accelerator.optimizers = new_optims @property - def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: + def lr_schedulers(self) -> List[LR_SCHEDULER_CONFIG]: return self.accelerator.lr_schedulers @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: + def lr_schedulers(self, new_schedulers: List[LR_SCHEDULER_CONFIG]) -> None: self.accelerator.lr_schedulers = new_schedulers @property diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 86c9028a5ff1f..0bd1b200aa78e 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -14,13 +14,26 @@ """ Convention: - Do not include any `_TYPE` suffix - - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) + - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + Sequence, + Type, + Union, + TypedDict, + Optional, + runtime_checkable, + Protocol, +) import torch -from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau +from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric @@ -43,7 +56,30 @@ Dict[str, Sequence[DataLoader]], ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -# todo: improve LRSchedulerType naming/typing + + +@runtime_checkable +class _LRScheduler(Protocol): + optimizer: Optimizer + + +@runtime_checkable +class ReduceLROnPlateau(Protocol): + optimizer: Optimizer + + +# todo: improve LRSchedulerType naming/typing ??? LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] + + +class LR_SCHEDULER_CONFIG(TypedDict): + scheduler: LRSchedulerTypeUnion + name: Optional[str] + interval: str + frequency: int + reduce_on_plateau: bool + monitor: Optional[str] + strict: bool + opt_idx: Optional[int] From 214e0efa6e3e3ba89f44fc3f7f76304a44a0008c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Nov 2021 04:15:56 +0000 Subject: [PATCH 02/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/utilities/types.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 6d49b698a1fdc..fe0f1addc43a6 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum -from pytorch_lightning.utilities.types import STEP_OUTPUT, LR_SCHEDULER_CONFIG +from pytorch_lightning.utilities.types import LR_SCHEDULER_CONFIG, STEP_OUTPUT class Accelerator: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b70e23a54034b..e2112a564c763 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -93,9 +93,9 @@ _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, + LR_SCHEDULER_CONFIG, LRSchedulerTypeUnion, TRAIN_DATALOADERS, - LR_SCHEDULER_CONFIG, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 0bd1b200aa78e..acec602228745 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,13 +23,13 @@ Iterator, List, Mapping, + Optional, + Protocol, + runtime_checkable, Sequence, Type, - Union, TypedDict, - Optional, - runtime_checkable, - Protocol, + Union, ) import torch From a04c719413ca0be7414858cc07efcabdb235f808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 05:26:34 +0100 Subject: [PATCH 03/16] python 3.6 support --- pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index acec602228745..34a648b3686fb 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -24,7 +24,6 @@ List, Mapping, Optional, - Protocol, runtime_checkable, Sequence, Type, @@ -36,6 +35,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric +from typing_extensions import Protocol _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] From 4abf8cae6fbfceb53821988878a0ad69ce4ab94a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Nov 2021 04:28:39 +0000 Subject: [PATCH 04/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/types.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 34a648b3686fb..a82e3ed8abcfe 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,19 +17,7 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import ( - Any, - Dict, - Iterator, - List, - Mapping, - Optional, - runtime_checkable, - Sequence, - Type, - TypedDict, - Union, -) +from typing import Any, Dict, Iterator, List, Mapping, Optional, runtime_checkable, Sequence, Type, TypedDict, Union import torch from torch.optim import Optimizer From 72224303c2144b85fbc7e1c8fa1f1b3ff66758b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:53:36 +0100 Subject: [PATCH 05/16] copy configs from torch and rename --- .../training_type/training_type_plugin.py | 5 +- pytorch_lightning/trainer/optimizers.py | 4 +- pytorch_lightning/trainer/trainer.py | 6 +- pytorch_lightning/utilities/types.py | 57 +++++++++++++++++-- 4 files changed, 58 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 14e9e59f7dc01..7fb4e1e3c352e 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -19,7 +19,6 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -30,7 +29,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig TBroadcast = TypeVar("TBroadcast") @@ -47,7 +46,7 @@ def __init__( self._checkpoint_io = checkpoint_io self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() self.optimizers: List[Optimizer] = [] - self.lr_schedulers: List[_LRScheduler] = [] + self.lr_schedulers: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] @property diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index a4929aee11908..8ee1017cf0a76 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import LR_SCHEDULER_CONFIG +from pytorch_lightning.utilities.types import LRSchedulerConfig class TrainerOptimizersMixin(ABC): @@ -124,7 +124,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer): @staticmethod def _configure_schedulers( schedulers: list, monitor: Optional[str], is_manual_optimization: bool - ) -> List[LR_SCHEDULER_CONFIG]: + ) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information.""" lr_schedulers = [] default_config = _get_default_scheduler_config() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2f96d69ba1fb9..d015970366868 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -102,7 +102,7 @@ _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, - LR_SCHEDULER_CONFIG, + LRSchedulerConfig, LRSchedulerTypeUnion, TRAIN_DATALOADERS, ) @@ -1721,11 +1721,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.training_type_plugin.optimizers = new_optims @property - def lr_schedulers(self) -> List[LR_SCHEDULER_CONFIG]: + def lr_schedulers(self) -> List[LRSchedulerConfig]: return self.training_type_plugin.lr_schedulers @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: List[LR_SCHEDULER_CONFIG]) -> None: + def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None: self.training_type_plugin.lr_schedulers = new_schedulers @property diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index a82e3ed8abcfe..4c376b05618d0 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -46,23 +46,68 @@ EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -@runtime_checkable -class _LRScheduler(Protocol): +# Copied from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class _LRScheduler: optimizer: Optimizer + def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: + ... -@runtime_checkable -class ReduceLROnPlateau(Protocol): + def state_dict(self) -> dict: + ... + + def load_state_dict(self, state_dict: dict) -> None: + ... + + def get_last_lr(self) -> List[float]: + ... + + def get_lr(self) -> float: + ... + + def step(self, epoch: Optional[int] = ...) -> None: + ... + + +# Copied from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class ReduceLROnPlateau: + in_cooldown: bool optimizer: Optimizer + def __init__( + self, + optimizer: Optimizer, + mode: str = ..., + factor: float = ..., + patience: int = ..., + verbose: bool = ..., + threshold: float = ..., + threshold_mode: str = ..., + cooldown: int = ..., + min_lr: float = ..., + eps: float = ..., + ) -> None: + ... + + def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: + ... + + def state_dict(self) -> dict: + ... + + def load_state_dict(self, state_dict: dict): + ... + -# todo: improve LRSchedulerType naming/typing ??? +# todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] -class LR_SCHEDULER_CONFIG(TypedDict): +class LRSchedulerConfig(TypedDict): scheduler: LRSchedulerTypeUnion name: Optional[str] interval: str From e4775a26aceadef46912e09df59ab93ede83baff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:54:59 +0100 Subject: [PATCH 06/16] add type --- pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 4c376b05618d0..0d78737d2ae94 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -97,7 +97,7 @@ def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: def state_dict(self) -> dict: ... - def load_state_dict(self, state_dict: dict): + def load_state_dict(self, state_dict: dict) -> None: ... From 7b59eb48da6d79af1e4fa63c15b685177fa249bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:55:54 +0100 Subject: [PATCH 07/16] remove notebooks --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index a2fb6468112b7..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc From 89fed768f97966f8dd98aced7e09d21b2d801f7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:56:11 +0100 Subject: [PATCH 08/16] reset _notebooks --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..0c325829101d5 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 0c325829101d5a6ebf32ed99bbf5b09badf04a59 From e83e2f8eb5347eb257e6e301fffa414bfbef674a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:00:00 +0100 Subject: [PATCH 09/16] unused imports --- pytorch_lightning/trainer/trainer.py | 1 - pytorch_lightning/utilities/types.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d015970366868..2a0e77f770edf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -103,7 +103,6 @@ _PREDICT_OUTPUT, EVAL_DATALOADERS, LRSchedulerConfig, - LRSchedulerTypeUnion, TRAIN_DATALOADERS, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 0d78737d2ae94..1abc9b2ed4a00 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,13 +17,12 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Optional, runtime_checkable, Sequence, Type, TypedDict, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, TypedDict, Union import torch from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric -from typing_extensions import Protocol _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] From d71017defe481b6458a4fc055b83028dc99ebb7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:08:17 +0100 Subject: [PATCH 10/16] fix import --- pytorch_lightning/utilities/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 1abc9b2ed4a00..e1af694b34389 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,12 +17,13 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, TypedDict, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union import torch from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric +from typing_extensions import TypedDict _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] From 33544ef366bdca7f8dc1d2006c719c550f44529b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:26:25 +0100 Subject: [PATCH 11/16] try to fix types for cli --- pytorch_lightning/utilities/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index e1af694b34389..899314763ecbe 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -21,6 +21,7 @@ import torch from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as _TorchLRScheduler, ReduceLROnPlateau as TorchReduceLROnPlateau from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import TypedDict @@ -48,7 +49,7 @@ # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class _LRScheduler: +class _LRScheduler(_TorchLRScheduler): optimizer: Optimizer def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: @@ -72,7 +73,7 @@ def step(self, epoch: Optional[int] = ...) -> None: # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau: +class ReduceLROnPlateau(TorchReduceLROnPlateau): in_cooldown: bool optimizer: Optimizer From 5b4af6b8c783305571e778596e151030d1dfd3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:26:25 +0100 Subject: [PATCH 12/16] Revert "try to fix types for cli" This reverts commit 33544ef366bdca7f8dc1d2006c719c550f44529b. --- pytorch_lightning/utilities/types.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 899314763ecbe..e1af694b34389 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -21,7 +21,6 @@ import torch from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as _TorchLRScheduler, ReduceLROnPlateau as TorchReduceLROnPlateau from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import TypedDict @@ -49,7 +48,7 @@ # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class _LRScheduler(_TorchLRScheduler): +class _LRScheduler: optimizer: Optimizer def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: @@ -73,7 +72,7 @@ def step(self, epoch: Optional[int] = ...) -> None: # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau(TorchReduceLROnPlateau): +class ReduceLROnPlateau: in_cooldown: bool optimizer: Optimizer From 52b4d93b8fda0d6191206f6fa209ddf6b8761fa3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Dec 2021 02:33:49 +0100 Subject: [PATCH 13/16] Fix CLI --- pytorch_lightning/utilities/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index e1af694b34389..904fbaec13a0f 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -102,9 +102,9 @@ def load_state_dict(self, state_dict: dict) -> None: # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] class LRSchedulerConfig(TypedDict): From 73c4fd1eb0a61fc4591639f1292d3701292b1675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Dec 2021 03:31:37 +0100 Subject: [PATCH 14/16] update types --- pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 904fbaec13a0f..44a3b88d530d6 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -108,7 +108,7 @@ def load_state_dict(self, state_dict: dict) -> None: class LRSchedulerConfig(TypedDict): - scheduler: LRSchedulerTypeUnion + scheduler: Union[_LRScheduler, ReduceLROnPlateau] name: Optional[str] interval: str frequency: int From 732e3be4e9a30b02a113b8ae62185059b2a1dc65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Dec 2021 03:31:46 +0100 Subject: [PATCH 15/16] assertions --- pytorch_lightning/callbacks/lr_monitor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index e69023a0e61f0..9d4e70fb880ce 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -146,6 +146,7 @@ def _check_no_key(key: str) -> bool: self.last_momentum_values = {name + "-momentum": None for name in names_flatten} def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + assert trainer.logger is not None if not trainer.logger_connector.should_update_logs: return @@ -157,6 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + assert trainer.logger is not None if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) From 17df0dcc188ac31995e7632c34abcc5fc9ccbcb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Dec 2021 02:33:02 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index eaca4b53d3a13..3a9a9aad7c8e2 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT, LRSchedulerConfig +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT TBroadcast = TypeVar("TBroadcast")