Skip to content

Commit 032cdad

Browse files
committed
fix: avoid potential mismatched toggling of optimzier
Refs #7405 chore: update CHANGELOG [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix: resolve a confict chore: update changelog
1 parent 20f6337 commit 032cdad

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Changed
2727

28+
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
2829

2930
- Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272))
3031

pytorch_lightning/trainer/training_loop.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
726726
# -------------------
727727
# calculate loss (train step + train step end)
728728
# -------------------
729-
730729
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
731730
# automatic_optimization=False: don't block synchronization here
732731
with self.block_ddp_sync_behaviour():
@@ -739,6 +738,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
739738
else:
740739
if self.trainer.lightning_module.automatic_optimization:
741740
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
741+
if len(self.trainer.optimizers) > 1:
742+
# revert back to previous state
743+
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
742744
else:
743745
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
744746

@@ -839,10 +841,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
839841
"training_step returned None. If this was on purpose, ignore this warning..."
840842
)
841843

842-
if len(self.trainer.optimizers) > 1:
843-
# revert back to previous state
844-
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
845-
846844
return result
847845

848846
def _check_finite(self, loss: torch.Tensor) -> None:

0 commit comments

Comments
 (0)