diff --git a/CHANGELOG.md b/CHANGELOG.md index 95dee9bd163c1..34bce9d691bda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -254,6 +254,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}` ([#11000](https://github.com/PyTorchLightning/pytorch-lightning/pull/11000)) +- Deprecated `Trainer.lr_schedulers` in favor of `Trainer.lr_scheduler_configs` which returns a list of dataclasses instead of dictionaries ([#11443](https://github.com/PyTorchLightning/pytorch-lightning/pull/11443)) + + - Deprecated `Trainer.verbose_evaluate` in favor of `EvaluationLoop(verbose=...)` ([#10931](https://github.com/PyTorchLightning/pytorch-lightning/pull/10931)) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 9d4e70fb880ce..4e529a7175f36 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -29,6 +29,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig class LearningRateMonitor(Callback): @@ -111,8 +112,10 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No if self.log_momentum: def _check_no_key(key: str) -> bool: - if trainer.lr_schedulers: - return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers) + if trainer.lr_scheduler_configs: + return any( + key not in config.scheduler.optimizer.defaults for config in trainer.lr_scheduler_configs + ) return any(key not in optimizer.defaults for optimizer in trainer.optimizers) @@ -129,7 +132,7 @@ def _check_no_key(key: str) -> bool: sched_hparam_keys, optimizers_with_scheduler, optimizers_with_scheduler_types, - ) = self._find_names_from_schedulers(trainer.lr_schedulers) + ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs) names.extend(sched_hparam_keys) # Find names for leftover optimizers @@ -173,12 +176,12 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa scheduler_hparam_keys, optimizers_with_scheduler, optimizers_with_scheduler_types, - ) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False) + ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs, add_lr_sch_names=False) self._remap_keys(scheduler_hparam_keys) - for name, scheduler in zip(scheduler_hparam_keys, trainer.lr_schedulers): - if interval in [scheduler["interval"], "any"]: - opt = scheduler["scheduler"].optimizer + for name, config in zip(scheduler_hparam_keys, trainer.lr_scheduler_configs): + if interval in [config.interval, "any"]: + opt = config.scheduler.optimizer current_stat = self._get_lr_momentum_stat(opt, name) latest_stat.update(current_stat) @@ -261,22 +264,22 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: return {n for n in names if names.count(n) > 1} def _find_names_from_schedulers( - self, lr_schedulers: List, add_lr_sch_names: bool = True + self, lr_scheduler_configs: List[LRSchedulerConfig], add_lr_sch_names: bool = True ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] seen_optimizers: List[Optimizer] = [] seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int) - for scheduler in lr_schedulers: - sch = scheduler["scheduler"] - if scheduler["name"] is not None: - name = scheduler["name"] + for config in lr_scheduler_configs: + sch = config.scheduler + if config.name is not None: + name = config.name else: name = "lr-" + sch.optimizer.__class__.__name__ updated_names = self._check_duplicates_and_update_name( - sch.optimizer, name, seen_optimizers, seen_optimizer_types, scheduler, add_lr_sch_names + sch.optimizer, name, seen_optimizers, seen_optimizer_types, config, add_lr_sch_names ) names.append(updated_names) @@ -313,14 +316,14 @@ def _check_duplicates_and_update_name( name: str, seen_optimizers: List[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int], - scheduler: Dict[str, Any] = None, + lr_scheduler_config: Optional[LRSchedulerConfig], add_lr_sch_names: bool = True, ) -> List[str]: seen_optimizers.append(optimizer) optimizer_cls = type(optimizer) - if scheduler is not None and scheduler["name"] is None: + if lr_scheduler_config is not None and lr_scheduler_config.name is None: seen_optimizer_types[optimizer_cls] += 1 - elif scheduler is None: + elif lr_scheduler_config is None: seen_optimizer_types[optimizer_cls] += 1 # Multiple param groups for the same optimizer diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 504d477d3ff63..d2f19f83540e9 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,9 +24,9 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.optimizer import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] @@ -142,13 +142,10 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: self._average_model = deepcopy(pl_module) def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - optimizers = trainer.optimizers - lr_schedulers = trainer.lr_schedulers - - if len(optimizers) != 1: + if len(trainer.optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") - if len(lr_schedulers) > 1: + if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") if isinstance(self._swa_epoch_start, float): @@ -182,21 +179,20 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ) - default_scheduler_cfg = _get_default_scheduler_config() - assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1 - default_scheduler_cfg["scheduler"] = self._swa_scheduler + default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler) + assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 - if trainer.lr_schedulers: - scheduler_cfg = trainer.lr_schedulers[0] - if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1: + if trainer.lr_scheduler_configs: + scheduler_cfg = trainer.lr_scheduler_configs[0] + if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") rank_zero_info( - f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`" + f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" f" for `{self._swa_scheduler.__class__.__name__}`" ) - trainer.lr_schedulers[0] = default_scheduler_cfg + trainer.lr_scheduler_configs[0] = default_scheduler_cfg else: - trainer.lr_schedulers.append(default_scheduler_cfg) + trainer.lr_scheduler_configs.append(default_scheduler_cfg) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 99e51ec76d4c3..05cc8d87eaac6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -158,7 +158,7 @@ def optimizers( # multiple opts return opts - def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: + def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRSchedulerTypeUnion]]]: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. @@ -166,11 +166,11 @@ def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no schedulers were returned in :meth:`configure_optimizers`. """ - if not self.trainer.lr_schedulers: + if not self.trainer.lr_scheduler_configs: return None # ignore other keys "interval", "frequency", etc. - lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers] + lr_schedulers = [config.scheduler for config in self.trainer.lr_scheduler_configs] # single scheduler if len(lr_schedulers) == 1: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 3dd7000acf311..3f2aef982ce78 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager +from dataclasses import fields from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from weakref import proxy @@ -23,7 +24,12 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerTypeTuple +from pytorch_lightning.utilities.types import ( + _SupportsStateDict, + LRSchedulerConfig, + LRSchedulerTypeTuple, + ReduceLROnPlateau, +) def do_nothing_closure() -> None: @@ -167,7 +173,7 @@ def closure_dis(): def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", -) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]: +) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) @@ -178,10 +184,11 @@ def _init_optimizers_and_lr_schedulers( optim_conf = _MockOptimizer() optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf) - _configure_schedulers = ( - _configure_schedulers_automatic_opt if model.automatic_optimization else _configure_schedulers_manual_opt + lr_schedulers = ( + _configure_schedulers_automatic_opt(lr_schedulers, monitor) + if model.automatic_optimization + else _configure_schedulers_manual_opt(lr_schedulers) ) - lr_schedulers = _configure_schedulers(lr_schedulers, monitor) _set_scheduler_opt_idx(optimizers, lr_schedulers) _validate_scheduler_api(lr_schedulers, model) return optimizers, lr_schedulers, optimizer_frequencies @@ -251,18 +258,21 @@ def _configure_optimizers( return optimizers, lr_schedulers, optimizer_frequencies, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using automatic optimization.""" lr_schedulers = [] - default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): # check provided keys - extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] + supported_keys = {field.name for field in fields(LRSchedulerConfig)} + extra_keys = scheduler.keys() - supported_keys if extra_keys: rank_zero_warn( - f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning + f"Found unsupported keys in the lr scheduler dict: {extra_keys}." + " HINT: remove them from the output of `configure_optimizers`.", + category=RuntimeWarning, ) + scheduler = {k: v for k, v in scheduler.items() if k in supported_keys} if "scheduler" not in scheduler: raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' @@ -286,27 +296,24 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] " Are you sure you didn't mean 'interval': 'step'?", category=RuntimeWarning, ) - lr_schedulers.append({**default_config, **scheduler}) - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): + scheduler = LRSchedulerConfig(**scheduler) + elif isinstance(scheduler, ReduceLROnPlateau): if monitor is None: raise MisconfigurationException( "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`" " scheduler is used. For example:" ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' ) - lr_schedulers.append( - {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} - ) + scheduler = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor) else: - lr_schedulers.append({**default_config, "scheduler": scheduler}) - + scheduler = LRSchedulerConfig(scheduler) + lr_schedulers.append(scheduler) return lr_schedulers -def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: +def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using manual optimization.""" lr_schedulers = [] - default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"} @@ -319,17 +326,16 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - category=RuntimeWarning, ) - scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys} - lr_schedulers.append({**default_config, **scheduler}) + scheduler = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys}) else: - lr_schedulers.append({**default_config, "scheduler": scheduler}) - + scheduler = LRSchedulerConfig(scheduler) + lr_schedulers.append(scheduler) return lr_schedulers -def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None: - for scheduler_config in lr_schedulers: - scheduler = scheduler_config["scheduler"] +def _validate_scheduler_api(lr_schedulers: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: + for config in lr_schedulers: + scheduler = config.scheduler if not isinstance(scheduler, _SupportsStateDict): raise TypeError( f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." @@ -344,31 +350,18 @@ def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.Ligh ) -def _get_default_scheduler_config() -> Dict[str, Any]: - return { - "scheduler": None, - "name": None, # no custom name - "interval": "epoch", # after epoch is over - "frequency": 1, # every epoch/batch - "reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler - "monitor": None, # value to monitor for ReduceLROnPlateau - "strict": True, # enforce that the monitor exists for ReduceLROnPlateau - "opt_idx": None, # opt_idx assigned internally if not assigned by user - } - - -def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None: - for sch in lr_schedulers: +def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: + for config in lr_scheduler_configs: for opt_idx, opt in enumerate(optimizers): - if sch["scheduler"].optimizer is opt: - if sch["opt_idx"] is not None and sch["opt_idx"] != opt_idx: + if config.scheduler.optimizer is opt: + if config.opt_idx is not None and config.opt_idx != opt_idx: raise MisconfigurationException( "`opt_idx` set inside scheduler config does not match with the index" " of the respective optimizer returned from `configure_optimizers`." ) - sch["opt_idx"] = opt_idx + config.opt_idx = opt_idx break else: raise MisconfigurationException( diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index b806b5ff5da3a..b23608e0efd8a 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -410,31 +410,31 @@ def _update_learning_rates( so they have to be updated separately. opt_indices: indices of the optimizers to update. """ - if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: + if not self.trainer.lr_scheduler_configs or not self.trainer.lightning_module.automatic_optimization: return if opt_indices is None: opt_indices = [] - for lr_scheduler in self.trainer.lr_schedulers: - if lr_scheduler["opt_idx"] not in opt_indices: + for config in self.trainer.lr_scheduler_configs: + if config.opt_idx not in opt_indices: continue - if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]: + if update_plateau_schedulers ^ config.reduce_on_plateau: continue current_idx = self.batch_idx if interval == "step" else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero - if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0: + if config.interval == interval and current_idx % config.frequency == 0: monitor_val = None - if lr_scheduler["reduce_on_plateau"]: + if config.reduce_on_plateau: # If instance of ReduceLROnPlateau, we need a monitor - monitor_key = lr_scheduler["monitor"] + monitor_key = config.monitor monitor_val = self._get_monitor_value(monitor_key) if monitor_val is None: - if lr_scheduler.get("strict", True): + if config.strict: avail_metrics = list(self.trainer.callback_metrics) raise MisconfigurationException( f"ReduceLROnPlateau conditioned on metric {monitor_key}" @@ -454,8 +454,8 @@ def _update_learning_rates( # update LR self.trainer._call_lightning_module_hook( "lr_scheduler_step", - lr_scheduler["scheduler"], - lr_scheduler["opt_idx"], + config.scheduler, + config.opt_idx, monitor_val, ) self.scheduler_progress.increment_completed() diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 568d84941f76e..6709f11282a2a 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -26,7 +26,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers +from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -444,15 +444,15 @@ def init_deepspeed(self): else: self._initialize_deepspeed_inference(model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[List[LRSchedulerConfig]], Optional[int]]: - optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) - if len(optimizers) > 1 or len(schedulers) > 1: + def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Optional[int]]: + optimizers, lr_schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) + if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) return ( optimizers[0], - schedulers[0] if schedulers else _get_default_scheduler_config(), + lr_schedulers[0] if lr_schedulers else None, optimizer_frequencies[0] if optimizer_frequencies else None, ) @@ -461,27 +461,32 @@ def zero_stage_3(self) -> bool: return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3 def _initialize_deepspeed_train(self, model): + optimizer, scheduler = None, None if "optimizer" in self.config: rank_zero_info( "You have specified an optimizer and/or scheduler within the DeepSpeed config." " It is recommended to define it in `LightningModule.configure_optimizers`." ) - optimizer, lr_scheduler = None, _get_default_scheduler_config() + lr_scheduler = None else: optimizer, lr_scheduler, _ = self._init_optimizers() + if lr_scheduler is not None: + scheduler = lr_scheduler.scheduler - scheduler = lr_scheduler["scheduler"] model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler) self._set_deepspeed_activation_checkpointing() # although we set these here, deepspeed manages the specific optimizer logic - self.lightning_module.trainer.optimizers = [deepspeed_optimizer] + self.optimizers = [deepspeed_optimizer] deepspeed_scheduler = model.lr_scheduler if deepspeed_scheduler is not None: # disable deepspeed lr scheduling as lightning manages scheduling model.lr_scheduler = None - lr_scheduler["scheduler"] = deepspeed_scheduler + if lr_scheduler is None: + lr_scheduler = LRSchedulerConfig(deepspeed_scheduler) + else: + lr_scheduler.scheduler = deepspeed_scheduler self.lr_schedulers = [lr_scheduler] self.model = model @@ -523,11 +528,10 @@ def _initialize_deepspeed_inference(self, model): " Using `configure_optimizers` to define optimizer and scheduler." ) optimizer, lr_scheduler, _ = self._init_optimizers() - scheduler = lr_scheduler["scheduler"] - inference_config = { - # todo: this is required for DeepSpeed throughput timers - "train_micro_batch_size_per_gpu": 1 - } + if lr_scheduler is not None: + scheduler = lr_scheduler.scheduler + # todo: this is required for DeepSpeed throughput timers + inference_config = {"train_micro_batch_size_per_gpu": 1} if "fp16" in self.config: inference_config.update({"fp16": self.config["fp16"]}) if self.zero_stage_3: diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index a1c34fa87b8d5..19fa1ca3d2b8e 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -101,9 +101,9 @@ def _unpack_lightning_optimizer(opt): param_group["lr"] *= self.world_size # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR - lr_schedulers = self.lightning_module.trainer.lr_schedulers - for scheduler in lr_schedulers: - scheduler = scheduler["scheduler"] + lr_scheduler_configs = self.lr_schedulers + for config in lr_scheduler_configs: + scheduler = config.scheduler scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 04ab5969798d3..5019890ad4798 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -53,6 +53,7 @@ def __init__( self.precision_plugin = precision_plugin self._optimizers: List[Optimizer] = [] self._lightning_optimizers: Dict[int, LightningOptimizer] = {} + # TODO: rename to `lr_scheduler_configs` to match the property in the `Trainer` self.lr_schedulers: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] if is_overridden("post_dispatch", self, parent=Strategy): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 125548471b529..c3b3f2988e847 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -305,8 +305,8 @@ def restore_lr_schedulers(self) -> None: # restore the lr schedulers lr_schedulers = self._loaded_checkpoint["lr_schedulers"] - for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): - scheduler["scheduler"].load_state_dict(lrs_state) + for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers): + config.scheduler.load_state_dict(lrs_state) # ---------------------------------- # PRIVATE OPS @@ -368,8 +368,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump lr schedulers lr_schedulers = [] - for scheduler in self.trainer.lr_schedulers: - lr_schedulers.append(scheduler["scheduler"].state_dict()) + for config in self.trainer.lr_scheduler_configs: + lr_schedulers.append(config.scheduler.state_dict()) checkpoint["lr_schedulers"] = lr_schedulers self.trainer.precision_plugin.on_save_checkpoint(checkpoint) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8347b43e22f13..4e19f16b29c6d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2014,19 +2014,26 @@ def lightning_optimizers(self) -> Dict[int, LightningOptimizer]: return self.strategy._lightning_optimizers @property - def lr_schedulers(self) -> List[LRSchedulerConfig]: + def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: return self.strategy.lr_schedulers - @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None: - self.strategy.lr_schedulers = new_schedulers + @property + def lr_schedulers(self) -> List[Dict[str, Any]]: + rank_zero_deprecation( + "`Trainer.lr_schedulers` is deprecated in v1.6 and will be removed in v1.8." + " You can use `trainer.lr_scheduler_configs` instead which contains dataclasses instead of dictionaries.", + stacklevel=5, + ) + from dataclasses import asdict + + return [asdict(config) for config in self.strategy.lr_schedulers] @property - def optimizer_frequencies(self) -> list: + def optimizer_frequencies(self) -> List[int]: return self.strategy.optimizer_frequencies @optimizer_frequencies.setter - def optimizer_frequencies(self, new_freqs: list) -> None: + def optimizer_frequencies(self, new_freqs: List[int]) -> None: self.strategy.optimizer_frequencies = new_freqs @property diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 84467310568f7..d4ffe7bc97f1f 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -123,7 +123,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule trainer.logger = DummyLogger() if trainer.logger is not None else None trainer.callbacks = [] # not needed before full run trainer.limit_train_batches = 1.0 - trainer.optimizers, trainer.lr_schedulers = [], [] # required for saving + trainer.optimizers, trainer.strategy.lr_schedulers = [], [] # required for saving trainer.model = model # required for saving diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 7bf1bcf34ed96..a15e65ef986b3 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -24,16 +24,13 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.optimizer import ( - _get_default_scheduler_config, - _init_optimizers_and_lr_schedulers, - _set_scheduler_opt_idx, -) +from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities.types import LRSchedulerConfig # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -127,13 +124,11 @@ def func(trainer): args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - sched_config = _get_default_scheduler_config() - sched_config.update({"scheduler": scheduler, "interval": "step", "opt_idx": 0}) trainer.strategy.optimizers = [optimizer] - trainer.strategy.lr_schedulers = [sched_config] + trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] trainer.strategy.optimizer_frequencies = [] - _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_schedulers) + _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) return func @@ -232,7 +227,7 @@ def lr_find( trainer.progress_bar_callback.disable() # Required for saving the model - trainer.optimizers, trainer.lr_schedulers = [], [] + trainer.optimizers, trainer.strategy.lr_schedulers = [], [] trainer.model = model # Dump model checkpoint @@ -339,7 +334,7 @@ def on_batch_start(self, trainer, pl_module): if self.progress_bar_refresh_rate and self.progress_bar is None: self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training) - self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0]) + self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """Called when the training batch ends, logs the calculated loss.""" diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 3e12629a947b2..43cce711ec944 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -16,6 +16,7 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union @@ -23,7 +24,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric -from typing_extensions import Protocol, runtime_checkable, TypedDict +from typing_extensions import Protocol, runtime_checkable _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] @@ -60,7 +61,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class _LRScheduler(_SupportsStateDict): +@runtime_checkable +class _LRScheduler(_SupportsStateDict, Protocol): optimizer: Optimizer def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: @@ -69,7 +71,8 @@ def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau(_SupportsStateDict): +@runtime_checkable +class ReduceLROnPlateau(_SupportsStateDict, Protocol): in_cooldown: bool optimizer: Optimizer @@ -95,12 +98,20 @@ def __init__( LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] -class LRSchedulerConfig(TypedDict): +@dataclass +class LRSchedulerConfig: scheduler: Union[_LRScheduler, ReduceLROnPlateau] - name: Optional[str] - interval: str - frequency: int - reduce_on_plateau: bool - monitor: Optional[str] - strict: bool - opt_idx: Optional[int] + # no custom name + name: Optional[str] = None + # after epoch is over + interval: str = "epoch" + # every epoch/batch + frequency: int = 1 + # most often not ReduceLROnPlateau scheduler + reduce_on_plateau: bool = False + # value to monitor for ReduceLROnPlateau + monitor: Optional[str] = None + # enforce that the monitor exists for ReduceLROnPlateau + strict: bool = True + # opt_idx assigned internally if not assigned by user + opt_idx: Optional[int] = None diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index d35b1e8eefc38..82a4a5b99894a 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -40,7 +40,7 @@ def test_lr_monitor_single_lr(tmpdir): assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default" - assert len(lr_monitor.lrs) == len(trainer.lr_schedulers) + assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-SGD"] @@ -76,7 +76,7 @@ def configure_optimizers(self): trainer.fit(model) assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" - assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers) + assert len(lr_monitor.last_momentum_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values) @@ -103,7 +103,7 @@ def configure_optimizers(self): trainer.fit(model) assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged" - assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers) + assert len(lr_monitor.last_momentum_values) == len(trainer.lr_scheduler_configs) assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values) @@ -237,7 +237,7 @@ def configure_optimizers(self): trainer.fit(model) assert lr_monitor.lrs, "No learning rates logged" - assert len(lr_monitor.lrs) == len(trainer.lr_schedulers) + assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": @@ -316,7 +316,7 @@ def configure_optimizers(self): trainer.fit(model, datamodule=dm) assert lr_monitor.lrs, "No learning rates logged" - assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers) + assert len(lr_monitor.lrs) == 2 * len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly" diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 1915f4f58d3cd..38af28c8eeb41 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -78,9 +78,9 @@ def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: - assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) - assert trainer.lr_schedulers[0]["interval"] == "epoch" - assert trainer.lr_schedulers[0]["frequency"] == 1 + assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR) + assert trainer.lr_scheduler_configs[0].interval == "epoch" + assert trainer.lr_scheduler_configs[0].frequency == 1 def on_train_epoch_end(self, trainer, *args): super().on_train_epoch_end(trainer, *args) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 363466d74d4b3..fa0d982478759 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -154,6 +154,12 @@ def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir): _ = trainer.should_rank_save_checkpoint +def test_v1_8_0_deprecated_lr_scheduler(): + trainer = Trainer() + with pytest.deprecated_call(match=r"`Trainer.lr_schedulers` is deprecated in v1.6 and will be removed in v1.8."): + assert trainer.lr_schedulers == [] + + def test_v1_8_0_trainer_optimizers_mixin(): trainer = Trainer() model = BoringModel() diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index b1d0116eb165a..17135b98c16f5 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -201,5 +201,5 @@ def configure_optimizers(self): assert trainer.state.finished, f"Training failed with {trainer.state}" assert bwd_mock.call_count == 10 - assert isinstance(trainer.lr_schedulers[0]["scheduler"].optimizer, optim.Adam) - assert isinstance(trainer.lr_schedulers[1]["scheduler"].optimizer, optim.SGD) + assert isinstance(trainer.lr_scheduler_configs[0].scheduler.optimizer, optim.Adam) + assert isinstance(trainer.lr_scheduler_configs[1].scheduler.optimizer, optim.SGD) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d6511d8db3cc2..ab1c05acb8fb0 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -168,24 +168,26 @@ def _is_equal(self, a, b): def _check_optimizers(self): return all( - self._is_equal(self.trainer.optimizers[i].state_dict(), state_dict["optimizer_states"][i]) - for i in range(len(self.trainer.optimizers)) + self._is_equal(optimizer.state_dict(), state) + for optimizer, state in zip(self.trainer.optimizers, state_dict["optimizer_states"]) ) def _check_schedulers(self): return all( - self._is_equal(self.trainer.lr_schedulers[i]["scheduler"].state_dict(), state_dict["lr_schedulers"][i]) - for i in range(len(self.trainer.lr_schedulers)) + self._is_equal(config.scheduler.state_dict(), state) + for config, state in zip(self.trainer.lr_scheduler_configs, state_dict["lr_schedulers"]) ) def _check_model_state_dict(self): - for k in self.state_dict(): - yield self._is_equal(self.state_dict()[k], state_dict["state_dict"][k]) + return all( + self._is_equal(actual, expected) + for actual, expected in zip(self.state_dict(), state_dict["state_dict"]) + ) def _test_on_val_test_predict_tune_start(self): assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] - assert all(self._check_model_state_dict()) + assert self._check_model_state_dict() # no optimizes and schedulers are loaded otherwise if self.trainer.state.fn != TrainerFn.TUNING: @@ -200,7 +202,7 @@ def on_train_start(self): else: assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] - assert all(self._check_model_state_dict()) + assert self._check_model_state_dict() assert self._check_optimizers() assert self._check_schedulers() diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index 50f0c94405e7c..e21e1cf7de04f 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -276,9 +276,9 @@ def on_train_start(self, trainer, pl_module) -> None: assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) - assert isinstance(trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) + assert isinstance(trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR) # check that the lr_scheduler config was preserved - assert trainer.lr_schedulers[0]["name"] == "Sean" + assert trainer.lr_scheduler_configs[0].name == "Sean" class TestModel(BoringModel): def configure_optimizers(self): @@ -314,7 +314,7 @@ def on_train_start(self, trainer, pl_module) -> None: assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) - assert isinstance(trainer.lr_schedulers[0]["scheduler"], WarmupLR) + assert isinstance(trainer.lr_scheduler_configs[0].scheduler, WarmupLR) model = BoringModel() trainer = Trainer( @@ -716,8 +716,8 @@ def on_train_batch_start( assert trainer.current_epoch == 1 # assert lr-scheduler states are loaded correctly - original_lr_scheduler = initial_trainer.lr_schedulers[0]["scheduler"] - current_lr_scheduler = trainer.lr_schedulers[0]["scheduler"] + original_lr_scheduler = initial_trainer.lr_scheduler_configs[0].scheduler + current_lr_scheduler = trainer.lr_scheduler_configs[0].scheduler assert original_lr_scheduler.state_dict() == current_lr_scheduler.state_dict() model = ModelParallelClassificationModel() diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index e960eabcb9b62..74abfda11b241 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -26,6 +26,7 @@ _init_optimizers_and_lr_schedulers, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -43,7 +44,7 @@ def test_optimizer_with_scheduling(tmpdir): init_lr = 0.1 adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups] - assert len(trainer.lr_schedulers) == 1 + assert len(trainer.lr_scheduler_configs) == 1 assert all(a == adjusted_lr[0] for a in adjusted_lr) assert init_lr * 0.1 == adjusted_lr[0] @@ -73,7 +74,7 @@ def configure_optimizers(self): adjusted_lr1 = [pg["lr"] for pg in trainer.optimizers[0].param_groups] adjusted_lr2 = [pg["lr"] for pg in trainer.optimizers[1].param_groups] - assert len(trainer.lr_schedulers) == 2 + assert len(trainer.lr_scheduler_configs) == 2 assert all(a == adjusted_lr1[0] for a in adjusted_lr1) assert all(a == adjusted_lr2[0] for a in adjusted_lr2) assert model.init_lr * 0.1 == adjusted_lr1[0] @@ -133,9 +134,9 @@ def configure_optimizers(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - lr_scheduler = trainer.lr_schedulers[0] - assert lr_scheduler == dict( - scheduler=lr_scheduler["scheduler"], + lr_scheduler = trainer.lr_scheduler_configs[0] + assert lr_scheduler == LRSchedulerConfig( + scheduler=lr_scheduler.scheduler, monitor="foo", interval="epoch", frequency=1, @@ -175,7 +176,7 @@ def test_optimizer_return_options(tmpdir): assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 - ref_lr_sched = dict( + ref_lr_sched = LRSchedulerConfig( scheduler=scheduler_a, interval="epoch", frequency=1, @@ -218,10 +219,10 @@ def test_optimizer_return_options(tmpdir): opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == len(freq) == 2 assert opt[0] == opt_a - ref_lr_sched["opt_idx"] = 0 + ref_lr_sched.opt_idx = 0 assert lr_sched[0] == ref_lr_sched - ref_lr_sched["scheduler"] = scheduler_b - ref_lr_sched["opt_idx"] = 1 + ref_lr_sched.scheduler = scheduler_b + ref_lr_sched.opt_idx = 1 assert lr_sched[1] == ref_lr_sched assert freq == [1, 5] @@ -309,11 +310,11 @@ def configure_optimizers(self): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.lr_schedulers[0]["opt_idx"] == 0 - assert trainer.lr_schedulers[1]["opt_idx"] == 1 + assert trainer.lr_scheduler_configs[0].opt_idx == 0 + assert trainer.lr_scheduler_configs[1].opt_idx == 1 # Step count is 1 greater than the expected value because scheduler.step() is called once during initialization - assert trainer.lr_schedulers[0]["scheduler"]._step_count == expected_steps[0] - assert trainer.lr_schedulers[1]["scheduler"]._step_count == expected_steps[1] + assert trainer.lr_scheduler_configs[0].scheduler._step_count == expected_steps[0] + assert trainer.lr_scheduler_configs[1].scheduler._step_count == expected_steps[1] @pytest.mark.parametrize("fn", ("validate", "test", "predict")) @@ -332,7 +333,7 @@ def configure_optimizers(self): train_fn = getattr(trainer, fn) train_fn(TestModel(), datamodule=BoringDataModule(), ckpt_path=None) - assert len(trainer.lr_schedulers) == 0 + assert len(trainer.lr_scheduler_configs) == 0 assert len(trainer.optimizers) == 0 assert len(trainer.optimizer_frequencies) == 0 @@ -483,7 +484,7 @@ def test_lr_scheduler_with_extra_keys_warns(tmpdir): "lr_scheduler": {"scheduler": optim.lr_scheduler.StepLR(optimizer, 1), "foo": 1, "bar": 2}, } trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the lr scheduler dict: \[.+\]"): + with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the lr scheduler dict: \{.+\}"): trainer.fit(model) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 849d75416176b..2f501a699d0c2 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -641,7 +641,7 @@ def add_arguments_to_parser(self, parser): else: assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 0 + assert len(cli.trainer.lr_scheduler_configs) == 0 def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir): @@ -658,9 +658,9 @@ def add_arguments_to_parser(self, parser): assert cli.model.configure_optimizers is not BoringModel.configure_optimizers assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR) - assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8 + assert len(cli.trainer.lr_scheduler_configs) == 1 + assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.ExponentialLR) + assert cli.trainer.lr_scheduler_configs[0].scheduler.gamma == 0.8 def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): @@ -684,9 +684,9 @@ def add_arguments_to_parser(self, parser): assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) - assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 + assert len(cli.trainer.lr_scheduler_configs) == 1 + assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR) + assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50 @pytest.mark.parametrize("use_registries", [False, True])