Skip to content

Commit b6dd1a3

Browse files
awaelchlicarmocca
andauthored
Fix typing in pl.callbacks.lr_monitor (#10802)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent ba8e7cd commit b6dd1a3

File tree

6 files changed

+85
-15
lines changed

6 files changed

+85
-15
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ warn_no_return = "False"
4444
module = [
4545
"pytorch_lightning.accelerators.gpu",
4646
"pytorch_lightning.callbacks.finetuning",
47-
"pytorch_lightning.callbacks.lr_monitor",
4847
"pytorch_lightning.callbacks.model_checkpoint",
4948
"pytorch_lightning.callbacks.progress.base",
5049
"pytorch_lightning.callbacks.progress.progress",

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def configure_optimizer(self):
8686
8787
"""
8888

89-
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
89+
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None:
9090
if logging_interval not in (None, "step", "epoch"):
9191
raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.")
9292

@@ -146,6 +146,7 @@ def _check_no_key(key: str) -> bool:
146146
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
147147

148148
def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
149+
assert trainer.logger is not None
149150
if not trainer.logger_connector.should_update_logs:
150151
return
151152

@@ -157,6 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
157158
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
158159

159160
def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
161+
assert trainer.logger is not None
160162
if self.logging_interval != "step":
161163
interval = "epoch" if self.logging_interval is None else "any"
162164
latest_stat = self._extract_stats(trainer, interval)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch import Tensor
2020
from torch.nn import Module
2121
from torch.optim import Optimizer
22-
from torch.optim.lr_scheduler import _LRScheduler
2322
from torch.utils.data import DataLoader
2423

2524
import pytorch_lightning as pl
@@ -32,7 +31,7 @@
3231
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
3332
from pytorch_lightning.utilities.distributed import ReduceOp
3433
from pytorch_lightning.utilities.model_helpers import is_overridden
35-
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
34+
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT
3635

3736
TBroadcast = TypeVar("TBroadcast")
3837

@@ -52,7 +51,7 @@ def __init__(
5251
self.checkpoint_io = checkpoint_io
5352
self.precision_plugin = precision_plugin
5453
self.optimizers: List[Optimizer] = []
55-
self.lr_schedulers: List[_LRScheduler] = []
54+
self.lr_schedulers: List[LRSchedulerConfig] = []
5655
self.optimizer_frequencies: List[int] = []
5756
if is_overridden("post_dispatch", self, parent=Strategy):
5857
rank_zero_deprecation(

pytorch_lightning/trainer/optimizers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.core.optimizer import LightningOptimizer
2424
from pytorch_lightning.utilities import rank_zero_warn
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
26+
from pytorch_lightning.utilities.types import LRSchedulerConfig
2627

2728

2829
class TrainerOptimizersMixin(ABC):
@@ -122,7 +123,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer):
122123
@staticmethod
123124
def _configure_schedulers(
124125
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
125-
) -> List[Dict[str, Any]]:
126+
) -> List[LRSchedulerConfig]:
126127
"""Convert each scheduler into dict structure with relevant information."""
127128
lr_schedulers = []
128129
default_config = _get_default_scheduler_config()

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
_PATH,
108108
_PREDICT_OUTPUT,
109109
EVAL_DATALOADERS,
110-
LRSchedulerTypeUnion,
110+
LRSchedulerConfig,
111111
STEP_OUTPUT,
112112
TRAIN_DATALOADERS,
113113
)
@@ -1839,11 +1839,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
18391839
self.strategy.optimizers = new_optims
18401840

18411841
@property
1842-
def lr_schedulers(self) -> List[LRSchedulerTypeUnion]:
1842+
def lr_schedulers(self) -> List[LRSchedulerConfig]:
18431843
return self.strategy.lr_schedulers
18441844

18451845
@lr_schedulers.setter
1846-
def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None:
1846+
def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None:
18471847
self.strategy.lr_schedulers = new_schedulers
18481848

18491849
@property

pytorch_lightning/utilities/types.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
"""
1515
Convention:
1616
- Do not include any `_TYPE` suffix
17-
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`)
17+
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
1818
"""
1919
from pathlib import Path
20-
from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union
20+
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union
2121

2222
import torch
23-
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
23+
from torch.optim import Optimizer
2424
from torch.utils.data import DataLoader
2525
from torchmetrics import Metric
26+
from typing_extensions import TypedDict
2627

2728
_NUMBER = Union[int, float]
2829
_METRIC = Union[Metric, torch.Tensor, _NUMBER]
@@ -43,7 +44,75 @@
4344
Dict[str, Sequence[DataLoader]],
4445
]
4546
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
47+
48+
49+
# Copied from `torch.optim.lr_scheduler.pyi`
50+
# Missing attributes were added to improve typing
51+
class _LRScheduler:
52+
optimizer: Optimizer
53+
54+
def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None:
55+
...
56+
57+
def state_dict(self) -> dict:
58+
...
59+
60+
def load_state_dict(self, state_dict: dict) -> None:
61+
...
62+
63+
def get_last_lr(self) -> List[float]:
64+
...
65+
66+
def get_lr(self) -> float:
67+
...
68+
69+
def step(self, epoch: Optional[int] = ...) -> None:
70+
...
71+
72+
73+
# Copied from `torch.optim.lr_scheduler.pyi`
74+
# Missing attributes were added to improve typing
75+
class ReduceLROnPlateau:
76+
in_cooldown: bool
77+
optimizer: Optimizer
78+
79+
def __init__(
80+
self,
81+
optimizer: Optimizer,
82+
mode: str = ...,
83+
factor: float = ...,
84+
patience: int = ...,
85+
verbose: bool = ...,
86+
threshold: float = ...,
87+
threshold_mode: str = ...,
88+
cooldown: int = ...,
89+
min_lr: float = ...,
90+
eps: float = ...,
91+
) -> None:
92+
...
93+
94+
def step(self, metrics: Any, epoch: Optional[int] = ...) -> None:
95+
...
96+
97+
def state_dict(self) -> dict:
98+
...
99+
100+
def load_state_dict(self, state_dict: dict) -> None:
101+
...
102+
103+
46104
# todo: improve LRSchedulerType naming/typing
47-
LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau)
48-
LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau]
49-
LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]]
105+
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
106+
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
107+
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
108+
109+
110+
class LRSchedulerConfig(TypedDict):
111+
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
112+
name: Optional[str]
113+
interval: str
114+
frequency: int
115+
reduce_on_plateau: bool
116+
monitor: Optional[str]
117+
strict: bool
118+
opt_idx: Optional[int]

0 commit comments

Comments
 (0)