diff --git a/pyproject.toml b/pyproject.toml index d9e877c45e817..0e627f1071cb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ warn_no_return = "False" module = [ "pytorch_lightning.accelerators.gpu", "pytorch_lightning.callbacks.finetuning", - "pytorch_lightning.callbacks.lr_monitor", "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.progress.base", "pytorch_lightning.callbacks.progress.progress", diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index dd86692ac9aec..9d4e70fb880ce 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -86,7 +86,7 @@ def configure_optimizer(self): """ - def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False): + def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None: if logging_interval not in (None, "step", "epoch"): raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.") @@ -146,6 +146,7 @@ def _check_no_key(key: str) -> bool: self.last_momentum_values = {name + "-momentum": None for name in names_flatten} def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + assert trainer.logger is not None if not trainer.logger_connector.should_update_logs: return @@ -157,6 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + assert trainer.logger is not None if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 77611233e79d8..dfd829ad1037d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -19,7 +19,6 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -32,7 +31,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT TBroadcast = TypeVar("TBroadcast") @@ -52,7 +51,7 @@ def __init__( self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin self.optimizers: List[Optimizer] = [] - self.lr_schedulers: List[_LRScheduler] = [] + self.lr_schedulers: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] if is_overridden("post_dispatch", self, parent=Strategy): rank_zero_deprecation( diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 32031d23ea3b3..6635dd8bb24e6 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -23,6 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig class TrainerOptimizersMixin(ABC): @@ -122,7 +123,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer): @staticmethod def _configure_schedulers( schedulers: list, monitor: Optional[str], is_manual_optimization: bool - ) -> List[Dict[str, Any]]: + ) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information.""" lr_schedulers = [] default_config = _get_default_scheduler_config() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8f586afa6bd03..9523e4a15a229 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -107,7 +107,7 @@ _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, - LRSchedulerTypeUnion, + LRSchedulerConfig, STEP_OUTPUT, TRAIN_DATALOADERS, ) @@ -1839,11 +1839,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.strategy.optimizers = new_optims @property - def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: + def lr_schedulers(self) -> List[LRSchedulerConfig]: return self.strategy.lr_schedulers @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: + def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None: self.strategy.lr_schedulers = new_schedulers @property diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 86c9028a5ff1f..44a3b88d530d6 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -14,15 +14,16 @@ """ Convention: - Do not include any `_TYPE` suffix - - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) + - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union import torch -from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau +from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric +from typing_extensions import TypedDict _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] @@ -43,7 +44,75 @@ Dict[str, Sequence[DataLoader]], ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] + + +# Copied from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class _LRScheduler: + optimizer: Optimizer + + def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: + ... + + def state_dict(self) -> dict: + ... + + def load_state_dict(self, state_dict: dict) -> None: + ... + + def get_last_lr(self) -> List[float]: + ... + + def get_lr(self) -> float: + ... + + def step(self, epoch: Optional[int] = ...) -> None: + ... + + +# Copied from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class ReduceLROnPlateau: + in_cooldown: bool + optimizer: Optimizer + + def __init__( + self, + optimizer: Optimizer, + mode: str = ..., + factor: float = ..., + patience: int = ..., + verbose: bool = ..., + threshold: float = ..., + threshold_mode: str = ..., + cooldown: int = ..., + min_lr: float = ..., + eps: float = ..., + ) -> None: + ... + + def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: + ... + + def state_dict(self) -> dict: + ... + + def load_state_dict(self, state_dict: dict) -> None: + ... + + # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] + + +class LRSchedulerConfig(TypedDict): + scheduler: Union[_LRScheduler, ReduceLROnPlateau] + name: Optional[str] + interval: str + frequency: int + reduce_on_plateau: bool + monitor: Optional[str] + strict: bool + opt_idx: Optional[int]