Skip to content

Commit 6ad1b4e

Browse files
authored
Merge branch 'master' into feature/5311-flatten-dict
2 parents 4fb5040 + 062800a commit 6ad1b4e

File tree

9 files changed

+168
-50
lines changed

9 files changed

+168
-50
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added a check for optimizer attached to lr_scheduler ([#5338](https://github.com/PyTorchLightning/pytorch-lightning/pull/5338))
13+
1214
- Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402))
1315

1416

@@ -23,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2325

2426
### Fixed
2527

28+
- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333))
29+
2630
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
2731

2832

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ class LearningRateMonitor(Callback):
3333
Automatically monitor and logs learning rate for learning rate schedulers during training.
3434
3535
Args:
36-
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers
37-
at the same interval, set to `None` to log at individual interval
38-
according to the `interval` key of each scheduler. Defaults to ``None``.
36+
logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers
37+
at the same interval, set to ``None`` to log at individual interval
38+
according to the ``interval`` key of each scheduler. Defaults to ``None``.
3939
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
40-
has the `momentum` attribute. Defaults to ``False``.
40+
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
4141
4242
Example::
4343
@@ -47,17 +47,19 @@ class LearningRateMonitor(Callback):
4747
>>> trainer = Trainer(callbacks=[lr_monitor])
4848
4949
Logging names are automatically determined based on optimizer class name.
50-
In case of multiple optimizers of same type, they will be named `Adam`,
51-
`Adam-1` etc. If a optimizer has multiple parameter groups they will
52-
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
53-
`name` keyword in the construction of the learning rate schdulers
50+
In case of multiple optimizers of same type, they will be named ``Adam``,
51+
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
52+
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
53+
``name`` keyword in the construction of the learning rate schdulers
5454
5555
Example::
5656
5757
def configure_optimizer(self):
5858
optimizer = torch.optim.Adam(...)
59-
lr_scheduler = {'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
60-
'name': 'my_logging_name'}
59+
lr_scheduler = {
60+
'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
61+
'name': 'my_logging_name'
62+
}
6163
return [optimizer], [lr_scheduler]
6264
6365
"""
@@ -80,16 +82,28 @@ def on_train_start(self, trainer, *args, **kwargs):
8082
"""
8183
if not trainer.logger:
8284
raise MisconfigurationException(
83-
'Cannot use LearningRateMonitor callback with Trainer that has no logger.'
85+
'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.'
8486
)
8587

8688
if not trainer.lr_schedulers:
8789
rank_zero_warn(
88-
'You are using LearningRateMonitor callback with models that'
90+
'You are using `LearningRateMonitor` callback with models that'
8991
' have no learning rate schedulers. Please see documentation'
9092
' for `configure_optimizers` method.', RuntimeWarning
9193
)
9294

95+
if self.log_momentum:
96+
def _check_no_key(key):
97+
return any(
98+
key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers
99+
)
100+
101+
if _check_no_key('momentum') and _check_no_key('betas'):
102+
rank_zero_warn(
103+
"You have set log_momentum=True, but some optimizers do not"
104+
" have momentum. This will log a value 0 for the momentum.", RuntimeWarning
105+
)
106+
93107
# Find names for schedulers
94108
names = self._find_names(trainer.lr_schedulers)
95109

@@ -121,19 +135,17 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
121135

122136
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
123137
if scheduler['interval'] == interval or interval == 'any':
124-
param_groups = scheduler['scheduler'].optimizer.param_groups
125-
if len(param_groups) != 1:
126-
for i, pg in enumerate(param_groups):
127-
lr = self._extract_lr(param_group=pg, name=f'{name}/pg{i + 1}')
128-
latest_stat.update(lr)
129-
momentum = self._extract_momentum(param_group=pg, name=f'{name}-momentum/pg{i + 1}')
130-
latest_stat.update(momentum)
131-
132-
else:
133-
pg = param_groups[0]
134-
lr = self._extract_lr(param_group=pg, name=name)
138+
opt = scheduler['scheduler'].optimizer
139+
param_groups = opt.param_groups
140+
use_betas = 'betas' in opt.defaults
141+
142+
for i, pg in enumerate(param_groups):
143+
suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
144+
lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')
135145
latest_stat.update(lr)
136-
momentum = self._extract_momentum(param_group=pg, name=f'{name}-momentum')
146+
momentum = self._extract_momentum(
147+
param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
148+
)
137149
latest_stat.update(momentum)
138150

139151
return latest_stat
@@ -143,11 +155,11 @@ def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
143155
self.lrs[name].append(lr)
144156
return {name: lr}
145157

