Skip to content

Commit 0302b8b

Browse files
akihironittaawaelchlitchaton
authored
Disable lr_scheduler.step() in manual optimization (#6825)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 14e6b46 commit 0302b8b

File tree

7 files changed

+88
-5
lines changed

7 files changed

+88
-5
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9696
- Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764))
9797

9898

99+
- Added `is_last_batch` attribute to `Trainer` ([#6825](https://github.com/PyTorchLightning/pytorch-lightning/pull/6825))
100+
101+
99102
- Added `LightningModule.lr_schedulers()` for manual optimization ([#6567](https://github.com/PyTorchLightning/pytorch-lightning/pull/6567))
100103

101104

@@ -129,6 +132,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
129132
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
130133

131134

135+
- Disabled `lr_scheduler.step()` in manual optimization ([#6825](https://github.com/PyTorchLightning/pytorch-lightning/pull/6825))
136+
137+
132138
- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))
133139

134140

docs/source/common/optimizers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ To manually optimize, do the following:
3535
* ``optimizer.step()`` to update your model parameters
3636

3737
Here is a minimal example of manual optimization.
38-
38+
3939
.. testcode:: python
4040

4141
from pytorch_lightning import LightningModule

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
3232
interval: either 'epoch' or 'step'.
3333
monitor_metrics: dict of possible values to monitor
3434
"""
35-
if not self.trainer.lr_schedulers:
35+
if not self.trainer.lr_schedulers or not self.trainer.train_loop.automatic_optimization:
3636
return
3737

3838
for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers):

pytorch_lightning/trainer/optimizers.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
8080
' * A list of the previously described dict format, with an optional "frequency" key (int)'
8181
)
8282

83-
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
83+
is_manual_optimization = not self.train_loop.automatic_optimization
84+
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
8485
_validate_scheduler_optimizer(optimizers, lr_schedulers)
8586

8687
return optimizers, lr_schedulers, optimizer_frequencies
@@ -98,8 +99,13 @@ def _convert_to_lightning_optimizer(trainer, optimizer):
9899
for opt_idx, opt in enumerate(self.optimizers)
99100
}
100101

101-
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
102-
# Convert each scheduler into dict structure with relevant information
102+
def configure_schedulers(
103+
self,
104+
schedulers: list,
105+
monitor: Optional[str],
106+
is_manual_optimization: bool,
107+
) -> List[Dict[str, Any]]:
108+
"""Convert each scheduler into dict structure with relevant information"""
103109
lr_schedulers = []
104110
default_config = _get_default_scheduler_config()
105111
for scheduler in schedulers:
@@ -117,6 +123,16 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
117123
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
118124
f' but is "{scheduler["interval"]}"'
119125
)
126+
if is_manual_optimization:
127+
invalid_keys = {'interval', 'frequency', 'reduce_on_plateau', 'monitor', 'strict'}
128+
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
129+
130+
if keys_to_warn:
131+
rank_zero_warn(
132+
f'The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored.'
133+
' You need to call `lr_scheduler.step()` manually in manual optimization.',
134+
RuntimeWarning,
135+
)
120136

121137
scheduler['reduce_on_plateau'] = isinstance(
122138
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def run_training_epoch(self):
472472
for batch_idx, (batch, is_last_batch) in train_dataloader:
473473

474474
self.trainer.batch_idx = batch_idx
475+
self.trainer.is_last_batch = is_last_batch
475476

476477
# ------------------------------------
477478
# TRAINING_STEP + TRAINING_STEP_END

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,3 +1170,41 @@ def configure_optimizers(self):
11701170
)
11711171

11721172
trainer.fit(model)
1173+
1174+
1175+
def test_lr_scheduler_step_not_called(tmpdir):
1176+
"""
1177+
Test `lr_scheduler.step()` is not called in manual optimization.
1178+
"""
1179+
class TestModel(BoringModel):
1180+
def __init__(self):
1181+
super().__init__()
1182+
self.automatic_optimization = False
1183+
1184+
def training_step(self, batch, batch_idx):
1185+
opt = self.optimizers()
1186+
1187+
output = self(batch)
1188+
loss = self.loss(batch, output)
1189+
1190+
opt.zero_grad()
1191+
self.manual_backward(loss)
1192+
opt.step()
1193+
1194+
model = TestModel()
1195+
model.training_step_end = None
1196+
model.training_epoch_end = None
1197+
1198+
trainer = Trainer(
1199+
max_epochs=1,
1200+
default_root_dir=tmpdir,
1201+
fast_dev_run=2,
1202+
)
1203+
1204+
with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
1205+
trainer.fit(model)
1206+
1207+
# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
1208+
# `.step()` is called once during its instantiation.
1209+
# Thus, the call count should be 1, not 0.
1210+
assert lr_step.call_count == 1

tests/trainer/optimization/test_optimizers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,25 @@ def configure_optimizers(self):
476476
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
477477
with pytest.raises(MisconfigurationException, match="attatched with an optimizer that wasn't returned"):
478478
trainer.fit(model)
479+
480+
481+
def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir):
482+
"""
483+
Test warning when invalid scheduler keys are provided in manual optimization.
484+
"""
485+
486+
class TestModel(BoringModel):
487+
488+
def __init__(self):
489+
super().__init__()
490+
self.automatic_optimization = False
491+
492+
def configure_optimizers(self):
493+
opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
494+
sch = torch.optim.lr_scheduler.StepLR(opt, step_size=1)
495+
return [opt], [{"scheduler": sch, "interval": "epoch"}]
496+
497+
model = TestModel()
498+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
499+
with pytest.warns(RuntimeWarning, match='the keys will be ignored'):
500+
trainer.fit(model)

0 commit comments

Comments
 (0)