Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.accelerators.gpu",
"pytorch_lightning.callbacks.finetuning",
"pytorch_lightning.callbacks.lr_monitor",
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.progress.base",
"pytorch_lightning.callbacks.progress.progress",
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def configure_optimizer(self):

"""

def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None:
if logging_interval not in (None, "step", "epoch"):
raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.")

Expand All @@ -112,7 +112,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No

def _check_no_key(key: str) -> bool:
if trainer.lr_schedulers:
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)
return any(key not in sch.scheduler.optimizer.defaults for sch in trainer.lr_schedulers)

return any(key not in optimizer.defaults for optimizer in trainer.optimizers)

Expand Down Expand Up @@ -146,6 +146,7 @@ def _check_no_key(key: str) -> bool:
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if not trainer.logger_connector.should_update_logs:
return

Expand All @@ -157,6 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if self.logging_interval != "step":
interval = "epoch" if self.logging_interval is None else "any"
latest_stat = self._extract_stats(trainer, interval)
Expand All @@ -175,8 +177,8 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
self._remap_keys(scheduler_hparam_keys)

for name, scheduler in zip(scheduler_hparam_keys, trainer.lr_schedulers):
if interval in [scheduler["interval"], "any"]:
opt = scheduler["scheduler"].optimizer
if interval in [scheduler.interval, "any"]:
opt = scheduler.scheduler.optimizer
current_stat = self._get_lr_momentum_stat(opt, name)
latest_stat.update(current_stat)

Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import LRSchedulerConfig

_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]

Expand Down Expand Up @@ -182,16 +182,15 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
)
default_scheduler_cfg = _get_default_scheduler_config()
assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1
default_scheduler_cfg["scheduler"] = self._swa_scheduler
default_scheduler_cfg = LRSchedulerConfig(scheduler=self._swa_scheduler)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1

if trainer.lr_schedulers:
scheduler_cfg = trainer.lr_schedulers[0]
if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1:
if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1:
rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
rank_zero_info(
f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`"
f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"
f" for `{self._swa_scheduler.__class__.__name__}`"
)
trainer.lr_schedulers[0] = default_scheduler_cfg
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -42,7 +41,7 @@
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT, LRSchedulerConfig
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -456,7 +455,7 @@ def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTup
)
return (
optimizers[0],
schedulers[0] if schedulers else _get_default_scheduler_config(),
schedulers[0] if schedulers else LRSchedulerConfig(scheduler=None), # TODO: fix type
optimizer_frequencies[0] if optimizer_frequencies else None,
)

Expand All @@ -466,15 +465,15 @@ def zero_stage_3(self) -> bool:

def _initialize_deepspeed_train(self, model):
if "optimizer" in self.config:
optimizer, lr_scheduler = None, _get_default_scheduler_config()
optimizer, lr_scheduler = None, LRSchedulerConfig()
else:
rank_zero_info(
"You have not specified an optimizer or scheduler within the DeepSpeed config."
" Using `configure_optimizers` to define optimizer and scheduler."
)
optimizer, lr_scheduler, _ = self._init_optimizers()

scheduler = lr_scheduler["scheduler"]
scheduler = lr_scheduler.scheduler
model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
self._set_deepspeed_activation_checkpointing()

Expand All @@ -485,7 +484,7 @@ def _initialize_deepspeed_train(self, model):
if deepspeed_scheduler is not None:
# disable deepspeed lr scheduling as lightning manages scheduling
model.lr_scheduler = None
lr_scheduler["scheduler"] = deepspeed_scheduler
lr_scheduler.scheduler = deepspeed_scheduler
self.lightning_module.trainer.lr_schedulers = [lr_scheduler]
self.model = model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

import pytorch_lightning as pl
Expand All @@ -32,7 +31,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")

Expand All @@ -52,7 +51,7 @@ def __init__(
self.checkpoint_io = checkpoint_io
self.precision_plugin = precision_plugin
self.optimizers: List[Optimizer] = []
self.lr_schedulers: List[_LRScheduler] = []
self.lr_schedulers: List[LRSchedulerConfig] = []
self.optimizer_frequencies: List[int] = []
if is_overridden("post_dispatch", self, parent=Strategy):
rank_zero_deprecation(
Expand Down
30 changes: 9 additions & 21 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from abc import ABC
from dataclasses import fields
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -23,6 +24,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import LRSchedulerConfig


class TrainerOptimizersMixin(ABC):
Expand Down Expand Up @@ -122,10 +124,9 @@ def _convert_to_lightning_optimizer(trainer, optimizer):
@staticmethod
def _configure_schedulers(
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
) -> List[Dict[str, Any]]:
) -> List[LRSchedulerConfig]:
"""Convert each scheduler into dict structure with relevant information."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if is_manual_optimization:
if isinstance(scheduler, dict):
Expand All @@ -140,13 +141,13 @@ def _configure_schedulers(
)

scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
lr_schedulers.append(LRSchedulerConfig(**scheduler))
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})
lr_schedulers.append(LRSchedulerConfig(scheduler=scheduler))
else:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
extra_keys = [k for k in scheduler.keys() if k not in fields(LRSchedulerConfig)]
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
Expand Down Expand Up @@ -176,7 +177,7 @@ def _configure_schedulers(
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
lr_schedulers.append({**default_config, **scheduler})
lr_schedulers.append(LRSchedulerConfig(**scheduler))
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
Expand All @@ -185,10 +186,10 @@ def _configure_schedulers(
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
LRSchedulerConfig(scheduler=scheduler, reduce_on_plateau=True, monitor=monitor)
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
lr_schedulers.append(LRSchedulerConfig(scheduler=scheduler))
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers
Expand Down Expand Up @@ -235,16 +236,3 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)


def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
"name": None, # no custom name
"interval": "epoch", # after epoch is over
"frequency": 1, # every epoch/batch
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
}
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
_PATH,
_PREDICT_OUTPUT,
EVAL_DATALOADERS,
LRSchedulerTypeUnion,
LRSchedulerConfig,
TRAIN_DATALOADERS,
)
from pytorch_lightning.utilities.warnings import PossibleUserWarning
Expand Down Expand Up @@ -1763,11 +1763,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
self.training_type_plugin.optimizers = new_optims

@property
def lr_schedulers(self) -> List[LRSchedulerTypeUnion]:
def lr_schedulers(self) -> List[LRSchedulerConfig]:
return self.training_type_plugin.lr_schedulers

@lr_schedulers.setter
def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None:
def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None:
self.training_type_plugin.lr_schedulers = new_schedulers

@property
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
from pytorch_lightning.utilities.types import LRSchedulerConfig

if importlib.util.find_spec("ipywidgets") is not None:
from tqdm.auto import tqdm
else:
Expand Down Expand Up @@ -123,8 +124,7 @@ def func(model):

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
sched_config = _get_default_scheduler_config()
sched_config.update({"scheduler": scheduler, "interval": "step"})
sched_config = LRSchedulerConfig(scheduler=scheduler, interval="step")

return [optimizer], [sched_config], []

Expand Down
83 changes: 77 additions & 6 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
"""
Convention:
- Do not include any `_TYPE` suffix
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`)
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import Metric
from typing_extensions import TypedDict

_NUMBER = Union[int, float]
_METRIC = Union[Metric, torch.Tensor, _NUMBER]
Expand All @@ -43,7 +45,76 @@
Dict[str, Sequence[DataLoader]],
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class _LRScheduler:
optimizer: Optimizer

def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None:
...

def state_dict(self) -> dict:
...

def load_state_dict(self, state_dict: dict) -> None:
...

def get_last_lr(self) -> List[float]:
...

def get_lr(self) -> float:
...

def step(self, epoch: Optional[int] = ...) -> None:
...


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class ReduceLROnPlateau:
in_cooldown: bool
optimizer: Optimizer

def __init__(
self,
optimizer: Optimizer,
mode: str = ...,
factor: float = ...,
patience: int = ...,
verbose: bool = ...,
threshold: float = ...,
threshold_mode: str = ...,
cooldown: int = ...,
min_lr: float = ...,
eps: float = ...,
) -> None:
...

def step(self, metrics: Any, epoch: Optional[int] = ...) -> None:
...

def state_dict(self) -> dict:
...

def load_state_dict(self, state_dict: dict) -> None:
...


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]]
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]


@dataclass
class LRSchedulerConfig:
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
name: Optional[str] = None # no custom name
interval: str = "epoch" # after epoch is over
frequency: int = 1 # every epoch/batch
reduce_on_plateau: bool = False # most often not ReduceLROnPlateau scheduler
monitor: Optional[str] = None # value to monitor for ReduceLROnPlateau
strict: bool = True # enforce that the monitor exists for ReduceLROnPlateau
opt_idx: Optional[int] = None # necessary to store opt_idx when optimizer frequencies are specified