146-
def _extract_momentum(self, param_group, name: str) -> Dict[str, float]:
158+
def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:
147159
if not self.log_momentum:
148160
return {}
149161

150-
momentum = param_group.get('momentum')
162+
momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)
151163
self.last_momentum_values[name] = momentum
152164
return {name: momentum}
153165

pytorch_lightning/core/lightning.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
"""nn.Module with additional great features."""
1616

17-
from abc import ABC
18-
from argparse import Namespace
1917
import collections
2018
import copy
2119
import inspect
2220
import os
23-
from pathlib import Path
2421
import re
2522
import tempfile
23+
from abc import ABC
24+
from argparse import Namespace
25+
from pathlib import Path
2626
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
2727

2828
import torch
@@ -1327,9 +1327,17 @@ def tbptt_split_batch(self, batch, split_size):
13271327

13281328
return splits
13291329

1330-
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
1331-
model_summary = ModelSummary(self, mode=mode)
1332-
log.info("\n" + str(model_summary))
1330+
def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]:
1331+
model_summary = None
1332+
1333+
if mode in ModelSummary.MODES:
1334+
model_summary = ModelSummary(self, mode=mode)
1335+
log.info("\n" + str(model_summary))
1336+
elif mode is not None:
1337+
raise MisconfigurationException(
1338+
f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}"
1339+
)
1340+
13331341
return model_summary
13341342

13351343
def freeze(self) -> None:

pytorch_lightning/trainer/optimizers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
7575
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
7676
' * A list of the previously described dict format, with an optional "frequency" key (int)'
7777
)
78+
7879
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
80+
_validate_scheduler_optimizer(optimizers, lr_schedulers)
7981

8082
return optimizers, lr_schedulers, optimizer_frequencies
8183

@@ -183,3 +185,10 @@ def zero_grad(self):
183185

184186
def __repr__(self):
185187
return 'No Optimizer'
188+
189+
190+
def _validate_scheduler_optimizer(optimizers, lr_schedulers):
191+
if any(sch['scheduler'].optimizer not in optimizers for sch in lr_schedulers):
192+
raise MisconfigurationException(
193+
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
194+
)

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def __init__(
311311
self.plugin_connector = PluginConnector(self)
312312

313313
# training state
314-
self.weights_summary = weights_summary
315314
self.model = None
316315
self.shown_warnings = set()
317316

@@ -374,7 +373,8 @@ def __init__(
374373
max_steps,
375374
min_steps,
376375
num_sanity_val_steps,
377-
automatic_optimization
376+
automatic_optimization,
377+
weights_summary,
378378
)
379379
self.evaluation_loop.on_trainer_init()
380380

pytorch_lightning/trainer/training_loop.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ def __init__(self, trainer):
4949
self._cur_grad_norm_dict = None
5050

5151
def on_trainer_init(
52-
self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization
52+
self,
53+
max_epochs,
54+
min_epochs,
55+
max_steps,
56+
min_steps,
57+
num_sanity_val_steps,
58+
automatic_optimization,
59+
weights_summary,
5360
):
5461
self.trainer.global_step = 0
5562
self.trainer.current_epoch = 0
@@ -73,6 +80,12 @@ def on_trainer_init(
7380
else:
7481
self.trainer.num_sanity_val_steps = num_sanity_val_steps
7582

83+
self.trainer.weights_summary = weights_summary
84+
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
85+
raise MisconfigurationException(
86+
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}"
87+
)
88+
7689
@property
7790
def num_optimizers(self):
7891
num_optimizers = len(self.get_optimizers_iterable())
@@ -161,11 +174,8 @@ def setup_training(self, model: LightningModule):
161174
ref_model.on_pretrain_routine_start()
162175

163176
# print model summary
164-
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
165-
if self.trainer.weights_summary in ModelSummary.MODES:
166-
ref_model.summarize(mode=self.trainer.weights_summary)
167-
else:
168-
raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES))
177+
if self.trainer.is_global_zero and not self.trainer.testing:
178+
ref_model.summarize(mode=self.trainer.weights_summary)
169179

170180
# track model now.
171181
# if cluster resets state, the model will update with the saved weights

tests/callbacks/test_lr_monitor.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pytest
15+
from torch import optim
1516

1617
import tests.base.develop_utils as tutils
1718
from pytorch_lightning import Trainer
@@ -47,19 +48,34 @@ def test_lr_monitor_single_lr(tmpdir):
4748
'Names of learning rates not set correctly'
4849

4950

