Skip to content

Commit 0421f08

Browse files
authored
fix optimizer loop with frequencies (#9507)
1 parent 3b34c89 commit 0421f08

File tree

6 files changed

+102
-22
lines changed

6 files changed

+102
-22
lines changed

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ def advance(self, batch, batch_idx):
123123

124124
if self.trainer.lightning_module.automatic_optimization:
125125
# in automatic optimization, hand over execution to the OptimizerLoop
126-
optimizers = [optimizer for _, optimizer in self.get_active_optimizers(batch_idx)]
127-
batch_outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
126+
batch_outputs = self.optimizer_loop.run(split_batch, self.get_active_optimizers(batch_idx), batch_idx)
128127
# combine outputs from each optimizer
129128
for k in range(len(batch_outputs)):
130129
self.batch_outputs[k].extend(batch_outputs[k])

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
1515
from functools import partial
16-
from typing import Any, Callable, Dict, List, Optional
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
1717

1818
import torch
1919
from torch import Tensor
@@ -188,36 +188,42 @@ def __init__(self) -> None:
188188
self._skip_backward: bool = False
189189
self._batch_idx: int = 0
190190
self._optimizers: List[Optimizer] = []
191+
self._indices: List[int] = []
191192
self._hiddens: Optional[Any] = None
192193

194+
@property
195+
def optimizer_idx(self) -> int:
196+
return self._indices[self.optim_progress.optimizer_position]
197+
193198
@property
194199
def done(self) -> bool:
195200
"""Returns ``True`` when the last optimizer in the sequence has run."""
196-
return self.optim_progress.optimizer_idx >= len(self._optimizers)
201+
return self.optim_progress.optimizer_position >= len(self._indices)
197202

198203
def connect(self, **kwargs: "Loop") -> None:
199204
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
200205

201206
def reset(self) -> None:
202207
if not self.restarting or self.done:
203-
self.optim_progress.optimizer_idx = 0
208+
self.optim_progress.optimizer_position = 0
204209
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]
205210

206-
def on_run_start(self, batch: Any, optimizers: List[Optimizer], batch_idx: int) -> None: # type: ignore[override]
211+
def on_run_start( # type: ignore[override]
212+
self, batch: Any, optimizers: List[Tuple[int, Optimizer]], batch_idx: int
213+
) -> None:
207214
self._batch_idx = batch_idx
208-
self._optimizers = optimizers
215+
self._indices, self._optimizers = zip(*optimizers)
209216

210217
def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
211218
result = self._run_optimization(
212219
batch,
213220
self._batch_idx,
214-
self._optimizers[self.optim_progress.optimizer_idx],
215-
self.optim_progress.optimizer_idx,
221+
self._optimizers[self.optim_progress.optimizer_position],
222+
self.optimizer_idx,
216223
)
217224
if result.loss is not None:
218-
self.outputs[self.optim_progress.optimizer_idx].append(result.drop_closure_loss())
219-
220-
self.optim_progress.optimizer_idx += 1
225+
self.outputs[self.optimizer_idx].append(result.drop_closure_loss())
226+
self.optim_progress.optimizer_position += 1
221227

222228
def on_run_end(self) -> _OUTPUTS_TYPE:
223229
outputs, self.outputs = self.outputs, [] # free memory

