From 377eee6bc790ba72de3014a11411bfac1445aad7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 16:36:18 +0100 Subject: [PATCH 01/19] Use a dataclass for the scheduler config --- pytorch_lightning/callbacks/lr_monitor.py | 8 +-- .../callbacks/stochastic_weight_avg.py | 11 ++- pytorch_lightning/core/lightning.py | 12 +++- pytorch_lightning/core/optimizer.py | 72 +++++++------------ .../loops/epoch/training_epoch_loop.py | 16 ++--- pytorch_lightning/strategies/deepspeed.py | 34 +++++---- .../connectors/checkpoint_connector.py | 8 +-- pytorch_lightning/trainer/trainer.py | 10 +-- pytorch_lightning/tuner/lr_finder.py | 14 ++-- pytorch_lightning/utilities/types.py | 27 ++++--- tests/trainer/optimization/test_optimizers.py | 25 +++---- 11 files changed, 115 insertions(+), 122 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 9d4e70fb880ce..b63b495a70879 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -112,7 +112,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No def _check_no_key(key: str) -> bool: if trainer.lr_schedulers: - return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers) + return any(key not in config.scheduler.optimizer.defaults for config in trainer.lr_schedulers) return any(key not in optimizer.defaults for optimizer in trainer.optimizers) @@ -176,9 +176,9 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa ) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False) self._remap_keys(scheduler_hparam_keys) - for name, scheduler in zip(scheduler_hparam_keys, trainer.lr_schedulers): - if interval in [scheduler["interval"], "any"]: - opt = scheduler["scheduler"].optimizer + for name, config in zip(scheduler_hparam_keys, trainer.lr_schedulers): + if interval in [config.interval, "any"]: + opt = config.scheduler.optimizer current_stat = self._get_lr_momentum_stat(opt, name) latest_stat.update(current_stat) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 504d477d3ff63..4796dcc5cc078 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,9 +24,9 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.optimizer import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] @@ -182,16 +182,15 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ) - default_scheduler_cfg = _get_default_scheduler_config() - assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1 - default_scheduler_cfg["scheduler"] = self._swa_scheduler + default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler) + assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 if trainer.lr_schedulers: scheduler_cfg = trainer.lr_schedulers[0] - if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1: + if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") rank_zero_info( - f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`" + f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" f" for `{self._swa_scheduler.__class__.__name__}`" ) trainer.lr_schedulers[0] = default_scheduler_cfg diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 25472d5295cbe..74cf9d5385cc3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -53,7 +53,13 @@ from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT +from pytorch_lightning.utilities.types import ( + _LRScheduler, + _METRIC_COLLECTION, + EPOCH_OUTPUT, + LRSchedulerTypeUnion, + STEP_OUTPUT, +) from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -158,7 +164,7 @@ def optimizers( # multiple opts return opts - def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: + def lr_schedulers(self) -> Optional[Union[_LRScheduler, List[_LRScheduler]]]: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. @@ -170,7 +176,7 @@ def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: return None # ignore other keys "interval", "frequency", etc. - lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers] + lr_schedulers = [config.scheduler for config in self.trainer.lr_schedulers] # single scheduler if len(lr_schedulers) == 1: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 3b0cdffff497e..c06b0d94cc5be 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -13,6 +13,7 @@ # limitations under the License. import weakref from contextlib import contextmanager +from dataclasses import fields from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from weakref import proxy @@ -24,7 +25,7 @@ from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerTypeTuple +from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerConfig, LRSchedulerTypeTuple def do_nothing_closure() -> None: @@ -172,7 +173,7 @@ def closure_dis(): def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", -) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]: +) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" model.trainer._lightning_optimizers = None optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) @@ -184,10 +185,11 @@ def _init_optimizers_and_lr_schedulers( optim_conf = _MockOptimizer() optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf) - _configure_schedulers = ( - _configure_schedulers_automatic_opt if model.automatic_optimization else _configure_schedulers_manual_opt + lr_schedulers = ( + _configure_schedulers_automatic_opt(lr_schedulers, monitor) + if model.automatic_optimization + else _configure_schedulers_manual_opt(lr_schedulers) ) - lr_schedulers = _configure_schedulers(lr_schedulers, monitor) _set_scheduler_opt_idx(optimizers, lr_schedulers) _validate_scheduler_api(lr_schedulers, model) return optimizers, lr_schedulers, optimizer_frequencies @@ -257,18 +259,15 @@ def _configure_optimizers( return optimizers, lr_schedulers, optimizer_frequencies, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using automatic optimization.""" lr_schedulers = [] - default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): # check provided keys - extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] + extra_keys = scheduler.keys() - {field.name for field in fields(LRSchedulerConfig)} if extra_keys: - rank_zero_warn( - f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning - ) + raise MisconfigurationException(f"Found unsupported keys in the lr scheduler dict: {extra_keys}") if "scheduler" not in scheduler: raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' @@ -292,7 +291,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] " Are you sure you didn't mean 'interval': 'step'?", category=RuntimeWarning, ) - lr_schedulers.append({**default_config, **scheduler}) + scheduler = LRSchedulerConfig(**scheduler) elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): if monitor is None: raise MisconfigurationException( @@ -300,19 +299,16 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] " scheduler is used. For example:" ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' ) - lr_schedulers.append( - {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} - ) + scheduler = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor) else: - lr_schedulers.append({**default_config, "scheduler": scheduler}) - + scheduler = LRSchedulerConfig(scheduler) + lr_schedulers.append(scheduler) return lr_schedulers -def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: +def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using manual optimization.""" lr_schedulers = [] - default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"} @@ -325,19 +321,18 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - category=RuntimeWarning, ) - scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys} - lr_schedulers.append({**default_config, **scheduler}) + scheduler = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys}) else: - lr_schedulers.append({**default_config, "scheduler": scheduler}) - + scheduler = LRSchedulerConfig(scheduler) + lr_schedulers.append(scheduler) return lr_schedulers -def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None: - for scheduler_config in lr_schedulers: - scheduler = scheduler_config["scheduler"] +def _validate_scheduler_api(lr_schedulers: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: + for config in lr_schedulers: + scheduler = config.scheduler if not isinstance(scheduler, _SupportsStateDict): - raise TypeError( + raise ValueError( f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." " It should have `state_dict` and `load_state_dict` methods defined." ) @@ -350,31 +345,18 @@ def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.Ligh ) -def _get_default_scheduler_config() -> Dict[str, Any]: - return { - "scheduler": None, - "name": None, # no custom name - "interval": "epoch", # after epoch is over - "frequency": 1, # every epoch/batch - "reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler - "monitor": None, # value to monitor for ReduceLROnPlateau - "strict": True, # enforce that the monitor exists for ReduceLROnPlateau - "opt_idx": None, # opt_idx assigned internally if not assigned by user - } - - -def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None: - for sch in lr_schedulers: +def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[LRSchedulerConfig]) -> None: + for config in lr_schedulers: for opt_idx, opt in enumerate(optimizers): - if sch["scheduler"].optimizer is opt: - if sch["opt_idx"] is not None and sch["opt_idx"] != opt_idx: + if config.scheduler.optimizer is opt: + if config.opt_idx is not None and config.opt_idx != opt_idx: raise MisconfigurationException( "`opt_idx` set inside scheduler config does not match with the index" " of the respective optimizer returned from `configure_optimizers`." ) - sch["opt_idx"] = opt_idx + config.opt_idx = opt_idx break else: raise MisconfigurationException( diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 69432ee07dd0b..ce1999f817613 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -419,24 +419,24 @@ def _update_learning_rates( opt_indices = [] for lr_scheduler in self.trainer.lr_schedulers: - if lr_scheduler["opt_idx"] not in opt_indices: + if lr_scheduler.opt_idx not in opt_indices: continue - if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]: + if update_plateau_schedulers ^ lr_scheduler.reduce_on_plateau: continue current_idx = self.batch_idx if interval == "step" else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero - if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0: + if lr_scheduler.interval == interval and current_idx % lr_scheduler.frequency == 0: monitor_val = None - if lr_scheduler["reduce_on_plateau"]: + if lr_scheduler.reduce_on_plateau: # If instance of ReduceLROnPlateau, we need a monitor - monitor_key = lr_scheduler["monitor"] + monitor_key = lr_scheduler.monitor monitor_val = self._get_monitor_value(monitor_key) if monitor_val is None: - if lr_scheduler.get("strict", True): + if lr_scheduler.strict: avail_metrics = list(self.trainer.callback_metrics) raise MisconfigurationException( f"ReduceLROnPlateau conditioned on metric {monitor_key}" @@ -456,8 +456,8 @@ def _update_learning_rates( # update LR self.trainer._call_lightning_module_hook( "lr_scheduler_step", - lr_scheduler["scheduler"], - lr_scheduler["opt_idx"], + lr_scheduler.scheduler, + lr_scheduler.opt_idx, monitor_val, ) self.scheduler_progress.increment_completed() diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 7504eb37500b2..6709f11282a2a 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -26,7 +26,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers +from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -444,15 +444,15 @@ def init_deepspeed(self): else: self._initialize_deepspeed_inference(model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[List[LRSchedulerConfig]], Optional[int]]: - optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) - if len(optimizers) > 1 or len(schedulers) > 1: + def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Optional[int]]: + optimizers, lr_schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) + if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) return ( optimizers[0], - schedulers[0] if schedulers else _get_default_scheduler_config(), + lr_schedulers[0] if lr_schedulers else None, optimizer_frequencies[0] if optimizer_frequencies else None, ) @@ -461,28 +461,33 @@ def zero_stage_3(self) -> bool: return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3 def _initialize_deepspeed_train(self, model): + optimizer, scheduler = None, None if "optimizer" in self.config: rank_zero_info( "You have specified an optimizer and/or scheduler within the DeepSpeed config." " It is recommended to define it in `LightningModule.configure_optimizers`." ) - optimizer, lr_scheduler = None, _get_default_scheduler_config() + lr_scheduler = None else: optimizer, lr_scheduler, _ = self._init_optimizers() + if lr_scheduler is not None: + scheduler = lr_scheduler.scheduler - scheduler = lr_scheduler["scheduler"] model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler) self._set_deepspeed_activation_checkpointing() # although we set these here, deepspeed manages the specific optimizer logic - self.lightning_module.trainer.optimizers = [deepspeed_optimizer] + self.optimizers = [deepspeed_optimizer] deepspeed_scheduler = model.lr_scheduler if deepspeed_scheduler is not None: # disable deepspeed lr scheduling as lightning manages scheduling model.lr_scheduler = None - lr_scheduler["scheduler"] = deepspeed_scheduler - self.lightning_module.trainer.lr_schedulers = [lr_scheduler] + if lr_scheduler is None: + lr_scheduler = LRSchedulerConfig(deepspeed_scheduler) + else: + lr_scheduler.scheduler = deepspeed_scheduler + self.lr_schedulers = [lr_scheduler] self.model = model @contextlib.contextmanager @@ -523,11 +528,10 @@ def _initialize_deepspeed_inference(self, model): " Using `configure_optimizers` to define optimizer and scheduler." ) optimizer, lr_scheduler, _ = self._init_optimizers() - scheduler = lr_scheduler["scheduler"] - inference_config = { - # todo: this is required for DeepSpeed throughput timers - "train_micro_batch_size_per_gpu": 1 - } + if lr_scheduler is not None: + scheduler = lr_scheduler.scheduler + # todo: this is required for DeepSpeed throughput timers + inference_config = {"train_micro_batch_size_per_gpu": 1} if "fp16" in self.config: inference_config.update({"fp16": self.config["fp16"]}) if self.zero_stage_3: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 125548471b529..5a6bc6b1796e7 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -305,8 +305,8 @@ def restore_lr_schedulers(self) -> None: # restore the lr schedulers lr_schedulers = self._loaded_checkpoint["lr_schedulers"] - for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): - scheduler["scheduler"].load_state_dict(lrs_state) + for config, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): + config.scheduler.load_state_dict(lrs_state) # ---------------------------------- # PRIVATE OPS @@ -368,8 +368,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump lr schedulers lr_schedulers = [] - for scheduler in self.trainer.lr_schedulers: - lr_schedulers.append(scheduler["scheduler"].state_dict()) + for config in self.trainer.lr_schedulers: + lr_schedulers.append(config.scheduler.state_dict()) checkpoint["lr_schedulers"] = lr_schedulers self.trainer.precision_plugin.on_save_checkpoint(checkpoint) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 39cadb7f9e7ef..13c637f2bbbf6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -515,11 +515,6 @@ def __init__( # hook self._call_callback_hooks("on_init_start") - # init optimizer + lr scheduler related flags - self.lr_schedulers = [] - self.optimizers = [] - self.optimizer_frequencies = [] - # init data flags self._data_connector.on_trainer_init( check_val_every_n_epoch, @@ -2021,6 +2016,7 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.strategy.optimizers = new_optims + # FIXME: bc compat @property def lr_schedulers(self) -> List[LRSchedulerConfig]: return self.strategy.lr_schedulers @@ -2030,11 +2026,11 @@ def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None: self.strategy.lr_schedulers = new_schedulers @property - def optimizer_frequencies(self) -> list: + def optimizer_frequencies(self) -> List[int]: return self.strategy.optimizer_frequencies @optimizer_frequencies.setter - def optimizer_frequencies(self, new_freqs: list) -> None: + def optimizer_frequencies(self, new_freqs: List[int]) -> None: self.strategy.optimizer_frequencies = new_freqs @property diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 7bf1bcf34ed96..58735d95960de 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -24,11 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.optimizer import ( - _get_default_scheduler_config, - _init_optimizers_and_lr_schedulers, - _set_scheduler_opt_idx, -) +from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -37,6 +33,8 @@ # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed +from pytorch_lightning.utilities.types import LRSchedulerConfig + if importlib.util.find_spec("ipywidgets") is not None: from tqdm.auto import tqdm else: @@ -127,11 +125,9 @@ def func(trainer): args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - sched_config = _get_default_scheduler_config() - sched_config.update({"scheduler": scheduler, "interval": "step", "opt_idx": 0}) trainer.strategy.optimizers = [optimizer] - trainer.strategy.lr_schedulers = [sched_config] + trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] trainer.strategy.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_schedulers) @@ -232,7 +228,7 @@ def lr_find( trainer.progress_bar_callback.disable() # Required for saving the model - trainer.optimizers, trainer.lr_schedulers = [], [] + trainer.strategy.optimizers, trainer.strategy.lr_schedulers = [], [] trainer.model = model # Dump model checkpoint diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 1d5cd272267d5..6e0ba9b126bc4 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -16,6 +16,7 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union @@ -23,7 +24,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric -from typing_extensions import Protocol, runtime_checkable, TypedDict +from typing_extensions import Protocol, runtime_checkable _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] @@ -94,12 +95,20 @@ def __init__( LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] -class LRSchedulerConfig(TypedDict): +@dataclass +class LRSchedulerConfig: scheduler: Union[_LRScheduler, ReduceLROnPlateau] - name: Optional[str] - interval: str - frequency: int - reduce_on_plateau: bool - monitor: Optional[str] - strict: bool - opt_idx: Optional[int] + # no custom name + name: Optional[str] = None + # after epoch is over + interval: str = "epoch" + # every epoch/batch + frequency: int = 1 + # most often not ReduceLROnPlateau scheduler + reduce_on_plateau: bool = False + # value to monitor for ReduceLROnPlateau + monitor: Optional[str] = None + # enforce that the monitor exists for ReduceLROnPlateau + strict: bool = True + # opt_idx assigned internally if not assigned by user + opt_idx: Optional[int] = None diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index e960eabcb9b62..2228caad24201 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -26,6 +26,7 @@ _init_optimizers_and_lr_schedulers, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -134,8 +135,8 @@ def configure_optimizers(self): assert trainer.state.finished, f"Training failed with {trainer.state}" lr_scheduler = trainer.lr_schedulers[0] - assert lr_scheduler == dict( - scheduler=lr_scheduler["scheduler"], + assert lr_scheduler == LRSchedulerConfig( + scheduler=lr_scheduler.scheduler, monitor="foo", interval="epoch", frequency=1, @@ -175,7 +176,7 @@ def test_optimizer_return_options(tmpdir): assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 - ref_lr_sched = dict( + ref_lr_sched = LRSchedulerConfig( scheduler=scheduler_a, interval="epoch", frequency=1, @@ -218,10 +219,10 @@ def test_optimizer_return_options(tmpdir): opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == len(freq) == 2 assert opt[0] == opt_a - ref_lr_sched["opt_idx"] = 0 + ref_lr_sched.opt_idx = 0 assert lr_sched[0] == ref_lr_sched - ref_lr_sched["scheduler"] = scheduler_b - ref_lr_sched["opt_idx"] = 1 + ref_lr_sched.scheduler = scheduler_b + ref_lr_sched.opt_idx = 1 assert lr_sched[1] == ref_lr_sched assert freq == [1, 5] @@ -309,11 +310,11 @@ def configure_optimizers(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.lr_schedulers[0]["opt_idx"] == 0 - assert trainer.lr_schedulers[1]["opt_idx"] == 1 + assert trainer.lr_schedulers[0].opt_idx == 0 + assert trainer.lr_schedulers[1].opt_idx == 1 # Step count is 1 greater than the expected value because scheduler.step() is called once during initialization - assert trainer.lr_schedulers[0]["scheduler"]._step_count == expected_steps[0] - assert trainer.lr_schedulers[1]["scheduler"]._step_count == expected_steps[1] + assert trainer.lr_schedulers[0].scheduler._step_count == expected_steps[0] + assert trainer.lr_schedulers[1].scheduler._step_count == expected_steps[1] @pytest.mark.parametrize("fn", ("validate", "test", "predict")) @@ -483,7 +484,7 @@ def test_lr_scheduler_with_extra_keys_warns(tmpdir): "lr_scheduler": {"scheduler": optim.lr_scheduler.StepLR(optimizer, 1), "foo": 1, "bar": 2}, } trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the lr scheduler dict: \[.+\]"): + with pytest.raises(MisconfigurationException, match=r"Found unsupported keys in the lr scheduler dict: \{.+\}"): trainer.fit(model) @@ -761,7 +762,7 @@ def configure_optimizers(self): model = CustomBoringModel() model.trainer = Trainer() - with pytest.raises(TypeError, match="provided lr scheduler `CustomScheduler` is invalid"): + with pytest.raises(ValueError, match="provided lr scheduler `CustomScheduler` is invalid"): _init_optimizers_and_lr_schedulers(model) From f63663bcd2f5f45ff40dcaea64e79065e430dc41 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:11:17 +0100 Subject: [PATCH 02/19] Deprecate `trainer.lr_schedulers` --- pytorch_lightning/callbacks/lr_monitor.py | 31 ++++++++++--------- .../callbacks/stochastic_weight_avg.py | 15 ++++----- pytorch_lightning/core/lightning.py | 4 +-- pytorch_lightning/core/optimizer.py | 4 +-- .../loops/epoch/training_epoch_loop.py | 20 ++++++------ pytorch_lightning/strategies/horovod.py | 6 ++-- pytorch_lightning/strategies/sharded_spawn.py | 3 +- pytorch_lightning/strategies/strategy.py | 1 + .../connectors/checkpoint_connector.py | 4 +-- pytorch_lightning/trainer/trainer.py | 16 +++++++--- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 6 ++-- tests/callbacks/test_lr_monitor.py | 10 +++--- tests/callbacks/test_stochastic_weight_avg.py | 6 ++-- tests/deprecated_api/test_remove_1-8.py | 6 ++++ tests/models/test_amp.py | 4 +-- tests/models/test_restore.py | 4 +-- tests/strategies/test_deepspeed_strategy.py | 10 +++--- tests/trainer/optimization/test_optimizers.py | 16 +++++----- tests/utilities/test_cli.py | 18 +++++------ 20 files changed, 99 insertions(+), 87 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index b63b495a70879..cbf17bd228694 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -29,6 +29,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig class LearningRateMonitor(Callback): @@ -111,8 +112,10 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No if self.log_momentum: def _check_no_key(key: str) -> bool: - if trainer.lr_schedulers: - return any(key not in config.scheduler.optimizer.defaults for config in trainer.lr_schedulers) + if trainer.lr_scheduler_configs: + return any( + key not in config.scheduler.optimizer.defaults for config in trainer.lr_scheduler_configs + ) return any(key not in optimizer.defaults for optimizer in trainer.optimizers) @@ -129,7 +132,7 @@ def _check_no_key(key: str) -> bool: sched_hparam_keys, optimizers_with_scheduler, optimizers_with_scheduler_types, - ) = self._find_names_from_schedulers(trainer.lr_schedulers) + ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs) names.extend(sched_hparam_keys) # Find names for leftover optimizers @@ -173,10 +176,10 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa scheduler_hparam_keys, optimizers_with_scheduler, optimizers_with_scheduler_types, - ) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False) + ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs, add_lr_sch_names=False) self._remap_keys(scheduler_hparam_keys) - for name, config in zip(scheduler_hparam_keys, trainer.lr_schedulers): + for name, config in zip(scheduler_hparam_keys, trainer.lr_scheduler_configs): if interval in [config.interval, "any"]: opt = config.scheduler.optimizer current_stat = self._get_lr_momentum_stat(opt, name) @@ -261,22 +264,22 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: return {n for n in names if names.count(n) > 1} def _find_names_from_schedulers( - self, lr_schedulers: List, add_lr_sch_names: bool = True + self, lr_scheduler_configs: List[LRSchedulerConfig], add_lr_sch_names: bool = True ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] seen_optimizers: List[Optimizer] = [] seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int) - for scheduler in lr_schedulers: - sch = scheduler["scheduler"] - if scheduler["name"] is not None: - name = scheduler["name"] + for config in lr_scheduler_configs: + sch = config.scheduler + if config.name is not None: + name = config.name else: name = "lr-" + sch.optimizer.__class__.__name__ updated_names = self._check_duplicates_and_update_name( - sch.optimizer, name, seen_optimizers, seen_optimizer_types, scheduler, add_lr_sch_names + sch.optimizer, name, seen_optimizers, seen_optimizer_types, config, add_lr_sch_names ) names.append(updated_names) @@ -313,14 +316,14 @@ def _check_duplicates_and_update_name( name: str, seen_optimizers: List[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int], - scheduler: Dict[str, Any] = None, + lr_scheduler_config: LRSchedulerConfig, add_lr_sch_names: bool = True, ) -> List[str]: seen_optimizers.append(optimizer) optimizer_cls = type(optimizer) - if scheduler is not None and scheduler["name"] is None: + if lr_scheduler_config is not None and lr_scheduler_config.name is None: seen_optimizer_types[optimizer_cls] += 1 - elif scheduler is None: + elif lr_scheduler_config is None: seen_optimizer_types[optimizer_cls] += 1 # Multiple param groups for the same optimizer diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 4796dcc5cc078..d2f19f83540e9 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -142,13 +142,10 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: self._average_model = deepcopy(pl_module) def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - optimizers = trainer.optimizers - lr_schedulers = trainer.lr_schedulers - - if len(optimizers) != 1: + if len(trainer.optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") - if len(lr_schedulers) > 1: + if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") if isinstance(self._swa_epoch_start, float): @@ -185,17 +182,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 - if trainer.lr_schedulers: - scheduler_cfg = trainer.lr_schedulers[0] + if trainer.lr_scheduler_configs: + scheduler_cfg = trainer.lr_scheduler_configs[0] if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") rank_zero_info( f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" f" for `{self._swa_scheduler.__class__.__name__}`" ) - trainer.lr_schedulers[0] = default_scheduler_cfg + trainer.lr_scheduler_configs[0] = default_scheduler_cfg else: - trainer.lr_schedulers.append(default_scheduler_cfg) + trainer.lr_scheduler_configs.append(default_scheduler_cfg) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 74cf9d5385cc3..6e4dd5eda5be6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -172,11 +172,11 @@ def lr_schedulers(self) -> Optional[Union[_LRScheduler, List[_LRScheduler]]]: A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no schedulers were returned in :meth:`configure_optimizers`. """ - if not self.trainer.lr_schedulers: + if not self.trainer.lr_scheduler_configs: return None # ignore other keys "interval", "frequency", etc. - lr_schedulers = [config.scheduler for config in self.trainer.lr_schedulers] + lr_schedulers = [config.scheduler for config in self.trainer.lr_scheduler_configs] # single scheduler if len(lr_schedulers) == 1: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index c06b0d94cc5be..1669ad2cb2346 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -345,8 +345,8 @@ def _validate_scheduler_api(lr_schedulers: List[LRSchedulerConfig], model: "pl.L ) -def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[LRSchedulerConfig]) -> None: - for config in lr_schedulers: +def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: + for config in lr_scheduler_configs: for opt_idx, opt in enumerate(optimizers): if config.scheduler.optimizer is opt: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ce1999f817613..528f23b24932b 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -412,31 +412,31 @@ def _update_learning_rates( so they have to be updated separately. opt_indices: indices of the optimizers to update. """ - if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: + if not self.trainer.lr_scheduler_configs or not self.trainer.lightning_module.automatic_optimization: return if opt_indices is None: opt_indices = [] - for lr_scheduler in self.trainer.lr_schedulers: - if lr_scheduler.opt_idx not in opt_indices: + for config in self.trainer.lr_scheduler_configs: + if config.opt_idx not in opt_indices: continue - if update_plateau_schedulers ^ lr_scheduler.reduce_on_plateau: + if update_plateau_schedulers ^ config.reduce_on_plateau: continue current_idx = self.batch_idx if interval == "step" else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero - if lr_scheduler.interval == interval and current_idx % lr_scheduler.frequency == 0: + if config.interval == interval and current_idx % config.frequency == 0: monitor_val = None - if lr_scheduler.reduce_on_plateau: + if config.reduce_on_plateau: # If instance of ReduceLROnPlateau, we need a monitor - monitor_key = lr_scheduler.monitor + monitor_key = config.monitor monitor_val = self._get_monitor_value(monitor_key) if monitor_val is None: - if lr_scheduler.strict: + if config.strict: avail_metrics = list(self.trainer.callback_metrics) raise MisconfigurationException( f"ReduceLROnPlateau conditioned on metric {monitor_key}" @@ -456,8 +456,8 @@ def _update_learning_rates( # update LR self.trainer._call_lightning_module_hook( "lr_scheduler_step", - lr_scheduler.scheduler, - lr_scheduler.opt_idx, + config.scheduler, + config.opt_idx, monitor_val, ) self.scheduler_progress.increment_completed() diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index a1c34fa87b8d5..72f39e207b691 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -101,9 +101,9 @@ def _unpack_lightning_optimizer(opt): param_group["lr"] *= self.world_size # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR - lr_schedulers = self.lightning_module.trainer.lr_schedulers - for scheduler in lr_schedulers: - scheduler = scheduler["scheduler"] + lr_scheduler_configs = self.lightning_module.trainer.lr_scheduler_configs + for config in lr_scheduler_configs: + scheduler = config.scheduler scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization diff --git a/pytorch_lightning/strategies/sharded_spawn.py b/pytorch_lightning/strategies/sharded_spawn.py index 6a6e1c3ade44e..2be4277fd8e81 100644 --- a/pytorch_lightning/strategies/sharded_spawn.py +++ b/pytorch_lightning/strategies/sharded_spawn.py @@ -39,11 +39,10 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def configure_ddp(self) -> None: trainer = self.lightning_module.trainer - self.model, optimizers = self._setup_model_and_optimizers( + self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=trainer.optimizers, ) - trainer.optimizers = optimizers def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 2d61f75c934eb..a45e1a2cfce80 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -52,6 +52,7 @@ def __init__( self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin self.optimizers: List[Optimizer] = [] + # FIXME: rename to _config? self.lr_schedulers: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] if is_overridden("post_dispatch", self, parent=Strategy): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5a6bc6b1796e7..c3b3f2988e847 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -305,7 +305,7 @@ def restore_lr_schedulers(self) -> None: # restore the lr schedulers lr_schedulers = self._loaded_checkpoint["lr_schedulers"] - for config, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): + for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers): config.scheduler.load_state_dict(lrs_state) # ---------------------------------- @@ -368,7 +368,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump lr schedulers lr_schedulers = [] - for config in self.trainer.lr_schedulers: + for config in self.trainer.lr_scheduler_configs: lr_schedulers.append(config.scheduler.state_dict()) checkpoint["lr_schedulers"] = lr_schedulers diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 13c637f2bbbf6..32bfab681948a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -19,6 +19,7 @@ import warnings from argparse import ArgumentParser, Namespace from copy import deepcopy +from dataclasses import asdict from datetime import timedelta from pathlib import Path from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Type, Union @@ -2016,14 +2017,19 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.strategy.optimizers = new_optims - # FIXME: bc compat @property - def lr_schedulers(self) -> List[LRSchedulerConfig]: + def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: + # FIXME: should we have this property at all? return self.strategy.lr_schedulers - @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None: - self.strategy.lr_schedulers = new_schedulers + @property + def lr_schedulers(self) -> List[Dict[str, Any]]: + rank_zero_deprecation( + "`Trainer.lr_schedulers` is deprecated in v1.6 and will be removed in v1.8." + " You can use `trainer.lr_scheduler_configs` instead which contains dataclasses instead of dictionaries.", + stacklevel=5, + ) + return [asdict(config) for config in self.strategy.lr_schedulers] @property def optimizer_frequencies(self) -> List[int]: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 84467310568f7..d4ffe7bc97f1f 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -123,7 +123,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule trainer.logger = DummyLogger() if trainer.logger is not None else None trainer.callbacks = [] # not needed before full run trainer.limit_train_batches = 1.0 - trainer.optimizers, trainer.lr_schedulers = [], [] # required for saving + trainer.optimizers, trainer.strategy.lr_schedulers = [], [] # required for saving trainer.model = model # required for saving diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 58735d95960de..b53ab75e659ec 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -127,9 +127,9 @@ def func(trainer): scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) trainer.strategy.optimizers = [optimizer] - trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] + trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] trainer.strategy.optimizer_frequencies = [] - _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_schedulers) + _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) return func @@ -335,7 +335,7 @@ def on_batch_start(self, trainer, pl_module): if self.progress_bar_refresh_rate and self.progress_bar is None: self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training) - self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0]) + self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """Called when the training batch ends, logs the calculated loss.""" diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index d35b1e8eefc38..82a4a5b99894a 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -40,7 +40,7 @@ def test_lr_monitor_single_lr(tmpdir): assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default" - assert len(lr_monitor.lrs) == len(trainer.lr_schedulers) + assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-SGD"] @@ -76,7 +76,7 @@ def configure_optimizers(self): trainer.fit(model) assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" - assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers) + assert len(lr_monitor.last_momentum_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values) @@ -103,7 +103,7 @@ def configure_optimizers(self): trainer.fit(model) assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" - assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers) + assert len(lr_monitor.last_momentum_values) == len(trainer.lr_scheduler_configs) assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values) @@ -237,7 +237,7 @@ def configure_optimizers(self): trainer.fit(model) assert lr_monitor.lrs, "No learning rates logged" - assert len(lr_monitor.lrs) == len(trainer.lr_schedulers) + assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": @@ -316,7 +316,7 @@ def configure_optimizers(self): trainer.fit(model, datamodule=dm) assert lr_monitor.lrs, "No learning rates logged" - assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers) + assert len(lr_monitor.lrs) == 2 * len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly" diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index ace2d359647b4..a568737dbba8c 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -78,9 +78,9 @@ def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: - assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) - assert trainer.lr_schedulers[0]["interval"] == "epoch" - assert trainer.lr_schedulers[0]["frequency"] == 1 + assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR) + assert trainer.lr_scheduler_configs[0].interval == "epoch" + assert trainer.lr_scheduler_configs[0].frequency == 1 def on_train_epoch_end(self, trainer, *args): super().on_train_epoch_end(trainer, *args) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 52c57a1f70247..e7c45e0d0c9b4 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -154,6 +154,12 @@ def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir): _ = trainer.should_rank_save_checkpoint +def test_v1_8_0_deprecated_lr_scheduler(): + trainer = Trainer() + with pytest.deprecated_call(match=r"`Trainer.lr_schedulers` is deprecated in v1.6 and will be removed in v1.8."): + assert trainer.lr_schedulers == [] + + def test_v1_8_0_trainer_optimizers_mixin(): trainer = Trainer() model = BoringModel() diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index b1d0116eb165a..17135b98c16f5 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -201,5 +201,5 @@ def configure_optimizers(self): assert trainer.state.finished, f"Training failed with {trainer.state}" assert bwd_mock.call_count == 10 - assert isinstance(trainer.lr_schedulers[0]["scheduler"].optimizer, optim.Adam) - assert isinstance(trainer.lr_schedulers[1]["scheduler"].optimizer, optim.SGD) + assert isinstance(trainer.lr_scheduler_configs[0].scheduler.optimizer, optim.Adam) + assert isinstance(trainer.lr_scheduler_configs[1].scheduler.optimizer, optim.SGD) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d6511d8db3cc2..dccdf9f601433 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -174,8 +174,8 @@ def _check_optimizers(self): def _check_schedulers(self): return all( - self._is_equal(self.trainer.lr_schedulers[i]["scheduler"].state_dict(), state_dict["lr_schedulers"][i]) - for i in range(len(self.trainer.lr_schedulers)) + self._is_equal(config.scheduler.state_dict(), state_dict["lr_schedulers"][i]) + for i, config in enumerate(self.trainer.lr_scheduler_configs) ) def _check_model_state_dict(self): diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index 50f0c94405e7c..e21e1cf7de04f 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -276,9 +276,9 @@ def on_train_start(self, trainer, pl_module) -> None: assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) - assert isinstance(trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) + assert isinstance(trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR) # check that the lr_scheduler config was preserved - assert trainer.lr_schedulers[0]["name"] == "Sean" + assert trainer.lr_scheduler_configs[0].name == "Sean" class TestModel(BoringModel): def configure_optimizers(self): @@ -314,7 +314,7 @@ def on_train_start(self, trainer, pl_module) -> None: assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) - assert isinstance(trainer.lr_schedulers[0]["scheduler"], WarmupLR) + assert isinstance(trainer.lr_scheduler_configs[0].scheduler, WarmupLR) model = BoringModel() trainer = Trainer( @@ -716,8 +716,8 @@ def on_train_batch_start( assert trainer.current_epoch == 1 # assert lr-scheduler states are loaded correctly - original_lr_scheduler = initial_trainer.lr_schedulers[0]["scheduler"] - current_lr_scheduler = trainer.lr_schedulers[0]["scheduler"] + original_lr_scheduler = initial_trainer.lr_scheduler_configs[0].scheduler + current_lr_scheduler = trainer.lr_scheduler_configs[0].scheduler assert original_lr_scheduler.state_dict() == current_lr_scheduler.state_dict() model = ModelParallelClassificationModel() diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 2228caad24201..1fc5fc1dc593f 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -44,7 +44,7 @@ def test_optimizer_with_scheduling(tmpdir): init_lr = 0.1 adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups] - assert len(trainer.lr_schedulers) == 1 + assert len(trainer.lr_scheduler_configs) == 1 assert all(a == adjusted_lr[0] for a in adjusted_lr) assert init_lr * 0.1 == adjusted_lr[0] @@ -74,7 +74,7 @@ def configure_optimizers(self): adjusted_lr1 = [pg["lr"] for pg in trainer.optimizers[0].param_groups] adjusted_lr2 = [pg["lr"] for pg in trainer.optimizers[1].param_groups] - assert len(trainer.lr_schedulers) == 2 + assert len(trainer.lr_scheduler_configs) == 2 assert all(a == adjusted_lr1[0] for a in adjusted_lr1) assert all(a == adjusted_lr2[0] for a in adjusted_lr2) assert model.init_lr * 0.1 == adjusted_lr1[0] @@ -134,7 +134,7 @@ def configure_optimizers(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - lr_scheduler = trainer.lr_schedulers[0] + lr_scheduler = trainer.lr_scheduler_configs[0] assert lr_scheduler == LRSchedulerConfig( scheduler=lr_scheduler.scheduler, monitor="foo", @@ -310,11 +310,11 @@ def configure_optimizers(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.lr_schedulers[0].opt_idx == 0 - assert trainer.lr_schedulers[1].opt_idx == 1 + assert trainer.lr_scheduler_configs[0].opt_idx == 0 + assert trainer.lr_scheduler_configs[1].opt_idx == 1 # Step count is 1 greater than the expected value because scheduler.step() is called once during initialization - assert trainer.lr_schedulers[0].scheduler._step_count == expected_steps[0] - assert trainer.lr_schedulers[1].scheduler._step_count == expected_steps[1] + assert trainer.lr_scheduler_configs[0].scheduler._step_count == expected_steps[0] + assert trainer.lr_scheduler_configs[1].scheduler._step_count == expected_steps[1] @pytest.mark.parametrize("fn", ("validate", "test", "predict")) @@ -333,7 +333,7 @@ def configure_optimizers(self): train_fn = getattr(trainer, fn) train_fn(TestModel(), datamodule=BoringDataModule(), ckpt_path=None) - assert len(trainer.lr_schedulers) == 0 + assert len(trainer.lr_scheduler_configs) == 0 assert len(trainer.optimizers) == 0 assert len(trainer.optimizer_frequencies) == 0 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 849d75416176b..1281d0b74c7ea 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -641,7 +641,7 @@ def add_arguments_to_parser(self, parser): else: assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 0 + assert len(cli.trainer.lr_scheduler_configs) == 0 def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir): @@ -658,9 +658,9 @@ def add_arguments_to_parser(self, parser): assert cli.model.configure_optimizers is not BoringModel.configure_optimizers assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR) - assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8 + assert len(cli.trainer.lr_scheduler_configs) == 1 + assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.ExponentialLR) + assert cli.trainer.lr_scheduler_configs[0].scheduler.gamma == 0.8 def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): @@ -684,9 +684,9 @@ def add_arguments_to_parser(self, parser): assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) - assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 + assert len(cli.trainer.lr_scheduler_configs) == 1 + assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR) + assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50 @pytest.mark.parametrize("use_registries", [False, True]) @@ -1387,8 +1387,8 @@ def test_cli_reducelronplateau(): ): cli = LightningCLI(BoringModel, run=False) config = cli.model.configure_optimizers() - assert isinstance(config["lr_scheduler"]["scheduler"], ReduceLROnPlateau) - assert config["lr_scheduler"]["scheduler"].monitor == "foo" + assert isinstance(config["lr_scheduler"].scheduler, ReduceLROnPlateau) + assert config["lr_scheduler"].scheduler.monitor == "foo" def test_cli_configureoptimizers_can_be_overridden(): From d7643ed6de629d8fb3a4c4121c1b29b356aeecdd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:16:12 +0100 Subject: [PATCH 03/19] Minor fixes --- pytorch_lightning/core/optimizer.py | 2 +- pytorch_lightning/trainer/trainer.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 1669ad2cb2346..ce81ba8e8c4dc 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -332,7 +332,7 @@ def _validate_scheduler_api(lr_schedulers: List[LRSchedulerConfig], model: "pl.L for config in lr_schedulers: scheduler = config.scheduler if not isinstance(scheduler, _SupportsStateDict): - raise ValueError( + raise TypeError( f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." " It should have `state_dict` and `load_state_dict` methods defined." ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 32bfab681948a..d7bb41a3da1a0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2019,7 +2019,6 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: @property def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: - # FIXME: should we have this property at all? return self.strategy.lr_schedulers @property From 95f3e639ea50f3a0a8c486f3c7d70062905057d2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:22:50 +0100 Subject: [PATCH 04/19] Minor fixes --- pytorch_lightning/callbacks/lr_monitor.py | 2 +- pytorch_lightning/core/optimizer.py | 5 ++++- tests/trainer/optimization/test_optimizers.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index cbf17bd228694..4e529a7175f36 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -316,7 +316,7 @@ def _check_duplicates_and_update_name( name: str, seen_optimizers: List[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int], - lr_scheduler_config: LRSchedulerConfig, + lr_scheduler_config: Optional[LRSchedulerConfig], add_lr_sch_names: bool = True, ) -> List[str]: seen_optimizers.append(optimizer) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ce81ba8e8c4dc..f92ac09cd5120 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -267,7 +267,10 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] # check provided keys extra_keys = scheduler.keys() - {field.name for field in fields(LRSchedulerConfig)} if extra_keys: - raise MisconfigurationException(f"Found unsupported keys in the lr scheduler dict: {extra_keys}") + raise MisconfigurationException( + f"Found unsupported keys in the lr scheduler dict: {extra_keys}. HINT: remove them from the output" + " of `configure_optimizers`." + ) if "scheduler" not in scheduler: raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 1fc5fc1dc593f..482cdb383ce21 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -762,7 +762,7 @@ def configure_optimizers(self): model = CustomBoringModel() model.trainer = Trainer() - with pytest.raises(ValueError, match="provided lr scheduler `CustomScheduler` is invalid"): + with pytest.raises(TypeError, match="provided lr scheduler `CustomScheduler` is invalid"): _init_optimizers_and_lr_schedulers(model) From dacf417696d0803fd9f285da24dd6e6c545beb6e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:35:33 +0100 Subject: [PATCH 05/19] Fix mypy --- pytorch_lightning/core/optimizer.py | 9 +++++++-- pytorch_lightning/utilities/types.py | 6 ++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index f92ac09cd5120..61c96eaafc9ae 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -25,7 +25,12 @@ from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerConfig, LRSchedulerTypeTuple +from pytorch_lightning.utilities.types import ( + _SupportsStateDict, + LRSchedulerConfig, + LRSchedulerTypeTuple, + ReduceLROnPlateau, +) def do_nothing_closure() -> None: @@ -295,7 +300,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] category=RuntimeWarning, ) scheduler = LRSchedulerConfig(**scheduler) - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): + elif isinstance(scheduler, ReduceLROnPlateau): if monitor is None: raise MisconfigurationException( "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`" diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 6e0ba9b126bc4..c239dcc62425b 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -60,7 +60,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class _LRScheduler(_SupportsStateDict): +@runtime_checkable +class _LRScheduler(_SupportsStateDict, Protocol): optimizer: Optimizer def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: @@ -69,7 +70,8 @@ def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau(_SupportsStateDict): +@runtime_checkable +class ReduceLROnPlateau(_SupportsStateDict, Protocol): in_cooldown: bool optimizer: Optimizer From dd17196c3db4fbc478370188e7127369eb1574dc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:42:21 +0100 Subject: [PATCH 06/19] Update tests --- tests/models/test_restore.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index dccdf9f601433..ab1c05acb8fb0 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -168,24 +168,26 @@ def _is_equal(self, a, b): def _check_optimizers(self): return all( - self._is_equal(self.trainer.optimizers[i].state_dict(), state_dict["optimizer_states"][i]) - for i in range(len(self.trainer.optimizers)) + self._is_equal(optimizer.state_dict(), state) + for optimizer, state in zip(self.trainer.optimizers, state_dict["optimizer_states"]) ) def _check_schedulers(self): return all( - self._is_equal(config.scheduler.state_dict(), state_dict["lr_schedulers"][i]) - for i, config in enumerate(self.trainer.lr_scheduler_configs) + self._is_equal(config.scheduler.state_dict(), state) + for config, state in zip(self.trainer.lr_scheduler_configs, state_dict["lr_schedulers"]) ) def _check_model_state_dict(self): - for k in self.state_dict(): - yield self._is_equal(self.state_dict()[k], state_dict["state_dict"][k]) + return all( + self._is_equal(actual, expected) + for actual, expected in zip(self.state_dict(), state_dict["state_dict"]) + ) def _test_on_val_test_predict_tune_start(self): assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] - assert all(self._check_model_state_dict()) + assert self._check_model_state_dict() # no optimizes and schedulers are loaded otherwise if self.trainer.state.fn != TrainerFn.TUNING: @@ -200,7 +202,7 @@ def on_train_start(self): else: assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] - assert all(self._check_model_state_dict()) + assert self._check_model_state_dict() assert self._check_optimizers() assert self._check_schedulers() From 494af7ca5169b87e7554e05b9831cb676018a63a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:45:25 +0100 Subject: [PATCH 07/19] Minor improvements --- pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/tuner/lr_finder.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d7bb41a3da1a0..64d6843a35822 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -19,7 +19,6 @@ import warnings from argparse import ArgumentParser, Namespace from copy import deepcopy -from dataclasses import asdict from datetime import timedelta from pathlib import Path from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Type, Union @@ -2028,6 +2027,8 @@ def lr_schedulers(self) -> List[Dict[str, Any]]: " You can use `trainer.lr_scheduler_configs` instead which contains dataclasses instead of dictionaries.", stacklevel=5, ) + from dataclasses import asdict + return [asdict(config) for config in self.strategy.lr_schedulers] @property diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index b53ab75e659ec..8aa8c0ebd84a6 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -30,11 +30,10 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities.types import LRSchedulerConfig # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed -from pytorch_lightning.utilities.types import LRSchedulerConfig - if importlib.util.find_spec("ipywidgets") is not None: from tqdm.auto import tqdm else: From b939e8e35c878b36215432f0ebc5f57512f9d204 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:53:33 +0100 Subject: [PATCH 08/19] Undo, CHANGELOG --- CHANGELOG.md | 5 ++++- pytorch_lightning/core/optimizer.py | 11 +++++++---- tests/trainer/optimization/test_optimizers.py | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c8d607648034..c5a7d0541cc40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -245,6 +245,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}` ([#11000](https://github.com/PyTorchLightning/pytorch-lightning/pull/11000)) +- Deprecated `Trainer.lr_schedulers` in favor of `Trainer.lr_scheduler_configs` which returns a list of dataclasses instead of dictionaries ([#11443](https://github.com/PyTorchLightning/pytorch-lightning/pull/11443)) + + - Deprecated `Trainer.verbose_evaluate` in favor of `EvaluationLoop(verbose=...)` ([#10931](https://github.com/PyTorchLightning/pytorch-lightning/pull/10931)) @@ -425,7 +428,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199)) -- Fixed the default value used with `log(reduce_fx=min|max)` ([#11310](https://github.com/PyTorchLightning/pytorch-lightning/pull/11310)) +- Fixed the _configure_schedulers_automatic_optault value used with `log(reduce_fx=min|max)` ([#11310](https://github.com/PyTorchLightning/pytorch-lightning/pull/11310)) - Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294)) - Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288)) - Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948)) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 61c96eaafc9ae..7b668973efb42 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -270,12 +270,15 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] for scheduler in schedulers: if isinstance(scheduler, dict): # check provided keys - extra_keys = scheduler.keys() - {field.name for field in fields(LRSchedulerConfig)} + supported_keys = {field.name for field in fields(LRSchedulerConfig)} + extra_keys = scheduler.keys() - supported_keys if extra_keys: - raise MisconfigurationException( - f"Found unsupported keys in the lr scheduler dict: {extra_keys}. HINT: remove them from the output" - " of `configure_optimizers`." + rank_zero_warn( + f"Found unsupported keys in the lr scheduler dict: {extra_keys}. " + " HINT: remove them from the output of `configure_optimizers`.", + category=RuntimeWarning, ) + scheduler = {k: v for k, v in scheduler.items() if k in supported_keys} if "scheduler" not in scheduler: raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 482cdb383ce21..74abfda11b241 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -484,7 +484,7 @@ def test_lr_scheduler_with_extra_keys_warns(tmpdir): "lr_scheduler": {"scheduler": optim.lr_scheduler.StepLR(optimizer, 1), "foo": 1, "bar": 2}, } trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises(MisconfigurationException, match=r"Found unsupported keys in the lr scheduler dict: \{.+\}"): + with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the lr scheduler dict: \{.+\}"): trainer.fit(model) From 347e1a875709c9fdb7b94556e1875ed6cdc7671b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 22:54:30 +0100 Subject: [PATCH 09/19] lol --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5a7d0541cc40..cf6c7c0c2b1b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -428,7 +428,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199)) -- Fixed the _configure_schedulers_automatic_optault value used with `log(reduce_fx=min|max)` ([#11310](https://github.com/PyTorchLightning/pytorch-lightning/pull/11310)) +- Fixed the default value used with `log(reduce_fx=min|max)` ([#11310](https://github.com/PyTorchLightning/pytorch-lightning/pull/11310)) - Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294)) - Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288)) - Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948)) From a34a47c3d47aa8933060cd6fb64528d3aa24e3fd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 13 Jan 2022 16:11:35 +0100 Subject: [PATCH 10/19] Fix error --- pytorch_lightning/strategies/strategy.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index a45e1a2cfce80..34fe3a1ad50db 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -52,7 +52,7 @@ def __init__( self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin self.optimizers: List[Optimizer] = [] - # FIXME: rename to _config? + # TODO: rename to `lr_scheduler_configs` to match the property in the `Trainer` self.lr_schedulers: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] if is_overridden("post_dispatch", self, parent=Strategy): diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 8aa8c0ebd84a6..9be7799a6e5a7 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -126,7 +126,7 @@ def func(trainer): scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) trainer.strategy.optimizers = [optimizer] - trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] + trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] trainer.strategy.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) From 2680412b2d943430201c100b4d4d4ab4bdde2248 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 14 Jan 2022 00:19:48 +0100 Subject: [PATCH 11/19] Undo unrelated change --- pytorch_lightning/strategies/sharded_spawn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/sharded_spawn.py b/pytorch_lightning/strategies/sharded_spawn.py index 2be4277fd8e81..6a6e1c3ade44e 100644 --- a/pytorch_lightning/strategies/sharded_spawn.py +++ b/pytorch_lightning/strategies/sharded_spawn.py @@ -39,10 +39,11 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def configure_ddp(self) -> None: trainer = self.lightning_module.trainer - self.model, self.optimizers = self._setup_model_and_optimizers( + self.model, optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=trainer.optimizers, ) + trainer.optimizers = optimizers def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. From 5498cca79621dbf8d062f210e2e880d5307db937 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 14 Jan 2022 00:24:32 +0100 Subject: [PATCH 12/19] Use direct reference --- pytorch_lightning/strategies/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index 72f39e207b691..19fa1ca3d2b8e 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -101,7 +101,7 @@ def _unpack_lightning_optimizer(opt): param_group["lr"] *= self.world_size # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR - lr_scheduler_configs = self.lightning_module.trainer.lr_scheduler_configs + lr_scheduler_configs = self.lr_schedulers for config in lr_scheduler_configs: scheduler = config.scheduler scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] From 43f103978be7c8beb41772223d485d1274e3012b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 15 Jan 2022 03:03:55 +0100 Subject: [PATCH 13/19] Fix test --- tests/utilities/test_cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 1281d0b74c7ea..2f501a699d0c2 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -1387,8 +1387,8 @@ def test_cli_reducelronplateau(): ): cli = LightningCLI(BoringModel, run=False) config = cli.model.configure_optimizers() - assert isinstance(config["lr_scheduler"].scheduler, ReduceLROnPlateau) - assert config["lr_scheduler"].scheduler.monitor == "foo" + assert isinstance(config["lr_scheduler"]["scheduler"], ReduceLROnPlateau) + assert config["lr_scheduler"]["scheduler"].monitor == "foo" def test_cli_configureoptimizers_can_be_overridden(): From d0d365bad0d3b2fa240300517732e3b226f5c1c7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 16:26:40 +0100 Subject: [PATCH 14/19] Bad merge --- pytorch_lightning/trainer/trainer.py | 1 + pytorch_lightning/tuner/lr_finder.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9fe9ebcef6876..4e19f16b29c6d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2006,6 +2006,7 @@ def optimizers(self) -> List[Optimizer]: def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.strategy.optimizers = new_optims + @property def lightning_optimizers(self) -> Dict[int, LightningOptimizer]: rank_zero_deprecation( "`Trainer.lightning_optimizers` is deprecated in v1.6 and will be removed in v1.8", stacklevel=5 diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 9be7799a6e5a7..bc20885c27b62 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -125,9 +125,9 @@ def func(trainer): args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - trainer.strategy.optimizers = [optimizer] + trainer.optimizers = [optimizer] trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] - trainer.strategy.optimizer_frequencies = [] + trainer.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) return func From d06b46944029286096c27debec16e31f1259caeb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 16:32:27 +0100 Subject: [PATCH 15/19] Bad merge --- pytorch_lightning/tuner/lr_finder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index bc20885c27b62..f61b0e4702bf9 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -125,9 +125,9 @@ def func(trainer): args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - trainer.optimizers = [optimizer] + trainer.stategy.optimizers = [optimizer] trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] - trainer.optimizer_frequencies = [] + trainer.strategy.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) return func @@ -227,7 +227,7 @@ def lr_find( trainer.progress_bar_callback.disable() # Required for saving the model - trainer.strategy.optimizers, trainer.strategy.lr_schedulers = [], [] + trainer.optimizers, trainer.strategy.lr_schedulers = [], [] trainer.model = model # Dump model checkpoint From c779ac9a8ff20507e8e35e6c5076eb51c73cc87e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 18 Jan 2022 16:40:02 +0100 Subject: [PATCH 16/19] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/core/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ff4c0a33e0be9..3f2aef982ce78 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -268,7 +268,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] extra_keys = scheduler.keys() - supported_keys if extra_keys: rank_zero_warn( - f"Found unsupported keys in the lr scheduler dict: {extra_keys}. " + f"Found unsupported keys in the lr scheduler dict: {extra_keys}." " HINT: remove them from the output of `configure_optimizers`.", category=RuntimeWarning, ) From f1c3cb88eea6d571689bdc56c0fde2b9d99acd27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 18 Jan 2022 16:46:19 +0100 Subject: [PATCH 17/19] Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b512960ec9cfb..589408913d07f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -164,7 +164,7 @@ def optimizers( # multiple opts return opts - def lr_schedulers(self) -> Optional[Union[_LRScheduler, List[_LRScheduler]]]: + def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRSchedulerTypeUnion]]]: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. From 9f8db5c0ca2080b198c6ea77d5cb102e55d689eb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 17:17:18 +0100 Subject: [PATCH 18/19] mypy --- pytorch_lightning/core/lightning.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 589408913d07f..05cc8d87eaac6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -53,13 +53,7 @@ from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import ( - _LRScheduler, - _METRIC_COLLECTION, - EPOCH_OUTPUT, - LRSchedulerTypeUnion, - STEP_OUTPUT, -) +from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() From 55059c613eb7065e87051d1fb4e1b3080c7b37c0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 17:18:21 +0100 Subject: [PATCH 19/19] Typo --- pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index f61b0e4702bf9..a15e65ef986b3 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -125,7 +125,7 @@ def func(trainer): args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - trainer.stategy.optimizers = [optimizer] + trainer.strategy.optimizers = [optimizer] trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] trainer.strategy.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)