|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import weakref |
14 | 15 | from contextlib import contextmanager |
15 | | -from typing import Any, Callable, Generator, Optional |
| 16 | +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union |
16 | 17 | from weakref import proxy |
17 | 18 |
|
| 19 | +import torch |
| 20 | +from torch import optim |
18 | 21 | from torch.optim import Optimizer |
19 | 22 |
|
20 | 23 | import pytorch_lightning as pl |
21 | | -from pytorch_lightning.utilities import AMPType |
| 24 | +from pytorch_lightning.utilities import AMPType, rank_zero_warn |
22 | 25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
23 | 26 |
|
24 | 27 |
|
@@ -54,7 +57,8 @@ def optimizer(self) -> Optimizer: |
54 | 57 | return self._optimizer |
55 | 58 |
|
56 | 59 | def _on_trainer_init(self, trainer: "pl.Trainer") -> None: |
57 | | - self._trainer = proxy(trainer) |
| 60 | + # check if trainer is already of type weakproxy since we can't call proxy on a weakproxy |
| 61 | + self._trainer = trainer if isinstance(trainer, weakref.ProxyType) else proxy(trainer) |
58 | 62 | for opt_idx, opt in enumerate(trainer.optimizers): |
59 | 63 | if opt == self._optimizer: |
60 | 64 | self._optimizer_idx = opt_idx |
@@ -162,3 +166,227 @@ def closure_dis(): |
162 | 166 | assert trainer is not None |
163 | 167 | with trainer.profiler.profile(profiler_action): |
164 | 168 | trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) |
| 169 | + |
| 170 | + |
| 171 | +def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]: |
| 172 | + """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" |
| 173 | + model.trainer._lightning_optimizers = None |
| 174 | + optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) |
| 175 | + |
| 176 | + if optim_conf is None: |
| 177 | + rank_zero_warn( |
| 178 | + "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer", |
| 179 | + ) |
| 180 | + optim_conf = _MockOptimizer() |
| 181 | + |
| 182 | + optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf) |
| 183 | + lr_schedulers = _configure_schedulers(lr_schedulers, monitor, not model.automatic_optimization) |
| 184 | + _validate_scheduler_optimizer(optimizers, lr_schedulers) |
| 185 | + return optimizers, lr_schedulers, optimizer_frequencies |
| 186 | + |
| 187 | + |
| 188 | +def _configure_optimizers( |
| 189 | + optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple] |
| 190 | +) -> Tuple[List, List, List, Optional[str]]: |
| 191 | + optimizers, lr_schedulers, optimizer_frequencies = [], [], [] |
| 192 | + monitor = None |
| 193 | + |
| 194 | + # single output, single optimizer |
| 195 | + if isinstance(optim_conf, Optimizer): |
| 196 | + optimizers = [optim_conf] |
| 197 | + # two lists, optimizer + lr schedulers |
| 198 | + elif ( |
| 199 | + isinstance(optim_conf, (list, tuple)) |
| 200 | + and len(optim_conf) == 2 |
| 201 | + and isinstance(optim_conf[0], list) |
| 202 | + and all(isinstance(opt, Optimizer) for opt in optim_conf[0]) |
| 203 | + ): |
| 204 | + opt, sch = optim_conf |
| 205 | + optimizers = opt |
| 206 | + lr_schedulers = sch if isinstance(sch, list) else [sch] |
| 207 | + # single dictionary |
| 208 | + elif isinstance(optim_conf, dict): |
| 209 | + _validate_optim_conf(optim_conf) |
| 210 | + optimizers = [optim_conf["optimizer"]] |
| 211 | + monitor = optim_conf.get("monitor", None) |
| 212 | + lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else [] |
| 213 | + # multiple dictionaries |
| 214 | + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf): |
| 215 | + for opt_dict in optim_conf: |
| 216 | + _validate_optim_conf(opt_dict) |
| 217 | + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] |
| 218 | + scheduler_dict = ( |
| 219 | + lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx) |
| 220 | + if isinstance(scheduler, dict) |
| 221 | + else {"scheduler": scheduler, "opt_idx": opt_idx} |
| 222 | + ) |
| 223 | + |
| 224 | + lr_schedulers = [ |
| 225 | + scheduler_dict(opt_dict["lr_scheduler"], opt_idx) |
| 226 | + for opt_idx, opt_dict in enumerate(optim_conf) |
| 227 | + if "lr_scheduler" in opt_dict |
| 228 | + ] |
| 229 | + optimizer_frequencies = [ |
| 230 | + opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None |
| 231 | + ] |
| 232 | + # assert that if frequencies are present, they are given for all optimizers |
| 233 | + if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): |
| 234 | + raise ValueError("A frequency must be given to each optimizer.") |
| 235 | + # single list or tuple, multiple optimizer |
| 236 | + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf): |
| 237 | + optimizers = list(optim_conf) |
| 238 | + # unknown configuration |
| 239 | + else: |
| 240 | + raise MisconfigurationException( |
| 241 | + "Unknown configuration for model optimizers." |
| 242 | + " Output from `model.configure_optimizers()` should be one of:\n" |
| 243 | + " * `Optimizer`\n" |
| 244 | + " * [`Optimizer`]\n" |
| 245 | + " * ([`Optimizer`], [`_LRScheduler`])\n" |
| 246 | + ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n' |
| 247 | + ' * A list of the previously described dict format, with an optional "frequency" key (int)' |
| 248 | + ) |
| 249 | + return optimizers, lr_schedulers, optimizer_frequencies, monitor |
| 250 | + |
| 251 | + |
| 252 | +def _configure_schedulers( |
| 253 | + schedulers: list, monitor: Optional[str], is_manual_optimization: bool |
| 254 | +) -> List[Dict[str, Any]]: |
| 255 | + """Convert each scheduler into dict structure with relevant information.""" |
| 256 | + lr_schedulers = [] |
| 257 | + default_config = _get_default_scheduler_config() |
| 258 | + # TODO: move is_manual_optimization check out of for loop |
| 259 | + for scheduler in schedulers: |
| 260 | + if is_manual_optimization: |
| 261 | + if isinstance(scheduler, dict): |
| 262 | + invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"} |
| 263 | + keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys] |
| 264 | + |
| 265 | + if keys_to_warn: |
| 266 | + rank_zero_warn( |
| 267 | + f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored." |
| 268 | + " You need to call `lr_scheduler.step()` manually in manual optimization.", |
| 269 | + category=RuntimeWarning, |
| 270 | + ) |
| 271 | + |
| 272 | + scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys} |
| 273 | + lr_schedulers.append({**default_config, **scheduler}) |
| 274 | + else: |
| 275 | + lr_schedulers.append({**default_config, "scheduler": scheduler}) |
| 276 | + else: |
| 277 | + if isinstance(scheduler, dict): |
| 278 | + # check provided keys |
| 279 | + extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] |
| 280 | + if extra_keys: |
| 281 | + rank_zero_warn( |
| 282 | + f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning |
| 283 | + ) |
| 284 | + if "scheduler" not in scheduler: |
| 285 | + raise MisconfigurationException( |
| 286 | + 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' |
| 287 | + ) |
| 288 | + if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"): |
| 289 | + raise MisconfigurationException( |
| 290 | + 'The "interval" key in lr scheduler dict must be "step" or "epoch"' |
| 291 | + f' but is "{scheduler["interval"]}"' |
| 292 | + ) |
| 293 | + scheduler["reduce_on_plateau"] = isinstance( |
| 294 | + scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau |
| 295 | + ) |
| 296 | + if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None: |
| 297 | + raise MisconfigurationException( |
| 298 | + "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used." |
| 299 | + ' For example: {"optimizer": optimizer, "lr_scheduler":' |
| 300 | + ' {"scheduler": scheduler, "monitor": "your_loss"}}' |
| 301 | + ) |
| 302 | + is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR) |
| 303 | + if is_one_cycle and scheduler.get("interval", "epoch") == "epoch": |
| 304 | + rank_zero_warn( |
| 305 | + "A `OneCycleLR` scheduler is using 'interval': 'epoch'." |
| 306 | + " Are you sure you didn't mean 'interval': 'step'?", |
| 307 | + category=RuntimeWarning, |
| 308 | + ) |
| 309 | + lr_schedulers.append({**default_config, **scheduler}) |
| 310 | + elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): |
| 311 | + if monitor is None: |
| 312 | + raise MisconfigurationException( |
| 313 | + "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`" |
| 314 | + " scheduler is used. For example:" |
| 315 | + ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' |
| 316 | + ) |
| 317 | + lr_schedulers.append( |
| 318 | + {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} |
| 319 | + ) |
| 320 | + elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): |
| 321 | + lr_schedulers.append({**default_config, "scheduler": scheduler}) |
| 322 | + else: |
| 323 | + raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') |
| 324 | + return lr_schedulers |
| 325 | + |
| 326 | + |
| 327 | +def _get_default_scheduler_config() -> Dict[str, Any]: |
| 328 | + return { |
| 329 | + "scheduler": None, |
| 330 | + "name": None, # no custom name |
| 331 | + "interval": "epoch", # after epoch is over |
| 332 | + "frequency": 1, # every epoch/batch |
| 333 | + "reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler |
| 334 | + "monitor": None, # value to monitor for ReduceLROnPlateau |
| 335 | + "strict": True, # enforce that the monitor exists for ReduceLROnPlateau |
| 336 | + "opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified |
| 337 | + } |
| 338 | + |
| 339 | + |
| 340 | +def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None: |
| 341 | + if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers): |
| 342 | + raise MisconfigurationException( |
| 343 | + "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." |
| 344 | + ) |
| 345 | + |
| 346 | + |
| 347 | +def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: |
| 348 | + valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"} |
| 349 | + extra_keys = optim_conf.keys() - valid_keys |
| 350 | + if extra_keys: |
| 351 | + rank_zero_warn( |
| 352 | + f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning |
| 353 | + ) |
| 354 | + |
| 355 | + |
| 356 | +def _convert_to_lightning_optimizers(trainer: "pl.Trainer") -> None: |
| 357 | + def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer: |
| 358 | + if not isinstance(optimizer, LightningOptimizer): |
| 359 | + optimizer = LightningOptimizer(optimizer) # type: ignore [assignment] |
| 360 | + optimizer._on_trainer_init(trainer) |
| 361 | + return optimizer # type: ignore [return-value] |
| 362 | + |
| 363 | + trainer._lightning_optimizers = { # type: ignore [assignment] |
| 364 | + opt_idx: _convert_to_lightning_optimizer(opt) for opt_idx, opt in enumerate(trainer.optimizers) |
| 365 | + } |
| 366 | + |
| 367 | + |
| 368 | +class _MockOptimizer(Optimizer): |
| 369 | + """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from |
| 370 | + `configure_optimizers`.""" |
| 371 | + |
| 372 | + def __init__(self) -> None: |
| 373 | + super().__init__([torch.zeros(1)], {}) |
| 374 | + |
| 375 | + def add_param_group(self, param_group: Dict[Any, Any]) -> None: |
| 376 | + pass # Do Nothing |
| 377 | + |
| 378 | + def load_state_dict(self, state_dict: Dict[Any, Any]) -> None: |
| 379 | + pass # Do Nothing |
| 380 | + |
| 381 | + def state_dict(self) -> Dict[str, Any]: |
| 382 | + return {} # Return Empty |
| 383 | + |
| 384 | + def step(self, closure: Callable = None) -> None: |
| 385 | + if closure is not None: |
| 386 | + closure() |
| 387 | + |
| 388 | + def zero_grad(self, set_to_none: Optional[bool] = False) -> None: |
| 389 | + pass # Do Nothing |
| 390 | + |
| 391 | + def __repr__(self) -> str: |
| 392 | + return "No Optimizer" |
0 commit comments