Skip to content

Commit 01c31ae

Browse files
CAIQTcarmocca
andauthored
Fix LightningModule.{un,}toggle_model when only 1 optimizer is used (#12088)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 17bb815 commit 01c31ae

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
666666
- Fixed the mid-epoch warning call while resuming training ([#11556](https://github.com/PyTorchLightning/pytorch-lightning/pull/11556))
667667

668668

669+
- Fixed `LightningModule.{un,}toggle_model` when only 1 optimizer is used ([#12088](https://github.com/PyTorchLightning/pytorch-lightning/pull/12088))
670+
671+
669672
- Fixed an issue in `RichProgressbar` to display the metrics logged only on main progress bar ([#11690](https://github.com/PyTorchLightning/pytorch-lightning/pull/11690))
670673

671674

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,7 @@ def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], opti
13841384
# Iterate over all optimizer parameters to preserve their `requires_grad` information
13851385
# in case these are pre-defined during `configure_optimizers`
13861386
param_requires_grad_state = {}
1387-
for opt in self.optimizers(use_pl_optimizer=False):
1387+
for opt in self.trainer.optimizers:
13881388
for group in opt.param_groups:
13891389
for param in group["params"]:
13901390
# If a param already appear in param_requires_grad_state, continue
@@ -1408,7 +1408,7 @@ def untoggle_optimizer(self, optimizer_idx: int) -> None:
14081408
Args:
14091409
optimizer_idx: The index of the optimizer to untoggle.
14101410
"""
1411-
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
1411+
for opt_idx, opt in enumerate(self.trainer.optimizers):
14121412
if optimizer_idx != opt_idx:
14131413
for group in opt.param_groups:
14141414
for param in group["params"]:

tests/core/test_lightning_module.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ def test_property_loggers(tmpdir):
8787
assert model.loggers == [logger]
8888

8989

90+
def test_1_optimizer_toggle_model():
91+
"""Test toggle_model runs when only one optimizer is used."""
92+
model = BoringModel()
93+
trainer = Mock()
94+
model.trainer = trainer
95+
params = model.parameters()
96+
optimizer = torch.optim.SGD(params, lr=0.1)
97+
trainer.optimizers = [optimizer]
98+
99+
assert not model._param_requires_grad_state
100+
# toggle optimizer was failing with a single optimizer
101+
model.toggle_optimizer(optimizer, 0)
102+
assert model._param_requires_grad_state
103+
model.untoggle_optimizer(0)
104+
assert not model._param_requires_grad_state
105+
106+
90107
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):
91108
class TestModel(BoringModel):
92109
def training_step(self, batch, batch_idx, optimizer_idx=None):

0 commit comments

Comments
 (0)