pytorch_lightning/trainer/progress.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,15 @@ class OptimizationProgress(BaseProgress):
209209
210210
Args:
211211
optimizer: Tracks optimizer progress.
212-
optimizer_idx: The index of the current optimizer. Used to know which optimizer we were using when restarting.
212+
optimizer_position: The index of the current optimizer amongst the currently active optimizers.
213+
Used to know which optimizer we were using when restarting.
214+
Since not all optimizers may be active at a given time, this index is different from the ``optimizer_idx``
215+
seen in the optimization loops.
213216
"""
214217

215218
# TODO: support for multiple optimizers
216219
optimizer: OptimizerProgress = field(default_factory=OptimizerProgress)
217-
optimizer_idx: int = 0
220+
optimizer_position: int = 0
218221

219222
@property
220223
def optimizer_steps(self) -> int:
@@ -225,4 +228,4 @@ def reset_on_epoch(self) -> None:
225228

226229
def load_state_dict(self, state_dict: dict) -> None:
227230
self.optimizer.load_state_dict(state_dict["optimizer"])
228-
self.optimizer_idx = state_dict["optimizer_idx"]
231+
self.optimizer_position = state_dict["optimizer_position"]

tests/loops/optimization/test_optimizer_loop.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import Mock
15+
16+
import pytest
1417
import torch
18+
from torch.optim import Adam, SGD
1519

20+
from pytorch_lightning import Trainer
21+
from pytorch_lightning.core.optimizer import LightningOptimizer
1622
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
23+
from tests.helpers import BoringModel
1724

1825

1926
def test_closure_result_deepcopy():
@@ -37,3 +44,68 @@ def test_closure_result_apply_accumulation():
3744
closure_loss = torch.tensor(25.0)
3845
result = ClosureResult.from_training_step_output(closure_loss, 5)
3946
assert result.loss == 5
47+
48+
49+
@pytest.mark.parametrize(
50+
"frequencies,expected",
51+
[
52+
(
53+
(3, 1),
54+
[
55+
(0, "SGD"),
56+
(0, "SGD"),
57+
(0, "SGD"),
58+
(1, "Adam"),
59+
(0, "SGD"),
60+
(0, "SGD"),
61+
(0, "SGD"),
62+
(1, "Adam"),
63+
(0, "SGD"),
64+
(0, "SGD"),
65+
],
66+
),
67+
(
68+
(1, 2),
69+
[
70+
(0, "SGD"),
71+
(1, "Adam"),
72+
(1, "Adam"),
73+
(0, "SGD"),
74+
(1, "Adam"),
75+
(1, "Adam"),
76+
(0, "SGD"),
77+
(1, "Adam"),
78+
(1, "Adam"),
79+
(0, "SGD"),
80+
],
81+
),
82+
],
83+
)
84+
def test_optimizer_frequencies(tmpdir, frequencies, expected):
85+
"""Test that the optimizer loop runs optimization for the correct optimizer and optimizer idx when different
86+
frequencies are requested."""
87+
88+
class CurrentModel(BoringModel):
89+
def training_step(self, batch, batch_idx, optimizer_idx):
90+
return super().training_step(batch, batch_idx)
91+
92+
def configure_optimizers(self):
93+
opt0 = SGD(self.parameters(), lr=0.1)
94+
opt1 = Adam(self.parameters(), lr=0.1)
95+
return {"optimizer": opt0, "frequency": frequencies[0]}, {"optimizer": opt1, "frequency": frequencies[1]}
96+
97+
model = CurrentModel()
98+
model.optimizer_step = Mock(wraps=model.optimizer_step)
99+
trainer = Trainer(
100+
default_root_dir=tmpdir,
101+
fast_dev_run=10,
102+
progress_bar_refresh_rate=0,
103+
)
104+
trainer.fit(model)
105+
106+
positional_args = [c[0] for c in model.optimizer_step.call_args_list]
107+
pl_optimizer_sequence = [args[2] for args in positional_args]
108+
opt_idx_sequence = [args[3] for args in positional_args]
109+
assert all(isinstance(opt, LightningOptimizer) for opt in pl_optimizer_sequence)
110+
optimizer_sequence = [opt._optimizer.__class__.__name__ for opt in pl_optimizer_sequence]
111+
assert list(zip(opt_idx_sequence, optimizer_sequence)) == expected

tests/loops/test_loop_state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_loops_state_dict_structure():
6767
"current": {"ready": 0, "started": 0, "completed": 0},
6868
},
6969
},
70-
"optimizer_idx": 0,
70+
"optimizer_position": 0,
7171
},
7272
"epoch_loop.val_loop.state_dict": {},
7373
"epoch_loop.val_loop.dataloader_progress": {

tests/loops/test_loops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def configure_optimizers_multiple(self):
468468
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
469469
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
470470
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
471-
"optimizer_idx": stop_optimizer,
471+
"optimizer_position": stop_optimizer,
472472
"optimizer": {
473473
"step": {
474474
"total": {
@@ -611,7 +611,7 @@ def configure_optimizers_multiple(self):
611611
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
612612
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
613613
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
614-
"optimizer_idx": n_optimizers,
614+
"optimizer_position": n_optimizers,
615615
"optimizer": {
616616
"step": {
617617
"total": {
@@ -697,12 +697,12 @@ def mid_epoch_reset_assertions():
697697

698698
# resetting from a mid-epoch checkpoint should not change progress counters
699699
mid_epoch_reset_assertions()
700-
assert optimizer_loop.optim_progress.optimizer_idx == 1
700+
assert optimizer_loop.optim_progress.optimizer_position == 1
701701
fit_loop.reset()
702702
epoch_loop.reset()
703703
optimizer_loop.reset()
704704
mid_epoch_reset_assertions()
705-
assert optimizer_loop.optim_progress.optimizer_idx == 0
705+
assert optimizer_loop.optim_progress.optimizer_position == 0
706706

707707
# reset state loaded from a checkpoint from the end of an epoch
708708
end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt"))
@@ -726,7 +726,7 @@ def mid_epoch_reset_assertions():
726726
assert epoch_loop.batch_progress.current.ready == 4
727727
assert epoch_loop.batch_progress.current.completed == 4
728728

729-
assert optimizer_loop.optim_progress.optimizer_idx == 1
729+
assert optimizer_loop.optim_progress.optimizer_position == 1
730730

731731
# resetting from a end-of-epoch checkpoint should reset the current counters to 0
732732
fit_loop.reset()
@@ -745,4 +745,4 @@ def mid_epoch_reset_assertions():
745745
assert epoch_loop.batch_progress.current.ready == 0
746746
assert epoch_loop.batch_progress.current.completed == 0
747747

748-
assert optimizer_loop.optim_progress.optimizer_idx == 0
748+
assert optimizer_loop.optim_progress.optimizer_position == 0

0 commit comments

Comments
 (0)