Skip to content

Commit a6a28e0

Browse files
Deprecate TrainerOptimizersMixin and move functionality to core/optimizer.py (#11155)
1 parent 81301db commit a6a28e0

File tree

13 files changed

+321
-271
lines changed

13 files changed

+321
-271
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
213213
- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068))
214214

215215

216+
- Deprecated `TrainerOptimizersMixin` and moved functionality to `core/optimizer.py`([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
217+
218+
216219
- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))
217220

218221
### Removed
@@ -351,6 +354,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
351354
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
352355

353356

357+
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
358+
354359

355360
## [1.5.7] - 2021-12-21
356361

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import pytorch_lightning as pl
2626
from pytorch_lightning.callbacks.base import Callback
27-
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
27+
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
2828
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3030

pytorch_lightning/core/optimizer.py

Lines changed: 231 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import weakref
1415
from contextlib import contextmanager
15-
from typing import Any, Callable, Generator, Optional
16+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
1617
from weakref import proxy
1718

19+
import torch
20+
from torch import optim
1821
from torch.optim import Optimizer
1922

2023
import pytorch_lightning as pl
21-
from pytorch_lightning.utilities import AMPType
24+
from pytorch_lightning.utilities import AMPType, rank_zero_warn
2225
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2326

2427

@@ -54,7 +57,8 @@ def optimizer(self) -> Optimizer:
5457
return self._optimizer
5558

