Skip to content

Commit b59f802

Browse files
donlaparkakihironittajustusschock
authored
fix mypy typing errors in pytorch_lightning/tuner/lr_finder.py (#13513)
Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Justus Schock <[email protected]>
1 parent 61c28cb commit b59f802

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ module = [
8585
"pytorch_lightning.trainer.supporters",
8686
"pytorch_lightning.trainer.trainer",
8787
"pytorch_lightning.tuner.batch_size_scaling",
88-
"pytorch_lightning.tuner.lr_finder",
8988
"pytorch_lightning.tuner.tuning",
9089
"pytorch_lightning.utilities.auto_restart",
9190
"pytorch_lightning.utilities.data",

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import uuid
1818
from functools import wraps
19-
from typing import Any, Dict, Optional, Sequence
19+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TYPE_CHECKING, Union
2020

2121
import numpy as np
2222
import torch
@@ -27,9 +27,10 @@
2727
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx
2828
from pytorch_lightning.loggers.logger import DummyLogger
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
30+
from pytorch_lightning.utilities.imports import _RequirementAvailable
3031
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
3132
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
32-
from pytorch_lightning.utilities.types import LRSchedulerConfig
33+
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
3334

3435
# check if ipywidgets is installed before importing tqdm.auto
3536
# to ensure it won't fail and a progress bar is displayed
@@ -38,6 +39,10 @@
3839
else:
3940
from tqdm import tqdm
4041

42+
_MATPLOTLIB_AVAILABLE: bool = _RequirementAvailable("matplotlib") # type: ignore[assignment]
43+
if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING:
44+
import matplotlib.pyplot as plt
45+
4146
log = logging.getLogger(__name__)
4247

4348

@@ -95,16 +100,16 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
95100
self.lr_max = lr_max
96101
self.num_training = num_training
97102

98-
self.results = {}
103+
self.results: Dict[str, Any] = {}
99104
self._total_batch_idx = 0 # for debug purpose
100105

101-
def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
106+
def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Callable[["pl.Trainer"], None]:
102107
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
103108
optimizer together with a new scheduler that takes care of the learning rate search."""
104109
setup_optimizers = trainer.strategy.setup_optimizers
105110

106111
@wraps(setup_optimizers)
107-
def func(trainer):
112+
def func(trainer: "pl.Trainer") -> None:
108113
# Decide the structure of the output from _init_optimizers_and_lr_schedulers
109114
optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module)
110115

@@ -123,6 +128,7 @@ def func(trainer):
123128

124129
args = (optimizer, self.lr_max, self.num_training)
125130
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
131+
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
126132

127133
trainer.strategy.optimizers = [optimizer]
128134
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
@@ -131,13 +137,18 @@ def func(trainer):
131137

132138
return func
133139

134-
def plot(self, suggest: bool = False, show: bool = False):
140+
def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]:
135141
"""Plot results from lr_find run
136142
Args:
137143
suggest: if True, will mark suggested lr to use with a red point
138144
139145
show: if True, will show figure
140146
"""
147+
if not _MATPLOTLIB_AVAILABLE:
148+
raise MisconfigurationException(
149+
"To use the `plot` method, you must have Matplotlib installed."
150+
" Install it by running `pip install -U matplotlib`."
151+
)
141152
import matplotlib.pyplot as plt
142153

143154
lrs = self.results["lr"]
@@ -162,7 +173,7 @@ def plot(self, suggest: bool = False, show: bool = False):
162173

163174
return fig
164175

165-
def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
176+
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
166177
"""This will propose a suggestion for choice of initial learning rate as the point with the steepest
167178
negative gradient.
168179
@@ -196,7 +207,7 @@ def lr_find(
196207
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
197208
if trainer.fast_dev_run:
198209
rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.")
199-
return
210+
return None
200211

201212
# Determine lr attr
202213
if update_attr:
@@ -218,7 +229,7 @@ def lr_find(
218229
trainer.progress_bar_callback.disable()
219230

220231
# Configure optimizer and scheduler
221-
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)
232+
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore[assignment]
222233

223234
# Fit, lr & loss logged in callback
224235
trainer.tuner._run(model)
@@ -304,24 +315,28 @@ def __init__(
304315
self.num_training = num_training
305316
self.early_stop_threshold = early_stop_threshold
306317
self.beta = beta
307-
self.losses = []
308-
self.lrs = []
318+
self.losses: List[float] = []
319+
self.lrs: List[float] = []
309320
self.avg_loss = 0.0
310321
self.best_loss = 0.0
311322
self.progress_bar_refresh_rate = progress_bar_refresh_rate
312323
self.progress_bar = None
313324

314-
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
325+
def on_train_batch_start(
326+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
327+
) -> None:
315328
"""Called before each training batch, logs the lr that will be used."""
316329
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
317330
return
318331

319332
if self.progress_bar_refresh_rate and self.progress_bar is None:
320333
self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training)
321334

322-
self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0])
335+
self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr]
323336

324-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
337+
def on_train_batch_end(
338+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
339+
) -> None:
325340
"""Called when the training batch ends, logs the calculated loss."""
326341
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
327342
return
@@ -372,7 +387,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
372387
self.num_iter = num_iter
373388
super().__init__(optimizer, last_epoch)
374389

375-
def get_lr(self):
390+
def get_lr(self) -> List[float]: # type: ignore[override]
376391
curr_iter = self.last_epoch + 1
377392
r = curr_iter / self.num_iter
378393

@@ -384,7 +399,7 @@ def get_lr(self):
384399
return val
385400

386401
@property
387-
def lr(self):
402+
def lr(self) -> Union[float, List[float]]:
388403
return self._lr
389404

390405

@@ -410,7 +425,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
410425
self.num_iter = num_iter
411426
super().__init__(optimizer, last_epoch)
412427

413-
def get_lr(self):
428+
def get_lr(self) -> List[float]: # type: ignore[override]
414429
curr_iter = self.last_epoch + 1
415430
r = curr_iter / self.num_iter
416431

@@ -422,5 +437,5 @@ def get_lr(self):
422437
return val
423438

424439
@property
425-
def lr(self):
440+
def lr(self) -> Union[float, List[float]]:
426441
return self._lr

0 commit comments

Comments
 (0)