Skip to content

Commit 28242f0

Browse files
ethanwharrisBorda
andauthored
Remove default optimizer, add None optimizer option (#1279)
* Add warning when using default optimizer * Refactor optimizer tests to test_optimizers * Remove default optimizer, add option to use no optimizer * Update CHANGELOG.md * Update pytorch_lightning/trainer/optimizers.py Co-Authored-By: Jirka Borovec <[email protected]> * Fix style Co-authored-by: Jirka Borovec <[email protected]>
1 parent 80dc979 commit 28242f0

File tree

11 files changed

+267
-168
lines changed

11 files changed

+267
-168
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
1818
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
1919
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
20+
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
21+
22+
### Changed
23+
24+
- Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
2025
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
2126
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
2227
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))

pytorch_lightning/core/lightning.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import warnings
55
from abc import ABC, abstractmethod
66
from argparse import Namespace
7-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
88

99
import torch
1010
import torch.distributed as torch_distrib
1111
from torch import Tensor
1212
from torch.nn.parallel import DistributedDataParallel
13-
from torch.optim import Adam
1413
from torch.optim.optimizer import Optimizer
1514
from torch.utils.data import DataLoader
1615

@@ -905,21 +904,20 @@ def configure_apex(self, amp, model, optimizers, amp_level):
905904

906905
return model, optimizers
907906

908-
def configure_optimizers(self) -> Union[
909-
Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], List]
910-
]:
907+
def configure_optimizers(self) -> Optional[Union[
908+
Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]
909+
]]:
911910
r"""
912911
Choose what optimizers and learning-rate schedulers to use in your optimization.
913912
Normally you'd need one. But in the case of GANs or similar you might have multiple.
914913
915-
If you don't define this method Lightning will automatically use Adam(lr=1e-3)
916-
917-
Return: any of these 5 options:
914+
Return: any of these 6 options:
918915
- Single optimizer.
919916
- List or Tuple - List of optimizers.
920917
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
921918
- Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key.
922919
- Tuple of dictionaries as described, with an optional `frequency` key.
920+
- None - Fit will run without any optimizer.
923921
924922
Note:
925923
The `frequency` value is an int corresponding to the number of sequential batches
@@ -932,7 +930,7 @@ def configure_optimizers(self) -> Union[
932930
Examples:
933931
.. code-block:: python
934932
935-
# most cases (default if not defined)
933+
# most cases
936934
def configure_optimizers(self):
937935
opt = Adam(self.parameters(), lr=1e-3)
938936
return opt
@@ -1005,7 +1003,6 @@ def configure_optimizers(self):
10051003
}
10061004
10071005
"""
1008-
return Adam(self.parameters(), lr=1e-3)
10091006