5659
def _on_trainer_init(self, trainer: "pl.Trainer") -> None:
57-
self._trainer = proxy(trainer)
60+
# check if trainer is already of type weakproxy since we can't call proxy on a weakproxy
61+
self._trainer = trainer if isinstance(trainer, weakref.ProxyType) else proxy(trainer)
5862
for opt_idx, opt in enumerate(trainer.optimizers):
5963
if opt == self._optimizer:
6064
self._optimizer_idx = opt_idx
@@ -162,3 +166,227 @@ def closure_dis():
162166
assert trainer is not None
163167
with trainer.profiler.profile(profiler_action):
164168
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
169+
170+
171+
def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
172+
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
173+
model.trainer._lightning_optimizers = None
174+
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
175+
176+
if optim_conf is None:
177+
rank_zero_warn(
178+
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
179+
)
180+
optim_conf = _MockOptimizer()
181+
182+
optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
183+
lr_schedulers = _configure_schedulers(lr_schedulers, monitor, not model.automatic_optimization)
184+
_validate_scheduler_optimizer(optimizers, lr_schedulers)
185+
return optimizers, lr_schedulers, optimizer_frequencies
186+
187+
188+
def _configure_optimizers(
189+
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
190+
) -> Tuple[List, List, List, Optional[str]]:
191+
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
192+
monitor = None
193+
194+
# single output, single optimizer
195+
if isinstance(optim_conf, Optimizer):
196+
optimizers = [optim_conf]
197+
# two lists, optimizer + lr schedulers
198+
elif (
199+
isinstance(optim_conf, (list, tuple))
200+
and len(optim_conf) == 2
201+
and isinstance(optim_conf[0], list)
202+
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
203+
):
204+
opt, sch = optim_conf
205+
optimizers = opt
206+
lr_schedulers = sch if isinstance(sch, list) else [sch]
207+
# single dictionary
208+
elif isinstance(optim_conf, dict):
209+
_validate_optim_conf(optim_conf)
210+
optimizers = [optim_conf["optimizer"]]
211+
monitor = optim_conf.get("monitor", None)
212+
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
213+
# multiple dictionaries
214+
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
215+
for opt_dict in optim_conf:
216+
_validate_optim_conf(opt_dict)
217+
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
218+
scheduler_dict = (
219+
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
220+
if isinstance(scheduler, dict)
221+
else {"scheduler": scheduler, "opt_idx": opt_idx}
222+
)
223+
224+
lr_schedulers = [
225+
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
226+
for opt_idx, opt_dict in enumerate(optim_conf)
227+
if "lr_scheduler" in opt_dict
228+
]
229+
optimizer_frequencies = [
230+
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
231+
]
232+
# assert that if frequencies are present, they are given for all optimizers
233+
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
234+
raise ValueError("A frequency must be given to each optimizer.")
235+
# single list or tuple, multiple optimizer
236+
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
237+
optimizers = list(optim_conf)
238+
# unknown configuration
239+
else:
240+
raise MisconfigurationException(
241+
"Unknown configuration for model optimizers."
242+
" Output from `model.configure_optimizers()` should be one of:\n"
243+
" * `Optimizer`\n"
244+
" * [`Optimizer`]\n"
245+
" * ([`Optimizer`], [`_LRScheduler`])\n"
246+
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
247+
' * A list of the previously described dict format, with an optional "frequency" key (int)'
248+
)
249+
return optimizers, lr_schedulers, optimizer_frequencies, monitor
250+
251+
252+
def _configure_schedulers(
253+
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
254+
) -> List[Dict[str, Any]]:
255+
"""Convert each scheduler into dict structure with relevant information."""
256+
lr_schedulers = []
257+
default_config = _get_default_scheduler_config()
258+
# TODO: move is_manual_optimization check out of for loop
259+
for scheduler in schedulers:
260+
if is_manual_optimization:
261+
if isinstance(scheduler, dict):
262+
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
263+
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
264+
265+
if keys_to_warn:
266+
rank_zero_warn(
267+
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
268+
" You need to call `lr_scheduler.step()` manually in manual optimization.",
269+
category=RuntimeWarning,
270+
)
271+
272+
scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
273+
lr_schedulers.append({**default_config, **scheduler})
274+
else:
275+
lr_schedulers.append({**default_config, "scheduler": scheduler})
276+
else:
277+
if isinstance(scheduler, dict):
278+
# check provided keys
279+
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
280+
if extra_keys:
281+
rank_zero_warn(
282+
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
283+
)
284+
if "scheduler" not in scheduler:
285+
raise MisconfigurationException(
286+
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
287+
)
288+
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
289+
raise MisconfigurationException(
290+
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
291+
f' but is "{scheduler["interval"]}"'
292+
)
293+
scheduler["reduce_on_plateau"] = isinstance(
294+
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
295+
)
296+
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
297+
raise MisconfigurationException(
298+
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
299+
' For example: {"optimizer": optimizer, "lr_scheduler":'
300+
' {"scheduler": scheduler, "monitor": "your_loss"}}'
301+
)
302+
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
303+
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
304+
rank_zero_warn(
305+
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
306+
" Are you sure you didn't mean 'interval': 'step'?",
307+
category=RuntimeWarning,
308+
)
309+
lr_schedulers.append({**default_config, **scheduler})
310+
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
311+
if monitor is None:
312+
raise MisconfigurationException(
313+
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
314+
" scheduler is used. For example:"
315+
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
316+
)
317+
lr_schedulers.append(
318+
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
319+
)
320+
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
321+
lr_schedulers.append({**default_config, "scheduler": scheduler})
322+
else:
323+
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
324+
return lr_schedulers
325+
326+
327+
def _get_default_scheduler_config() -> Dict[str, Any]:
328+
return {
329+
"scheduler": None,
330+
"name": None, # no custom name
331+
"interval": "epoch", # after epoch is over
332+
"frequency": 1, # every epoch/batch
333+
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
334+
"monitor": None, # value to monitor for ReduceLROnPlateau
335+
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
336+
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
337+
}
338+
339+
340+
def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
341+
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
342+
raise MisconfigurationException(
343+
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
344+
)
345+
346+
347+
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
348+
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
349+
extra_keys = optim_conf.keys() - valid_keys
350+
if extra_keys:
351+
rank_zero_warn(
352+
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
353+
)
354+
355+
356+
def _convert_to_lightning_optimizers(trainer: "pl.Trainer") -> None:
357+
def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer:
358+
if not isinstance(optimizer, LightningOptimizer):
359+
optimizer = LightningOptimizer(optimizer) # type: ignore [assignment]
360+
optimizer._on_trainer_init(trainer)
361+
return optimizer # type: ignore [return-value]
362+
363+
trainer._lightning_optimizers = { # type: ignore [assignment]
364+
opt_idx: _convert_to_lightning_optimizer(opt) for opt_idx, opt in enumerate(trainer.optimizers)
365+
}
366+
367+
368+
class _MockOptimizer(Optimizer):
369+
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
370+
`configure_optimizers`."""
371+
372+
def __init__(self) -> None:
373+
super().__init__([torch.zeros(1)], {})
374+
375+
def add_param_group(self, param_group: Dict[Any, Any]) -> None:
376+
pass # Do Nothing
377+
378+
def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
379+
pass # Do Nothing
380+
381+
def state_dict(self) -> Dict[str, Any]:
382+
return {} # Return Empty
383+
384+
def step(self, closure: Callable = None) -> None:
385+
if closure is not None:
386+
closure()
387+
388+
def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
389+
pass # Do Nothing
390+
391+
def __repr__(self) -> str:
392+
return "No Optimizer"

pytorch_lightning/strategies/ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch.nn.parallel.distributed import DistributedDataParallel
3232

3333
import pytorch_lightning as pl
34-
from pytorch_lightning.core.optimizer import LightningOptimizer
34+
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
3535
from pytorch_lightning.overrides import LightningDistributedModule
3636
from pytorch_lightning.overrides.distributed import prepare_for_backward
3737
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
@@ -347,7 +347,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
347347
del optimizer
348348
trainer = self.lightning_module.trainer
349349
trainer.optimizers = optimizers
350-
trainer.convert_to_lightning_optimizers()
350+
_convert_to_lightning_optimizers(trainer)
351351

352352
def configure_ddp(self) -> None:
353353
self.pre_configure_ddp()

pytorch_lightning/strategies/deepspeed.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
from torch.optim.lr_scheduler import _LRScheduler
2828

2929
import pytorch_lightning as pl
30+
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
3031
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
3132
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3233
from pytorch_lightning.plugins.precision import PrecisionPlugin
3334
from pytorch_lightning.strategies.ddp import DDPStrategy
34-
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
3535
from pytorch_lightning.trainer.states import TrainerFn
3636
from pytorch_lightning.utilities import GradClipAlgorithmType
3737
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -446,9 +446,7 @@ def init_deepspeed(self):
446446
self._initialize_deepspeed_inference(model)
447447

448448
def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
449-
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
450-
self.lightning_module
451-
)
449+
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
452450
if len(optimizers) > 1 or len(schedulers) > 1:
453451
raise MisconfigurationException(
454452
"DeepSpeed currently only supports single optimizer, single optional scheduler."

pytorch_lightning/strategies/sharded.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.optim import Optimizer
2020

2121
import pytorch_lightning as pl
22-
from pytorch_lightning.core.optimizer import LightningOptimizer
22+
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
2323
from pytorch_lightning.strategies.ddp import DDPStrategy
2424
from pytorch_lightning.trainer.states import TrainerFn
2525
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
@@ -50,7 +50,7 @@ def configure_ddp(self) -> None:
5050
optimizers=trainer.optimizers,
5151
)
5252
trainer.optimizers = optimizers
53-
trainer.convert_to_lightning_optimizers()
53+
_convert_to_lightning_optimizers(trainer)
5454

5555
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
5656
"""Wraps the model and optimizers with fairscale components.

pytorch_lightning/strategies/training_type_plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.utils.data import DataLoader
2323

2424
import pytorch_lightning as pl
25+
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
2526
from pytorch_lightning.overrides.base import unwrap_lightning_module
2627
from pytorch_lightning.plugins import TorchCheckpointIO
2728
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
@@ -377,7 +378,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
377378
return dataloader
378379

379380
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
380-
return trainer.init_optimizers(model)
381+
return _init_optimizers_and_lr_schedulers(model)
381382

382383
@property
383384
def restore_checkpoint_after_setup(self) -> bool:

0 commit comments

Comments
 (0)