From 3617f92b68035b09092ee4f4e9dc88381db35b8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 05:13:26 +0100 Subject: [PATCH 01/17] fix typing for lr_monitor --- pyproject.toml | 1 - pytorch_lightning/accelerators/accelerator.py | 4 +- pytorch_lightning/callbacks/lr_monitor.py | 2 +- pytorch_lightning/trainer/optimizers.py | 3 +- pytorch_lightning/trainer/trainer.py | 5 ++- pytorch_lightning/utilities/types.py | 44 +++++++++++++++++-- 6 files changed, 48 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f219d8f509d37..561add15b3d93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,6 @@ module = [ "pytorch_lightning.accelerators.accelerator", "pytorch_lightning.accelerators.gpu", "pytorch_lightning.callbacks.finetuning", - "pytorch_lightning.callbacks.lr_monitor", "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.prediction_writer", "pytorch_lightning.callbacks.progress.base", diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index eb3886b209503..6d49b698a1fdc 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT, LR_SCHEDULER_CONFIG class Accelerator: @@ -63,7 +63,7 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl self.training_type_plugin._precision_plugin = precision_plugin self.optimizers: List = [] - self.lr_schedulers: List = [] + self.lr_schedulers: List[LR_SCHEDULER_CONFIG] = [] self.optimizer_frequencies: List = [] def setup_environment(self) -> None: diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index d72f42d8f8616..947909cdadfb1 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -86,7 +86,7 @@ def configure_optimizer(self): """ - def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False): + def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None: if logging_interval not in (None, "step", "epoch"): raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.") diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 9de9e83614f57..a91d98dd3ab2f 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -23,6 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LR_SCHEDULER_CONFIG class TrainerOptimizersMixin(ABC): @@ -123,7 +124,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer): @staticmethod def _configure_schedulers( schedulers: list, monitor: Optional[str], is_manual_optimization: bool - ) -> List[Dict[str, Any]]: + ) -> List[LR_SCHEDULER_CONFIG]: """Convert each scheduler into dict structure with relevant information.""" lr_schedulers = [] default_config = _get_default_scheduler_config() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b417d40484028..b70e23a54034b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -95,6 +95,7 @@ EVAL_DATALOADERS, LRSchedulerTypeUnion, TRAIN_DATALOADERS, + LR_SCHEDULER_CONFIG, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning @@ -1656,11 +1657,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.accelerator.optimizers = new_optims @property - def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: + def lr_schedulers(self) -> List[LR_SCHEDULER_CONFIG]: return self.accelerator.lr_schedulers @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: + def lr_schedulers(self, new_schedulers: List[LR_SCHEDULER_CONFIG]) -> None: self.accelerator.lr_schedulers = new_schedulers @property diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 86c9028a5ff1f..0bd1b200aa78e 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -14,13 +14,26 @@ """ Convention: - Do not include any `_TYPE` suffix - - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) + - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + Sequence, + Type, + Union, + TypedDict, + Optional, + runtime_checkable, + Protocol, +) import torch -from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau +from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric @@ -43,7 +56,30 @@ Dict[str, Sequence[DataLoader]], ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -# todo: improve LRSchedulerType naming/typing + + +@runtime_checkable +class _LRScheduler(Protocol): + optimizer: Optimizer + + +@runtime_checkable +class ReduceLROnPlateau(Protocol): + optimizer: Optimizer + + +# todo: improve LRSchedulerType naming/typing ??? LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] + + +class LR_SCHEDULER_CONFIG(TypedDict): + scheduler: LRSchedulerTypeUnion + name: Optional[str] + interval: str + frequency: int + reduce_on_plateau: bool + monitor: Optional[str] + strict: bool + opt_idx: Optional[int] From 214e0efa6e3e3ba89f44fc3f7f76304a44a0008c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Nov 2021 04:15:56 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/utilities/types.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 6d49b698a1fdc..fe0f1addc43a6 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum -from pytorch_lightning.utilities.types import STEP_OUTPUT, LR_SCHEDULER_CONFIG +from pytorch_lightning.utilities.types import LR_SCHEDULER_CONFIG, STEP_OUTPUT class Accelerator: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b70e23a54034b..e2112a564c763 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -93,9 +93,9 @@ _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, + LR_SCHEDULER_CONFIG, LRSchedulerTypeUnion, TRAIN_DATALOADERS, - LR_SCHEDULER_CONFIG, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 0bd1b200aa78e..acec602228745 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,13 +23,13 @@ Iterator, List, Mapping, + Optional, + Protocol, + runtime_checkable, Sequence, Type, - Union, TypedDict, - Optional, - runtime_checkable, - Protocol, + Union, ) import torch From a04c719413ca0be7414858cc07efcabdb235f808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 05:26:34 +0100 Subject: [PATCH 03/17] python 3.6 support --- pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index acec602228745..34a648b3686fb 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -24,7 +24,6 @@ List, Mapping, Optional, - Protocol, runtime_checkable, Sequence, Type, @@ -36,6 +35,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric +from typing_extensions import Protocol _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] From 4abf8cae6fbfceb53821988878a0ad69ce4ab94a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Nov 2021 04:28:39 +0000 Subject: [PATCH 04/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/types.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 34a648b3686fb..a82e3ed8abcfe 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,19 +17,7 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import ( - Any, - Dict, - Iterator, - List, - Mapping, - Optional, - runtime_checkable, - Sequence, - Type, - TypedDict, - Union, -) +from typing import Any, Dict, Iterator, List, Mapping, Optional, runtime_checkable, Sequence, Type, TypedDict, Union import torch from torch.optim import Optimizer From 72224303c2144b85fbc7e1c8fa1f1b3ff66758b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:53:36 +0100 Subject: [PATCH 05/17] copy configs from torch and rename --- .../training_type/training_type_plugin.py | 5 +- pytorch_lightning/trainer/optimizers.py | 4 +- pytorch_lightning/trainer/trainer.py | 6 +- pytorch_lightning/utilities/types.py | 57 +++++++++++++++++-- 4 files changed, 58 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 14e9e59f7dc01..7fb4e1e3c352e 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -19,7 +19,6 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -30,7 +29,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig TBroadcast = TypeVar("TBroadcast") @@ -47,7 +46,7 @@ def __init__( self._checkpoint_io = checkpoint_io self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() self.optimizers: List[Optimizer] = [] - self.lr_schedulers: List[_LRScheduler] = [] + self.lr_schedulers: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] @property diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index a4929aee11908..8ee1017cf0a76 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import LR_SCHEDULER_CONFIG +from pytorch_lightning.utilities.types import LRSchedulerConfig class TrainerOptimizersMixin(ABC): @@ -124,7 +124,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer): @staticmethod def _configure_schedulers( schedulers: list, monitor: Optional[str], is_manual_optimization: bool - ) -> List[LR_SCHEDULER_CONFIG]: + ) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information.""" lr_schedulers = [] default_config = _get_default_scheduler_config() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2f96d69ba1fb9..d015970366868 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -102,7 +102,7 @@ _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, - LR_SCHEDULER_CONFIG, + LRSchedulerConfig, LRSchedulerTypeUnion, TRAIN_DATALOADERS, ) @@ -1721,11 +1721,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.training_type_plugin.optimizers = new_optims @property - def lr_schedulers(self) -> List[LR_SCHEDULER_CONFIG]: + def lr_schedulers(self) -> List[LRSchedulerConfig]: return self.training_type_plugin.lr_schedulers @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: List[LR_SCHEDULER_CONFIG]) -> None: + def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None: self.training_type_plugin.lr_schedulers = new_schedulers @property diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index a82e3ed8abcfe..4c376b05618d0 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -46,23 +46,68 @@ EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -@runtime_checkable -class _LRScheduler(Protocol): +# Copied from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class _LRScheduler: optimizer: Optimizer + def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: + ... -@runtime_checkable -class ReduceLROnPlateau(Protocol): + def state_dict(self) -> dict: + ... + + def load_state_dict(self, state_dict: dict) -> None: + ... + + def get_last_lr(self) -> List[float]: + ... + + def get_lr(self) -> float: + ... + + def step(self, epoch: Optional[int] = ...) -> None: + ... + + +# Copied from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class ReduceLROnPlateau: + in_cooldown: bool optimizer: Optimizer + def __init__( + self, + optimizer: Optimizer, + mode: str = ..., + factor: float = ..., + patience: int = ..., + verbose: bool = ..., + threshold: float = ..., + threshold_mode: str = ..., + cooldown: int = ..., + min_lr: float = ..., + eps: float = ..., + ) -> None: + ... + + def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: + ... + + def state_dict(self) -> dict: + ... + + def load_state_dict(self, state_dict: dict): + ... + -# todo: improve LRSchedulerType naming/typing ??? +# todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] -class LR_SCHEDULER_CONFIG(TypedDict): +class LRSchedulerConfig(TypedDict): scheduler: LRSchedulerTypeUnion name: Optional[str] interval: str From e4775a26aceadef46912e09df59ab93ede83baff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:54:59 +0100 Subject: [PATCH 06/17] add type --- pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 4c376b05618d0..0d78737d2ae94 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -97,7 +97,7 @@ def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: def state_dict(self) -> dict: ... - def load_state_dict(self, state_dict: dict): + def load_state_dict(self, state_dict: dict) -> None: ... From 7b59eb48da6d79af1e4fa63c15b685177fa249bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:55:54 +0100 Subject: [PATCH 07/17] remove notebooks --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index a2fb6468112b7..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc From 89fed768f97966f8dd98aced7e09d21b2d801f7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 07:56:11 +0100 Subject: [PATCH 08/17] reset _notebooks --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..0c325829101d5 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 0c325829101d5a6ebf32ed99bbf5b09badf04a59 From e83e2f8eb5347eb257e6e301fffa414bfbef674a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:00:00 +0100 Subject: [PATCH 09/17] unused imports --- pytorch_lightning/trainer/trainer.py | 1 - pytorch_lightning/utilities/types.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d015970366868..2a0e77f770edf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -103,7 +103,6 @@ _PREDICT_OUTPUT, EVAL_DATALOADERS, LRSchedulerConfig, - LRSchedulerTypeUnion, TRAIN_DATALOADERS, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 0d78737d2ae94..1abc9b2ed4a00 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,13 +17,12 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Optional, runtime_checkable, Sequence, Type, TypedDict, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, TypedDict, Union import torch from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric -from typing_extensions import Protocol _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] From d71017defe481b6458a4fc055b83028dc99ebb7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:08:17 +0100 Subject: [PATCH 10/17] fix import --- pytorch_lightning/utilities/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 1abc9b2ed4a00..e1af694b34389 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,12 +17,13 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, TypedDict, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union import torch from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric +from typing_extensions import TypedDict _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] From 33544ef366bdca7f8dc1d2006c719c550f44529b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:26:25 +0100 Subject: [PATCH 11/17] try to fix types for cli --- pytorch_lightning/utilities/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index e1af694b34389..899314763ecbe 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -21,6 +21,7 @@ import torch from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as _TorchLRScheduler, ReduceLROnPlateau as TorchReduceLROnPlateau from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import TypedDict @@ -48,7 +49,7 @@ # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class _LRScheduler: +class _LRScheduler(_TorchLRScheduler): optimizer: Optimizer def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: @@ -72,7 +73,7 @@ def step(self, epoch: Optional[int] = ...) -> None: # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau: +class ReduceLROnPlateau(TorchReduceLROnPlateau): in_cooldown: bool optimizer: Optimizer From 5b4af6b8c783305571e778596e151030d1dfd3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Dec 2021 08:26:25 +0100 Subject: [PATCH 12/17] Revert "try to fix types for cli" This reverts commit 33544ef366bdca7f8dc1d2006c719c550f44529b. --- pytorch_lightning/utilities/types.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 899314763ecbe..e1af694b34389 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -21,7 +21,6 @@ import torch from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as _TorchLRScheduler, ReduceLROnPlateau as TorchReduceLROnPlateau from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import TypedDict @@ -49,7 +48,7 @@ # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class _LRScheduler(_TorchLRScheduler): +class _LRScheduler: optimizer: Optimizer def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: @@ -73,7 +72,7 @@ def step(self, epoch: Optional[int] = ...) -> None: # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau(TorchReduceLROnPlateau): +class ReduceLROnPlateau: in_cooldown: bool optimizer: Optimizer From 52b4d93b8fda0d6191206f6fa209ddf6b8761fa3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 6 Dec 2021 02:33:49 +0100 Subject: [PATCH 13/17] Fix CLI --- pytorch_lightning/utilities/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index e1af694b34389..904fbaec13a0f 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -102,9 +102,9 @@ def load_state_dict(self, state_dict: dict) -> None: # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] class LRSchedulerConfig(TypedDict): From 73c4fd1eb0a61fc4591639f1292d3701292b1675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Dec 2021 03:31:37 +0100 Subject: [PATCH 14/17] update types --- pytorch_lightning/utilities/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 904fbaec13a0f..44a3b88d530d6 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -108,7 +108,7 @@ def load_state_dict(self, state_dict: dict) -> None: class LRSchedulerConfig(TypedDict): - scheduler: LRSchedulerTypeUnion + scheduler: Union[_LRScheduler, ReduceLROnPlateau] name: Optional[str] interval: str frequency: int From 732e3be4e9a30b02a113b8ae62185059b2a1dc65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Dec 2021 03:31:46 +0100 Subject: [PATCH 15/17] assertions --- pytorch_lightning/callbacks/lr_monitor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index e69023a0e61f0..9d4e70fb880ce 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -146,6 +146,7 @@ def _check_no_key(key: str) -> bool: self.last_momentum_values = {name + "-momentum": None for name in names_flatten} def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + assert trainer.logger is not None if not trainer.logger_connector.should_update_logs: return @@ -157,6 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + assert trainer.logger is not None if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) From 17df0dcc188ac31995e7632c34abcc5fc9ccbcb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Dec 2021 02:33:02 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index eaca4b53d3a13..3a9a9aad7c8e2 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT, LRSchedulerConfig +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT TBroadcast = TypeVar("TBroadcast") From e3b65b56f98f98c8a702f27570fb25f6431226fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 22 Dec 2021 00:11:02 +0100 Subject: [PATCH 17/17] Make scheduler config a dataclass --- pytorch_lightning/callbacks/lr_monitor.py | 6 ++--- .../callbacks/stochastic_weight_avg.py | 11 ++++---- .../plugins/training_type/deepspeed.py | 11 ++++---- pytorch_lightning/trainer/optimizers.py | 27 +++++-------------- pytorch_lightning/tuner/lr_finder.py | 6 ++--- pytorch_lightning/utilities/types.py | 18 +++++++------ 6 files changed, 33 insertions(+), 46 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 9d4e70fb880ce..08a5a4d6da9d6 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -112,7 +112,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No def _check_no_key(key: str) -> bool: if trainer.lr_schedulers: - return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers) + return any(key not in sch.scheduler.optimizer.defaults for sch in trainer.lr_schedulers) return any(key not in optimizer.defaults for optimizer in trainer.optimizers) @@ -177,8 +177,8 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa self._remap_keys(scheduler_hparam_keys) for name, scheduler in zip(scheduler_hparam_keys, trainer.lr_schedulers): - if interval in [scheduler["interval"], "any"]: - opt = scheduler["scheduler"].optimizer + if interval in [scheduler.interval, "any"]: + opt = scheduler.scheduler.optimizer current_stat = self._get_lr_momentum_stat(opt, name) latest_stat.update(current_stat) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index bde9c1b5c2407..8d913a2469bce 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,9 +24,9 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] @@ -182,16 +182,15 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ) - default_scheduler_cfg = _get_default_scheduler_config() - assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1 - default_scheduler_cfg["scheduler"] = self._swa_scheduler + default_scheduler_cfg = LRSchedulerConfig(scheduler=self._swa_scheduler) + assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 if trainer.lr_schedulers: scheduler_cfg = trainer.lr_schedulers[0] - if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1: + if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") rank_zero_info( - f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`" + f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" f" for `{self._swa_scheduler.__class__.__name__}`" ) trainer.lr_schedulers[0] = default_scheduler_cfg diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index c442cee3823ec..78be8824b371f 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -32,7 +32,6 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -42,7 +41,7 @@ from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT, LRSchedulerConfig from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache warning_cache = WarningCache() @@ -456,7 +455,7 @@ def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTup ) return ( optimizers[0], - schedulers[0] if schedulers else _get_default_scheduler_config(), + schedulers[0] if schedulers else LRSchedulerConfig(scheduler=None), # TODO: fix type optimizer_frequencies[0] if optimizer_frequencies else None, ) @@ -466,7 +465,7 @@ def zero_stage_3(self) -> bool: def _initialize_deepspeed_train(self, model): if "optimizer" in self.config: - optimizer, lr_scheduler = None, _get_default_scheduler_config() + optimizer, lr_scheduler = None, LRSchedulerConfig() else: rank_zero_info( "You have not specified an optimizer or scheduler within the DeepSpeed config." @@ -474,7 +473,7 @@ def _initialize_deepspeed_train(self, model): ) optimizer, lr_scheduler, _ = self._init_optimizers() - scheduler = lr_scheduler["scheduler"] + scheduler = lr_scheduler.scheduler model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler) self._set_deepspeed_activation_checkpointing() @@ -485,7 +484,7 @@ def _initialize_deepspeed_train(self, model): if deepspeed_scheduler is not None: # disable deepspeed lr scheduling as lightning manages scheduling model.lr_scheduler = None - lr_scheduler["scheduler"] = deepspeed_scheduler + lr_scheduler.scheduler = deepspeed_scheduler self.lightning_module.trainer.lr_schedulers = [lr_scheduler] self.model = model diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 6635dd8bb24e6..ef91fcca38cf5 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC +from dataclasses import fields from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -126,7 +127,6 @@ def _configure_schedulers( ) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information.""" lr_schedulers = [] - default_config = _get_default_scheduler_config() for scheduler in schedulers: if is_manual_optimization: if isinstance(scheduler, dict): @@ -141,13 +141,13 @@ def _configure_schedulers( ) scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys} - lr_schedulers.append({**default_config, **scheduler}) + lr_schedulers.append(LRSchedulerConfig(**scheduler)) else: - lr_schedulers.append({**default_config, "scheduler": scheduler}) + lr_schedulers.append(LRSchedulerConfig(scheduler=scheduler)) else: if isinstance(scheduler, dict): # check provided keys - extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] + extra_keys = [k for k in scheduler.keys() if k not in fields(LRSchedulerConfig)] if extra_keys: rank_zero_warn( f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning @@ -177,7 +177,7 @@ def _configure_schedulers( " Are you sure you didn't mean 'interval': 'step'?", category=RuntimeWarning, ) - lr_schedulers.append({**default_config, **scheduler}) + lr_schedulers.append(LRSchedulerConfig(**scheduler)) elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): if monitor is None: raise MisconfigurationException( @@ -186,10 +186,10 @@ def _configure_schedulers( ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' ) lr_schedulers.append( - {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} + LRSchedulerConfig(scheduler=scheduler, reduce_on_plateau=True, monitor=monitor) ) elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): - lr_schedulers.append({**default_config, "scheduler": scheduler}) + lr_schedulers.append(LRSchedulerConfig(scheduler=scheduler)) else: raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') return lr_schedulers @@ -236,16 +236,3 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers): raise MisconfigurationException( "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." ) - - -def _get_default_scheduler_config() -> Dict[str, Any]: - return { - "scheduler": None, - "name": None, # no custom name - "interval": "epoch", # after epoch is over - "frequency": 1, # every epoch/batch - "reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler - "monitor": None, # value to monitor for ReduceLROnPlateau - "strict": True, # enforce that the monitor exists for ReduceLROnPlateau - "opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified - } diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index ba4c737ed049a..bcaefc110d2b6 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -25,7 +25,6 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -33,6 +32,8 @@ # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed +from pytorch_lightning.utilities.types import LRSchedulerConfig + if importlib.util.find_spec("ipywidgets") is not None: from tqdm.auto import tqdm else: @@ -123,8 +124,7 @@ def func(model): args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - sched_config = _get_default_scheduler_config() - sched_config.update({"scheduler": scheduler, "interval": "step"}) + sched_config = LRSchedulerConfig(scheduler=scheduler, interval="step") return [optimizer], [sched_config], [] diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 44a3b88d530d6..292e3e3fe80a1 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -16,6 +16,7 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union @@ -107,12 +108,13 @@ def load_state_dict(self, state_dict: dict) -> None: LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] -class LRSchedulerConfig(TypedDict): +@dataclass +class LRSchedulerConfig: scheduler: Union[_LRScheduler, ReduceLROnPlateau] - name: Optional[str] - interval: str - frequency: int - reduce_on_plateau: bool - monitor: Optional[str] - strict: bool - opt_idx: Optional[int] + name: Optional[str] = None # no custom name + interval: str = "epoch" # after epoch is over + frequency: int = 1 # every epoch/batch + reduce_on_plateau: bool = False # most often not ReduceLROnPlateau scheduler + monitor: Optional[str] = None # value to monitor for ReduceLROnPlateau + strict: bool = True # enforce that the monitor exists for ReduceLROnPlateau + opt_idx: Optional[int] = None # necessary to store opt_idx when optimizer frequencies are specified