50-
def test_lr_monitor_single_lr_with_momentum(tmpdir):
51-
""" Test that learning rates and momentum are extracted and logged for single lr scheduler. """
52-
tutils.reset_seed()
51+
@pytest.mark.parametrize('opt', ['SGD', 'Adam'])
52+
def test_lr_monitor_single_lr_with_momentum(tmpdir, opt):
53+
"""
54+
Test that learning rates and momentum are extracted and logged for single lr scheduler.
55+
"""
56+
class LogMomentumModel(BoringModel):
57+
def __init__(self, opt):
58+
super().__init__()
59+
self.opt = opt
5360

54-
model = EvalModelTemplate()
55-
model.configure_optimizers = model.configure_optimizers__onecycle_scheduler
61+
def configure_optimizers(self):
62+
if self.opt == 'SGD':
63+
opt_kwargs = {'momentum': 0.9}
64+
elif self.opt == 'Adam':
65+
opt_kwargs = {'betas': (0.9, 0.999)}
5666

67+
optimizer = getattr(optim, self.opt)(self.parameters(), lr=1e-2, **opt_kwargs)
68+
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, total_steps=10_000)
69+
return [optimizer], [lr_scheduler]
70+
71+
model = LogMomentumModel(opt=opt)
5772
lr_monitor = LearningRateMonitor(log_momentum=True)
5873
trainer = Trainer(
5974
default_root_dir=tmpdir,
6075
max_epochs=2,
61-
limit_val_batches=0.1,
62-
limit_train_batches=0.5,
76+
limit_val_batches=2,
77+
limit_train_batches=5,
78+
log_every_n_steps=1,
6379
callbacks=[lr_monitor],
6480
)
6581
result = trainer.fit(model)
@@ -69,7 +85,39 @@ def test_lr_monitor_single_lr_with_momentum(tmpdir):
6985
'Expected momentum to be logged'
7086
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers), \
7187
'Number of momentum values logged does not match number of lr schedulers'
72-
assert all([k in ['lr-SGD-momentum'] for k in lr_monitor.last_momentum_values.keys()]), \
88+
assert all(k == f'lr-{opt}-momentum' for k in lr_monitor.last_momentum_values.keys()), \
89+
'Names of momentum values not set correctly'
90+
91+
92+
def test_log_momentum_no_momentum_optimizer(tmpdir):
93+
"""
94+
Test that if optimizer doesn't have momentum then a warning is raised with log_momentum=True.
95+
"""
96+
class LogMomentumModel(BoringModel):
97+
def configure_optimizers(self):
98+
optimizer = optim.ASGD(self.parameters(), lr=1e-2)
99+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
100+
return [optimizer], [lr_scheduler]
101+
102+
model = LogMomentumModel()
103+
lr_monitor = LearningRateMonitor(log_momentum=True)
104+
trainer = Trainer(
105+
default_root_dir=tmpdir,
106+
max_epochs=1,
107+
limit_val_batches=2,
108+
limit_train_batches=5,
109+
log_every_n_steps=1,
110+
callbacks=[lr_monitor],
111+
)
112+
with pytest.warns(RuntimeWarning, match="optimizers do not have momentum."):
113+
result = trainer.fit(model)
114+
assert result
115+
116+
assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), \
117+
'Expected momentum to be logged'
118+
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers), \
119+
'Number of momentum values logged does not match number of lr schedulers'
120+
assert all(k == 'lr-ASGD-momentum' for k in lr_monitor.last_momentum_values.keys()), \
73121
'Names of momentum values not set correctly'
74122

75123

@@ -105,7 +153,7 @@ def test_lr_monitor_no_logger(tmpdir):
105153
logger=False
106154
)
107155

108-
with pytest.raises(MisconfigurationException, match='Trainer that has no logger'):
156+
with pytest.raises(MisconfigurationException, match='`Trainer` that has no logger'):
109157
trainer.fit(model)
110158

111159

tests/core/test_memory.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import torch
1616
import torch.nn as nn
1717

18-
from pytorch_lightning import LightningModule
18+
from pytorch_lightning import LightningModule, Trainer
1919
from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2021
from tests.base.models import ParityModuleRNN
2122

2223

@@ -68,6 +69,15 @@ def forward(self, x):
6869
return self.reduce(self.embed(x))
6970

7071

72+
def test_invalid_weights_summmary():
73+
""" Test that invalid value for weights_summary raises an error. """
74+
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
75+
UnorderedModel().summarize(mode='temp')
76+
77+
with pytest.raises(MisconfigurationException, match='`weights_summary` can be None, .* got temp'):
78+
Trainer(weights_summary='temp')
79+
80+
7181
@pytest.mark.parametrize(['mode'], [
7282
pytest.param(ModelSummary.MODE_FULL),
7383
pytest.param(ModelSummary.MODE_TOP),

0 commit comments

Comments
 (0)