Skip to content

Commit f8c7548

Browse files
tchatonBorda
authored andcommitted
[bugfix] Resolve bug with multiple optimizers and toggle. (#5574)
* fix toggle_optimizer * update doc * resolve bug * update * Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta <[email protected]> * update on comments * update on comments * update Co-authored-by: Rohit Gupta <[email protected]> (cherry picked from commit c76cc23)
1 parent cf13cec commit f8c7548

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
164164

165165
### Fixed
166166

167+
- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574))
168+
169+
167170
- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))
168171

169172

@@ -213,7 +216,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
213216
- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743))
214217
- Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505))
215218

216-
217219
## [1.1.3] - 2021-01-05
218220

219221
### Added

pytorch_lightning/core/lightning.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,17 +1190,47 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11901190
11911191
Override for your own behavior
11921192
1193+
It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset.
1194+
11931195
Args:
1194-
optimizer:
1195-
optimizer_idx:
1196+
optimizer: Current optimizer used in training_loop
1197+
optimizer_idx: Current optimizer idx in training_loop
11961198
"""
1197-
# Todo: required argument `optimizer_idx` is not used
1198-
for param in self.parameters():
1199-
param.requires_grad = False
1199+
param_requires_grad_state = {}
1200+
# make sure current optimizer is latest to be iterated over.
1201+
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
1202+
num_optimizers = len(optimizers) - 1
1203+
for opt_idx, opt in enumerate(optimizers):
1204+
for group in opt.param_groups:
1205+
for param in group['params']:
1206+
if num_optimizers == opt_idx:
1207+
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
1208+
if param in param_requires_grad_state:
1209+
param.requires_grad = param_requires_grad_state[param]
1210+
else:
1211+
# save requires_grad for later restoration
1212+
param_requires_grad_state[param] = param.requires_grad
1213+
param.requires_grad = False
1214+
1215+
self._param_requires_grad_state = param_requires_grad_state
1216+
1217+
def untoggle_optimizer(self, optimizer_idx: int):
1218+
"""
1219+
.. note:: Only called when using multiple optimizers
12001220
1201-
for group in optimizer.param_groups:
1202-
for param in group['params']:
1203-
param.requires_grad = True
1221+
Override for your own behavior
1222+
1223+
Args:
1224+
optimizer_idx: Current optimizer idx in training_loop
1225+
"""
1226+
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
1227+
if optimizer_idx != opt_idx:
1228+
for group in opt.param_groups:
1229+
for param in group['params']:
1230+
if param in self._param_requires_grad_state:
1231+
param.requires_grad = self._param_requires_grad_state[param]
1232+
# save memory
1233+
del self._param_requires_grad_state
12041234

12051235
def optimizer_step(
12061236
self,

pytorch_lightning/trainer/training_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
795795
if self.trainer.terminate_on_nan:
796796
self.trainer.detect_nan_tensors(result.loss)
797797

798+
if len(self.trainer.optimizers) > 1:
799+
# revert back to previous state
800+
self.trainer.get_model().untoggle_optimizer(opt_idx)
801+
798802
return result
799803

800804
def backward(self, result, optimizer, opt_idx, *args, **kwargs):

0 commit comments

Comments
 (0)