Skip to content

Commit ec9da3b

Browse files
committed
integrate #7563
1 parent ca2ff4b commit ec9da3b

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

pytorch_lightning/loops/batch_loop.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
124124
# -------------------
125125
# calculate loss (train step + train step end)
126126
# -------------------
127-
128127
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
129128
# automatic_optimization=False: don't block synchronization here
130129
with self.block_ddp_sync_behaviour():
@@ -137,6 +136,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
137136
else:
138137
if self.trainer.lightning_module.automatic_optimization:
139138
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
139+
if len(self.trainer.optimizers) > 1:
140+
# revert back to previous state
141+
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
140142
else:
141143
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
142144

@@ -448,10 +450,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
448450
"training_step returned None. If this was on purpose, ignore this warning..."
449451
)
450452

451-
if len(self.trainer.optimizers) > 1:
452-
# revert back to previous state
453-
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
454-
455453
return result
456454

457455
def _check_finite(self, loss: torch.Tensor) -> None:

0 commit comments

Comments
 (0)