Skip to content

Commit 01109cd

Browse files
Lucklyricpre-commit-ci[bot]ananthsubcarmocca
authored
Fix/mismatched toggle optimizer (#7563)
* fix: avoid potential mismatched toggling of optimzier Refs #7405 chore: update CHANGELOG [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix: resolve a confict chore: update changelog * feat: add a test that fails in master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo in tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: ananthsub <[email protected]> * Polish tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Carlos Mocholí <[email protected]> * Polish tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Carlos Mocholí <[email protected]> * fix: change placeholder in optimizer_step from positional args to keyword args Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 2242423 commit 01109cd

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4040

4141
### Changed
4242

43+
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
4344

4445
- Changed the `Trainer`'s `checkpoint_callback` argument to allow only boolean values ([#7539](https://github.com/PyTorchLightning/pytorch-lightning/pull/7539))
4546

pytorch_lightning/trainer/training_loop.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
724724
# -------------------
725725
# calculate loss (train step + train step end)
726726
# -------------------
727-
728727
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
729728
# automatic_optimization=False: don't block synchronization here
730729
with self.block_ddp_sync_behaviour():
@@ -737,6 +736,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
737736
else:
738737
if self.trainer.lightning_module.automatic_optimization:
739738
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
739+
if len(self.trainer.optimizers) > 1:
740+
# revert back to previous state
741+
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
740742
else:
741743
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
742744

@@ -837,10 +839,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
837839
"training_step returned None. If this was on purpose, ignore this warning..."
838840
)
839841

840-
if len(self.trainer.optimizers) > 1:
841-
# revert back to previous state
842-
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
843-
844842
return result
845843

846844
def _check_finite(self, loss: torch.Tensor) -> None:

tests/trainer/optimization/test_multiple_optimizers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,68 @@ def training_step(self, batch, batch_idx):
168168

169169
with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'):
170170
trainer.fit(TestModel())
171+
172+
173+
def test_custom_optimizer_step_with_multiple_optimizers(tmpdir):
174+
"""
175+
This tests ensures custom optimizer_step works,
176+
even when optimizer.step is not called for a particular optimizer
177+
"""
178+
179+
class TestModel(BoringModel):
180+
training_step_called = [0, 0]
181+
optimizer_step_called = [0, 0]
182+
183+
def __init__(self):
184+
super().__init__()
185+
self.layer_a = torch.nn.Linear(32, 2)
186+
self.layer_b = torch.nn.Linear(32, 2)
187+
188+
def configure_optimizers(self):
189+
opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001)
190+
opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001)
191+
return opt_a, opt_b
192+
193+
def training_step(self, batch, batch_idx, optimizer_idx):
194+
self.training_step_called[optimizer_idx] += 1
195+
x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0])
196+
loss = torch.nn.functional.mse_loss(x, torch.ones_like(x))
197+
return loss
198+
199+
def training_epoch_end(self, outputs) -> None:
200+
# outputs should be an array with an entry per optimizer
201+
assert len(outputs) == 2
202+
203+
def optimizer_step(
204+
self,
205+
epoch,
206+
batch_idx,
207+
optimizer,
208+
optimizer_idx,
209+
optimizer_closure,
210+
**_,
211+
):
212+
# update first optimizer every step
213+
if optimizer_idx == 0:
214+
self.optimizer_step_called[optimizer_idx] += 1
215+
optimizer.step(closure=optimizer_closure)
216+
217+
# update second optimizer every 2 steps
218+
if optimizer_idx == 1:
219+
if batch_idx % 2 == 0:
220+
self.optimizer_step_called[optimizer_idx] += 1
221+
optimizer.step(closure=optimizer_closure)
222+
223+
model = TestModel()
224+
model.val_dataloader = None
225+
226+
trainer = pl.Trainer(
227+
default_root_dir=tmpdir,
228+
limit_train_batches=4,
229+
max_epochs=1,
230+
log_every_n_steps=1,
231+
weights_summary=None,
232+
)
233+
trainer.fit(model)
234+
assert model.training_step_called == [4, 2]
235+
assert model.optimizer_step_called == [4, 2]

0 commit comments

Comments
 (0)