10101007
def optimizer_step(
10111008
self,

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,7 @@ def ddp_train(self, gpu_idx, model):
304304

305305
# CHOOSE OPTIMIZER
306306
# allow for lr schedulers as well
307-
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
308-
self.init_optimizers(model.configure_optimizers())
307+
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
309308

310309
# MODEL
311310
# copy model to each gpu

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,7 @@ def single_gpu_train(self, model):
462462

463463
# CHOOSE OPTIMIZER
464464
# allow for lr schedulers as well
465-
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
466-
self.init_optimizers(model.configure_optimizers())
465+
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
467466

468467
if self.use_amp:
469468
# An example
@@ -489,8 +488,7 @@ def tpu_train(self, tpu_core_idx, model):
489488

490489
# CHOOSE OPTIMIZER
491490
# allow for lr schedulers as well
492-
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
493-
self.init_optimizers(model.configure_optimizers())
491+
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
494492

495493
# init 16 bit for TPU
496494
if self.precision == 16:
@@ -508,8 +506,7 @@ def dp_train(self, model):
508506

509507
# CHOOSE OPTIMIZER
510508
# allow for lr schedulers as well
511-
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
512-
self.init_optimizers(model.configure_optimizers())
509+
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
513510

514511
model.cuda(self.root_gpu)
515512

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import warnings
2+
from abc import ABC
3+
from typing import List, Tuple
4+
5+
import torch
6+
from torch import optim
7+
from torch.optim.optimizer import Optimizer
8+
9+
from pytorch_lightning.core.lightning import LightningModule
10+
11+
12+
class TrainerOptimizersMixin(ABC):
13+
14+
def init_optimizers(
15+
self,
16+
model: LightningModule
17+
) -> Tuple[List, List, List]:
18+
optim_conf = model.configure_optimizers()
19+
20+
if optim_conf is None:
21+
warnings.warn('`LightningModule.configure_optimizers` returned `None`, '
22+
'this fit will run with no optimizer', UserWarning)
23+
optim_conf = _MockOptimizer()
24+
25+
# single output, single optimizer
26+
if isinstance(optim_conf, Optimizer):
27+
return [optim_conf], [], []
28+
29+
# two lists, optimizer + lr schedulers
30+
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
31+
and isinstance(optim_conf[0], list):
32+
optimizers, lr_schedulers = optim_conf
33+
lr_schedulers = self.configure_schedulers(lr_schedulers)
34+
return optimizers, lr_schedulers, []
35+
36+
# single dictionary
37+
elif isinstance(optim_conf, dict):
38+
optimizer = optim_conf["optimizer"]
39+
lr_scheduler = optim_conf.get("lr_scheduler", [])
40+
if lr_scheduler:
41+
lr_schedulers = self.configure_schedulers([lr_scheduler])
42+
return [optimizer], lr_schedulers, []
43+
44+
# multiple dictionaries
45+
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
46+
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
47+
# take only lr wif exists and ot they are defined - not None
48+
lr_schedulers = [
49+
opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")
50+
]
51+
# take only freq wif exists and ot they are defined - not None
52+
optimizer_frequencies = [
53+
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")
54+
]
55+
56+
# clean scheduler list
57+
if lr_schedulers:
58+
lr_schedulers = self.configure_schedulers(lr_schedulers)
59+
# assert that if frequencies are present, they are given for all optimizers
60+
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
61+
raise ValueError("A frequency must be given to each optimizer.")
62+
return optimizers, lr_schedulers, optimizer_frequencies
63+
64+
# single list or tuple, multiple optimizer
65+
elif isinstance(optim_conf, (list, tuple)):
66+
return list(optim_conf), [], []
67+
68+
# unknown configuration
69+
else:
70+
raise ValueError(
71+
'Unknown configuration for model optimizers.'
72+
' Output from `model.configure_optimizers()` should either be:'
73+
' * single output, single `torch.optim.Optimizer`'
74+
' * single output, list of `torch.optim.Optimizer`'
75+
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
76+
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
77+
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
78+
' a list of `torch.optim.lr_scheduler`'
79+
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
80+
81+
def configure_schedulers(self, schedulers: list):
82+
# Convert each scheduler into dict sturcture with relevant information
83+
lr_schedulers = []
84+
default_config = {'interval': 'epoch', # default every epoch
85+
'frequency': 1, # default every epoch/batch
86+
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
87+
'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau
88+
for scheduler in schedulers:
89+
if isinstance(scheduler, dict):
90+
if 'scheduler' not in scheduler:
91+
raise ValueError(f'Lr scheduler should have key `scheduler`',
92+
' with item being a lr scheduler')
93+
scheduler['reduce_on_plateau'] = isinstance(
94+
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)
95+
96+
lr_schedulers.append({**default_config, **scheduler})
97+
98+
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
99+
lr_schedulers.append({**default_config, 'scheduler': scheduler,
100+
'reduce_on_plateau': True})
101+
102+
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
103+
lr_schedulers.append({**default_config, 'scheduler': scheduler})
104+
else:
105+
raise ValueError(f'Input {scheduler} to lr schedulers '
106+
'is a invalid input.')
107+
return lr_schedulers
108+
109+
110+
class _MockOptimizer(Optimizer):
111+
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
112+
is returned from `configure_optimizers`.
113+
"""
114+
115+
def __init__(self):
116+
super().__init__([torch.zeros(1)], {})
117+
118+
def add_param_group(self, param_group):
119+
pass # Do Nothing
120+
121+
def load_state_dict(self, state_dict):
122+
pass # Do Nothing
123+
124+
def state_dict(self):
125+
return {} # Return Empty
126+
127+
def step(self, closure=None):
128+
if closure is not None:
129+
closure()
130+
131+
def zero_grad(self):
132+
pass # Do Nothing
133+
134+
def __repr__(self):
135+
return 'No Optimizer'

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1+
import distutils
12
import inspect
23
import os
34
import sys
45
import warnings
56
from argparse import ArgumentParser
6-
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence
7-
import distutils
7+
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
88

99
import torch
1010
import torch.distributed as torch_distrib
1111
import torch.multiprocessing as mp
12-
from torch import optim
13-
from torch.optim.optimizer import Optimizer
1412
from torch.utils.data import DataLoader
1513
from tqdm.auto import tqdm
1614

@@ -29,11 +27,12 @@
2927
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
3028
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
3129
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
30+
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
31+
from pytorch_lightning.trainer.supporters import TensorRunningMean
3232
from pytorch_lightning.trainer.training_io import TrainerIOMixin
3333
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
3434
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
36-
from pytorch_lightning.trainer.supporters import TensorRunningMean
3736

3837
try:
3938
from apex import amp
@@ -54,6 +53,7 @@
5453

5554
class Trainer(
5655
TrainerIOMixin,
56+
TrainerOptimizersMixin,
5757
TrainerDPMixin,
5858
TrainerDDPMixin,
5959
TrainerLoggingMixin,
@@ -712,8 +712,7 @@ def fit(
712712

713713
# CHOOSE OPTIMIZER
714714
# allow for lr schedulers as well
715-
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
716-
self.init_optimizers(model.configure_optimizers())
715+
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
717716

718717
self.run_pretrain_routine(model)
719718

@@ -757,90 +756,6 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da
757756

758757
model.test_dataloader = _PatchDataLoader(test_dataloaders)
759758

760-
def init_optimizers(
761-
self,
762-
optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
763-
) -> Tuple[List, List, List]:
764-
765-
# single output, single optimizer
766-
if isinstance(optim_conf, Optimizer):
767-
return [optim_conf], [], []
768-
769-
# two lists, optimizer + lr schedulers
770-
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
771-
optimizers, lr_schedulers = optim_conf
772-
lr_schedulers = self.configure_schedulers(lr_schedulers)
773-
return optimizers, lr_schedulers, []
774-
775-
# single dictionary
776-
elif isinstance(optim_conf, dict):
777-
optimizer = optim_conf["optimizer"]
778-
lr_scheduler = optim_conf.get("lr_scheduler", [])
779-
if lr_scheduler:
780-
lr_schedulers = self.configure_schedulers([lr_scheduler])
781-
return [optimizer], lr_schedulers, []
782-
783-
# multiple dictionaries
784-
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
785-
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
786-
# take only lr wif exists and ot they are defined - not None
787-
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")]
788-
# take only freq wif exists and ot they are defined - not None
789-
optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")]
790-
791-
# clean scheduler list
792-
if lr_schedulers:
793-
lr_schedulers = self.configure_schedulers(lr_schedulers)
794-
# assert that if frequencies are present, they are given for all optimizers
795-
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
796-
raise ValueError("A frequency must be given to each optimizer.")
797-
return optimizers, lr_schedulers, optimizer_frequencies
798-
799-
# single list or tuple, multiple optimizer
800-
elif isinstance(optim_conf, (list, tuple)):
801-
return list(optim_conf), [], []
802-
803-
# unknown configuration
804-
else:
805-
raise ValueError(
806-
'Unknown configuration for model optimizers.'
807-
' Output from `model.configure_optimizers()` should either be:'
808-
' * single output, single `torch.optim.Optimizer`'
809-
' * single output, list of `torch.optim.Optimizer`'
810-
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
811-
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
812-
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
813-
' a list of `torch.optim.lr_scheduler`'
814-
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
815-
816-
def configure_schedulers(self, schedulers: list):
817-
# Convert each scheduler into dict sturcture with relevant information
818-
lr_schedulers = []
819-
default_config = {'interval': 'epoch', # default every epoch
820-
'frequency': 1, # default every epoch/batch
821-
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
822-
'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau
823-
for scheduler in schedulers:
824-
if isinstance(scheduler, dict):
825-
if 'scheduler' not in scheduler:
826-
raise ValueError(f'Lr scheduler should have key `scheduler`',
827-
' with item being a lr scheduler')
828-
scheduler['reduce_on_plateau'] = isinstance(
829-
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)
830-
831-
lr_schedulers.append({**default_config, **scheduler})
832-
833-
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
834-
lr_schedulers.append({**default_config, 'scheduler': scheduler,
835-
'reduce_on_plateau': True})
836-
837-
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
838-
lr_schedulers.append({**default_config, 'scheduler': scheduler})
839-
else:
840-
raise ValueError(f'Input {scheduler} to lr schedulers '
841-
'is a invalid input.')
842-
return lr_schedulers
843-
844759
def run_pretrain_routine(self, model: LightningModule):
845760
"""Sanity check a few things before starting actual training.
846761

tests/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
LightTestMultipleOptimizersWithSchedulingMixin,
2727
LightTestOptimizersWithMixedSchedulingMixin,
2828
LightTestReduceLROnPlateauMixin,
29+
LightTestNoneOptimizerMixin,
2930
LightZeroLenDataloader
3031
)
3132

0 commit comments

Comments
 (0)