Skip to content

Commit 06f6593

Browse files
ananthsubtchatonrohitgr7mergify[bot]
authored andcommitted
Fix toggle optimizer (#5775)
* Update lightning.py * update changelog * add a 3 optimizer test * resolve flake8 * remove extra code * typo * resolve typo * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: tchaton <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 6d7c01b commit 06f6593

File tree

3 files changed

+254
-13
lines changed

3 files changed

+254
-13
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [unreleased] - YYYY-MM-DD
8+
9+
### Added
10+
11+
### Changed
12+
13+
### Deprecated
14+
15+
### Removed
16+
17+
### Fixed
18+
19+
- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))
720

821
## [unreleased.Features] - YYYY-MM-DD
922

pytorch_lightning/core/lightning.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,22 +1196,24 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11961196
optimizer: Current optimizer used in training_loop
11971197
optimizer_idx: Current optimizer idx in training_loop
11981198
"""
1199+
1200+
# Iterate over all optimizer parameters to preserve their `requires_grad` information
1201+
# in case these are pre-defined during `configure_optimizers`
11991202
param_requires_grad_state = {}
1200-
# make sure current optimizer is latest to be iterated over.
1201-
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
1202-
num_optimizers = len(optimizers) - 1
1203-
for opt_idx, opt in enumerate(optimizers):
1203+
for opt in self.optimizers(use_pl_optimizer=False):
12041204
for group in opt.param_groups:
12051205
for param in group['params']:
1206-
if num_optimizers == opt_idx:
1207-
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
1208-
if param in param_requires_grad_state:
1209-
param.requires_grad = param_requires_grad_state[param]
1210-
else:
1211-
# save requires_grad for later restoration
1212-
param_requires_grad_state[param] = param.requires_grad
1213-
param.requires_grad = False
1214-
1206+
# If a param already appear in param_requires_grad_state, continue
1207+
if param in param_requires_grad_state:
1208+
continue
1209+
param_requires_grad_state[param] = param.requires_grad
1210+
param.requires_grad = False
1211+
1212+
# Then iterate over the current optimizer's parameters and set its `requires_grad`
1213+
# properties accordingly
1214+
for group in optimizer.param_groups:
1215+
for param in group['params']:
1216+
param.requires_grad = param_requires_grad_state[param]
12151217
self._param_requires_grad_state = param_requires_grad_state
12161218

12171219
def untoggle_optimizer(self, optimizer_idx: int):

tests/core/test_lightning_module.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from unittest.mock import Mock, patch
1515

1616
import pytest
17+
from torch import nn
1718
from torch.optim import Adam, SGD
1819

1920
from pytorch_lightning import Trainer
@@ -184,3 +185,228 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
184185
)
185186

186187
trainer.fit(model)
188+
189+
190+
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):
191+
192+
class TestModel(BoringModel):
193+
194+
def training_step(self, batch, batch_idx, optimizer_idx=None):
195+
return super().training_step(batch, batch_idx)
196+
197+
def __init__(self):
198+
super().__init__()
199+
self.layer_1 = nn.Sequential(
200+
nn.Linear(32, 32),
201+
nn.ReLU(),
202+
nn.Linear(32, 32),
203+
nn.ReLU(),
204+
nn.Linear(32, 32),
205+
)
206+
207+
self.layer_2 = nn.Sequential(
208+
nn.ReLU(),
209+
nn.Linear(32, 32),
210+
nn.ReLU(),
211+
nn.Linear(32, 32),
212+
nn.ReLU(),
213+
nn.Linear(32, 2)
214+
)
215+
216+
# set some weights to False to check untoggle works as expected.
217+
self.layer_1[2].weight.requires_grad = False
218+
self.layer_1[4].weight.requires_grad = False
219+
220+
self.layer_2[1].weight.requires_grad = False
221+
self.layer_2[3].weight.requires_grad = False
222+
223+
def configure_optimizers(self):
224+
optimizer = SGD(self.layer_1.parameters(), lr=0.1)
225+
optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1)
226+
return [optimizer, optimizer_2]
227+
228+
def optimizer_step(
229+
self,
230+
current_epoch,
231+
batch_nb,
232+
optimizer,
233+
optimizer_idx,
234+
closure,
235+
on_tpu=False,
236+
using_native_amp=False,
237+
using_lbfgs=False
238+
):
239+
if optimizer_idx == 0:
240+
assert self.layer_1[0].weight.requires_grad is True
241+
assert self.layer_1[2].weight.requires_grad is False
242+
assert self.layer_1[4].weight.requires_grad is False
243+
244+
assert self.layer_2[1].weight.requires_grad is False
245+
assert self.layer_2[3].weight.requires_grad is False
246+
assert self.layer_2[5].weight.requires_grad is False
247+
248+
if optimizer_idx == 1:
249+
assert self.layer_1[0].weight.requires_grad is False
250+
assert self.layer_1[2].weight.requires_grad is False
251+
assert self.layer_1[4].weight.requires_grad is False
252+
253+
assert self.layer_2[1].weight.requires_grad is False
254+
assert self.layer_2[3].weight.requires_grad is False
255+
assert self.layer_2[5].weight.requires_grad is True
256+
257+
optimizer.step(closure=closure)
258+
259+
model = TestModel()
260+
model.training_epoch_end = None
261+
262+
trainer = Trainer(
263+
max_epochs=1,
264+
default_root_dir=tmpdir,
265+
limit_train_batches=8,
266+
accumulate_grad_batches=1,
267+
limit_val_batches=0,
268+
)
269+
270+
results = trainer.fit(model)
271+
assert results
272+
273+
274+
def test_toggle_untoggle_3_optimizers_shared_parameters(tmpdir):
275+
276+
class TestModel(BoringModel):
277+
278+
def __init__(self):
279+
super().__init__()
280+
self.layer_1 = nn.Sequential(
281+
nn.Linear(32, 32),
282+
nn.ReLU(),
283+
nn.Linear(32, 32),
284+
nn.ReLU(),
285+
nn.Linear(32, 32),
286+
)
287+
288+
self.layer_2 = nn.Sequential(
289+
nn.ReLU(),
290+
nn.Linear(32, 32),
291+
nn.ReLU(),
292+
nn.Linear(32, 32),
293+
nn.ReLU(),
294+
nn.Linear(32, 2)
295+
)
296+
297+
self.layer_3 = nn.Sequential(
298+
nn.ReLU(),
299+
nn.Linear(32, 32),
300+
nn.ReLU(),
301+
nn.Linear(32, 32),
302+
nn.ReLU(),
303+
nn.Linear(32, 2)
304+
)
305+
306+
# set some weights to False to check untoggle works as expected.
307+
self.layer_1[2].weight.requires_grad = False
308+
self.layer_1[4].weight.requires_grad = False
309+
310+
self.layer_2[1].weight.requires_grad = False
311+
self.layer_2[3].weight.requires_grad = False
312+
313+
self.layer_3[1].weight.requires_grad = False
314+
self.layer_3[5].weight.requires_grad = False
315+
316+
def optimizer_step(
317+
self,
318+
current_epoch,
319+
batch_nb,
320+
optimizer,
321+
optimizer_idx,
322+
closure,
323+
on_tpu=False,
324+
using_native_amp=False,
325+
using_lbfgs=False
326+
):
327+
if optimizer_idx == 0:
328+
assert self.layer_1[0].weight.requires_grad is True
329+
assert self.layer_1[2].weight.requires_grad is False
330+
assert self.layer_1[4].weight.requires_grad is False
331+
332+
assert self.layer_2[1].weight.requires_grad is False
333+
assert self.layer_2[3].weight.requires_grad is False
334+
assert self.layer_2[5].weight.requires_grad is True
335+
336+
assert self.layer_3[1].weight.requires_grad is False
337+
assert self.layer_3[3].weight.requires_grad is False
338+
assert self.layer_3[5].weight.requires_grad is False
339+
340+
if optimizer_idx == 1:
341+
assert self.layer_1[0].weight.requires_grad is False
342+
assert self.layer_1[2].weight.requires_grad is False
343+
assert self.layer_1[4].weight.requires_grad is False
344+
345+
assert self.layer_2[1].weight.requires_grad is False
346+
assert self.layer_2[3].weight.requires_grad is False
347+
assert self.layer_2[5].weight.requires_grad is True
348+
349+
assert self.layer_3[1].weight.requires_grad is False
350+
assert self.layer_3[3].weight.requires_grad is True
351+
assert self.layer_3[5].weight.requires_grad is False
352+
353+
if optimizer_idx == 2:
354+
assert self.layer_1[0].weight.requires_grad is True
355+
assert self.layer_1[2].weight.requires_grad is False
356+
assert self.layer_1[4].weight.requires_grad is False
357+
358+
assert self.layer_2[1].weight.requires_grad is False
359+
assert self.layer_2[3].weight.requires_grad is False
360+
assert self.layer_2[5].weight.requires_grad is False
361+
362+
assert self.layer_3[1].weight.requires_grad is False
363+
assert self.layer_3[3].weight.requires_grad is True
364+
assert self.layer_3[5].weight.requires_grad is False
365+
366+
optimizer.step(closure=closure)
367+
368+
def training_step(self, batch, batch_idx, optimizer_idx=None):
369+
return super().training_step(batch, batch_idx)
370+
371+
@staticmethod
372+
def combine_generators(gen_1, gen_2):
373+
for p in gen_1:
374+
yield p
375+
for p in gen_2:
376+
yield p
377+
378+
def configure_optimizers(self):
379+
optimizer_1 = SGD(
380+
self.combine_generators(
381+
self.layer_1.parameters(),
382+
self.layer_2.parameters()
383+
),
384+
lr=0.1
385+
)
386+
optimizer_2 = Adam(
387+
self.combine_generators(
388+
self.layer_2.parameters(),
389+
self.layer_3.parameters()
390+
),
391+
lr=0.1
392+
)
393+
optimizer_3 = SGD(
394+
self.combine_generators(
395+
self.layer_3.parameters(),
396+
self.layer_1.parameters()
397+
),
398+
lr=0.1
399+
)
400+
return [optimizer_1, optimizer_2, optimizer_3]
401+
402+
model = TestModel()
403+
model.training_epoch_end = None
404+
405+
trainer = Trainer(
406+
max_epochs=1,
407+
default_root_dir=tmpdir,
408+
limit_train_batches=8,
409+
accumulate_grad_batches=1,
410+
)
411+
412+
trainer.fit(model)

0 commit comments

Comments
 (0)