Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

- Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272))

Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
# -------------------
# calculate loss (train step + train step end)
# -------------------

# automatic_optimization=True: perform ddp sync only when performing optimizer_step
# automatic_optimization=False: don't block synchronization here
with self.block_ddp_sync_behaviour():
Expand All @@ -739,6 +738,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
else:
if self.trainer.lightning_module.automatic_optimization:
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
else:
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)

Expand Down Expand Up @@ -839,10 +841,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
"training_step returned None. If this was on purpose, ignore this warning..."
)

if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.lightning_module.untoggle_optimizer(opt_idx)

return result

def _check_finite(self, loss: torch.Tensor) -> None:
Expand Down
65 changes: 65 additions & 0 deletions tests/trainer/optimization/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,68 @@ def training_step(self, batch, batch_idx):

with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'):
trainer.fit(TestModel())


def test_custom_optimizer_step_with_multiple_optimizers(tmpdir):
"""
This tests ensures custom optimizer_step works,
even when optimizer.step is not called for a particular optimizer
"""

class TestModel(BoringModel):
training_step_called = [0, 0]
optimizer_step_called = [0, 0]

def __init__(self):
super().__init__()
self.layer_a = torch.nn.Linear(32, 2)
self.layer_b = torch.nn.Linear(32, 2)

def configure_optimizers(self):
opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001)
opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001)
return opt_a, opt_b

def training_step(self, batch, batch_idx, optimizer_idx):
self.training_step_called[optimizer_idx] += 1
x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0])
loss = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return loss

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
**_,
):
# update first optimizer every step
if optimizer_idx == 0:
self.optimizer_step_called[optimizer_idx] += 1
optimizer.step(closure=optimizer_closure)

# update second optimizer every 2 steps
if optimizer_idx == 1:
if batch_idx % 2 == 0:
self.optimizer_step_called[optimizer_idx] += 1
optimizer.step(closure=optimizer_closure)

model = TestModel()
model.val_dataloader = None

trainer = pl.Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
assert model.training_step_called == [4, 2]
assert model.optimizer_step_called == [4, 2]