Skip to content

Commit ced2c94

Browse files
authored
fix missing call to untoggle_optimizer when accumulating gradients (#8284)
* add fix * toggle test * re-structure * update changelog * update comment * remove debugging assertion
1 parent d7a0786 commit ced2c94

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
358358
- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))
359359

360360

361+
- Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284))
362+
363+
361364
## [1.3.8] - 2021-07-01
362365

363366
### Fixed

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,20 +204,17 @@ def _run_optimization(
204204
else:
205205
if self.trainer.lightning_module.automatic_optimization:
206206
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
207-
if len(self.trainer.optimizers) > 1:
208-
# revert back to previous state
209-
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
210207
else:
211208
result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)
212209

213-
if not result:
214-
# user decided to skip optimization
215-
return result
216-
217-
# update running loss + reset accumulated loss
210+
if result:
211+
# if no result, user decided to skip optimization
212+
# otherwise update running loss + reset accumulated loss
218213
self._update_running_loss(result.loss)
214+
self._process_closure_result(result)
219215

220-
self._process_closure_result(result)
216+
# untoggle model params
217+
self._run_optimization_end(opt_idx)
221218
return result
222219

223220
def _training_step_and_backward_closure(
@@ -509,6 +506,11 @@ def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer
509506
model = self.trainer.lightning_module
510507
model.toggle_optimizer(optimizer, opt_idx)
511508

509+
def _run_optimization_end(self, opt_idx: int) -> None:
510+
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
511+
model = self.trainer.lightning_module
512+
model.untoggle_optimizer(opt_idx)
513+
512514
@contextmanager
513515
def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]:
514516
"""

tests/core/test_lightning_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def optimizer_step(
197197
max_epochs=1,
198198
default_root_dir=tmpdir,
199199
limit_train_batches=8,
200-
accumulate_grad_batches=1,
200+
accumulate_grad_batches=2,
201201
limit_val_batches=0,
202202
)
203203
trainer.fit(model)
@@ -331,7 +331,7 @@ def configure_optimizers(self):
331331
max_epochs=1,
332332
default_root_dir=tmpdir,
333333
limit_train_batches=8,
334-
accumulate_grad_batches=1,
334+
accumulate_grad_batches=2,
335335
)
336336

337337
trainer.fit(model)

0 commit comments

Comments
 (0)