1616import os
1717import uuid
1818from functools import wraps
19- from typing import Any , Dict , Optional , Sequence
19+ from typing import Any , Callable , cast , Dict , List , Optional , Sequence , TYPE_CHECKING , Union
2020
2121import numpy as np
2222import torch
2727from pytorch_lightning .core .optimizer import _init_optimizers_and_lr_schedulers , _set_scheduler_opt_idx
2828from pytorch_lightning .loggers .logger import DummyLogger
2929from pytorch_lightning .utilities .exceptions import MisconfigurationException
30+ from pytorch_lightning .utilities .imports import _RequirementAvailable
3031from pytorch_lightning .utilities .parsing import lightning_hasattr , lightning_setattr
3132from pytorch_lightning .utilities .rank_zero import rank_zero_warn
32- from pytorch_lightning .utilities .types import LRSchedulerConfig
33+ from pytorch_lightning .utilities .types import LRSchedulerConfig , STEP_OUTPUT
3334
3435# check if ipywidgets is installed before importing tqdm.auto
3536# to ensure it won't fail and a progress bar is displayed
3839else :
3940 from tqdm import tqdm
4041
42+ _MATPLOTLIB_AVAILABLE : bool = _RequirementAvailable ("matplotlib" ) # type: ignore[assignment]
43+ if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING :
44+ import matplotlib .pyplot as plt
45+
4146log = logging .getLogger (__name__ )
4247
4348
@@ -95,16 +100,16 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
95100 self .lr_max = lr_max
96101 self .num_training = num_training
97102
98- self .results = {}
103+ self .results : Dict [ str , Any ] = {}
99104 self ._total_batch_idx = 0 # for debug purpose
100105
101- def _exchange_scheduler (self , trainer : "pl.Trainer" , model : "pl.LightningModule" ):
106+ def _exchange_scheduler (self , trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> Callable [[ "pl.Trainer" ], None ] :
102107 """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
103108 optimizer together with a new scheduler that takes care of the learning rate search."""
104109 setup_optimizers = trainer .strategy .setup_optimizers
105110
106111 @wraps (setup_optimizers )
107- def func (trainer ) :
112+ def func (trainer : "pl.Trainer" ) -> None :
108113 # Decide the structure of the output from _init_optimizers_and_lr_schedulers
109114 optimizers , _ , _ = _init_optimizers_and_lr_schedulers (trainer .lightning_module )
110115
@@ -123,6 +128,7 @@ def func(trainer):
123128
124129 args = (optimizer , self .lr_max , self .num_training )
125130 scheduler = _LinearLR (* args ) if self .mode == "linear" else _ExponentialLR (* args )
131+ scheduler = cast (pl .utilities .types ._LRScheduler , scheduler )
126132
127133 trainer .strategy .optimizers = [optimizer ]
128134 trainer .strategy .lr_scheduler_configs = [LRSchedulerConfig (scheduler , interval = "step" , opt_idx = 0 )]
@@ -131,13 +137,18 @@ def func(trainer):
131137
132138 return func
133139
134- def plot (self , suggest : bool = False , show : bool = False ):
140+ def plot (self , suggest : bool = False , show : bool = False ) -> Optional [ "plt.Figure" ] :
135141 """Plot results from lr_find run
136142 Args:
137143 suggest: if True, will mark suggested lr to use with a red point
138144
139145 show: if True, will show figure
140146 """
147+ if not _MATPLOTLIB_AVAILABLE :
148+ raise MisconfigurationException (
149+ "To use the `plot` method, you must have Matplotlib installed."
150+ " Install it by running `pip install -U matplotlib`."
151+ )
141152 import matplotlib .pyplot as plt
142153
143154 lrs = self .results ["lr" ]
@@ -162,7 +173,7 @@ def plot(self, suggest: bool = False, show: bool = False):
162173
163174 return fig
164175
165- def suggestion (self , skip_begin : int = 10 , skip_end : int = 1 ):
176+ def suggestion (self , skip_begin : int = 10 , skip_end : int = 1 ) -> Optional [ float ] :
166177 """This will propose a suggestion for choice of initial learning rate as the point with the steepest
167178 negative gradient.
168179
@@ -196,7 +207,7 @@ def lr_find(
196207 """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
197208 if trainer .fast_dev_run :
198209 rank_zero_warn ("Skipping learning rate finder since fast_dev_run is enabled." )
199- return
210+ return None
200211
201212 # Determine lr attr
202213 if update_attr :
@@ -218,7 +229,7 @@ def lr_find(
218229 trainer .progress_bar_callback .disable ()
219230
220231 # Configure optimizer and scheduler
221- trainer .strategy .setup_optimizers = lr_finder ._exchange_scheduler (trainer , model )
232+ trainer .strategy .setup_optimizers = lr_finder ._exchange_scheduler (trainer , model ) # type: ignore[assignment]
222233
223234 # Fit, lr & loss logged in callback
224235 trainer .tuner ._run (model )
@@ -304,24 +315,28 @@ def __init__(
304315 self .num_training = num_training
305316 self .early_stop_threshold = early_stop_threshold
306317 self .beta = beta
307- self .losses = []
308- self .lrs = []
318+ self .losses : List [ float ] = []
319+ self .lrs : List [ float ] = []
309320 self .avg_loss = 0.0
310321 self .best_loss = 0.0
311322 self .progress_bar_refresh_rate = progress_bar_refresh_rate
312323 self .progress_bar = None
313324
314- def on_train_batch_start (self , trainer , pl_module , batch , batch_idx ):
325+ def on_train_batch_start (
326+ self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int
327+ ) -> None :
315328 """Called before each training batch, logs the lr that will be used."""
316329 if (trainer .fit_loop .batch_idx + 1 ) % trainer .accumulate_grad_batches != 0 :
317330 return
318331
319332 if self .progress_bar_refresh_rate and self .progress_bar is None :
320333 self .progress_bar = tqdm (desc = "Finding best initial lr" , total = self .num_training )
321334
322- self .lrs .append (trainer .lr_scheduler_configs [0 ].scheduler .lr [0 ])
335+ self .lrs .append (trainer .lr_scheduler_configs [0 ].scheduler .lr [0 ]) # type: ignore[union-attr]
323336
324- def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
337+ def on_train_batch_end (
338+ self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , outputs : STEP_OUTPUT , batch : Any , batch_idx : int
339+ ) -> None :
325340 """Called when the training batch ends, logs the calculated loss."""
326341 if (trainer .fit_loop .batch_idx + 1 ) % trainer .accumulate_grad_batches != 0 :
327342 return
@@ -372,7 +387,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
372387 self .num_iter = num_iter
373388 super ().__init__ (optimizer , last_epoch )
374389
375- def get_lr (self ):
390+ def get_lr (self ) -> List [ float ]: # type: ignore[override]
376391 curr_iter = self .last_epoch + 1
377392 r = curr_iter / self .num_iter
378393
@@ -384,7 +399,7 @@ def get_lr(self):
384399 return val
385400
386401 @property
387- def lr (self ):
402+ def lr (self ) -> Union [ float , List [ float ]] :
388403 return self ._lr
389404
390405
@@ -410,7 +425,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
410425 self .num_iter = num_iter
411426 super ().__init__ (optimizer , last_epoch )
412427
413- def get_lr (self ):
428+ def get_lr (self ) -> List [ float ]: # type: ignore[override]
414429 curr_iter = self .last_epoch + 1
415430 r = curr_iter / self .num_iter
416431
@@ -422,5 +437,5 @@ def get_lr(self):
422437 return val
423438
424439 @property
425- def lr (self ):
440+ def lr (self ) -> Union [ float , List [ float ]] :
426441 return self ._lr
0 commit comments