|
13 | 13 | # limitations under the License. |
14 | 14 | from dataclasses import dataclass, field |
15 | 15 | from functools import partial |
16 | | -from typing import Any, Callable, Dict, List, Optional |
| 16 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | from torch import Tensor |
@@ -188,36 +188,42 @@ def __init__(self) -> None: |
188 | 188 | self._skip_backward: bool = False |
189 | 189 | self._batch_idx: int = 0 |
190 | 190 | self._optimizers: List[Optimizer] = [] |
| 191 | + self._indices: List[int] = [] |
191 | 192 | self._hiddens: Optional[Any] = None |
192 | 193 |
|
| 194 | + @property |
| 195 | + def optimizer_idx(self) -> int: |
| 196 | + return self._indices[self.optim_progress.optimizer_position] |
| 197 | + |
193 | 198 | @property |
194 | 199 | def done(self) -> bool: |
195 | 200 | """Returns ``True`` when the last optimizer in the sequence has run.""" |
196 | | - return self.optim_progress.optimizer_idx >= len(self._optimizers) |
| 201 | + return self.optim_progress.optimizer_position >= len(self._indices) |
197 | 202 |
|
198 | 203 | def connect(self, **kwargs: "Loop") -> None: |
199 | 204 | raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") |
200 | 205 |
|
201 | 206 | def reset(self) -> None: |
202 | 207 | if not self.restarting or self.done: |
203 | | - self.optim_progress.optimizer_idx = 0 |
| 208 | + self.optim_progress.optimizer_position = 0 |
204 | 209 | self.outputs = [[] for _ in range(len(self.trainer.optimizers))] |
205 | 210 |
|
206 | | - def on_run_start(self, batch: Any, optimizers: List[Optimizer], batch_idx: int) -> None: # type: ignore[override] |
| 211 | + def on_run_start( # type: ignore[override] |
| 212 | + self, batch: Any, optimizers: List[Tuple[int, Optimizer]], batch_idx: int |
| 213 | + ) -> None: |
207 | 214 | self._batch_idx = batch_idx |
208 | | - self._optimizers = optimizers |
| 215 | + self._indices, self._optimizers = zip(*optimizers) |
209 | 216 |
|
210 | 217 | def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override] |
211 | 218 | result = self._run_optimization( |
212 | 219 | batch, |
213 | 220 | self._batch_idx, |
214 | | - self._optimizers[self.optim_progress.optimizer_idx], |
215 | | - self.optim_progress.optimizer_idx, |
| 221 | + self._optimizers[self.optim_progress.optimizer_position], |
| 222 | + self.optimizer_idx, |
216 | 223 | ) |
217 | 224 | if result.loss is not None: |
218 | | - self.outputs[self.optim_progress.optimizer_idx].append(result.drop_closure_loss()) |
219 | | - |
220 | | - self.optim_progress.optimizer_idx += 1 |
| 225 | + self.outputs[self.optimizer_idx].append(result.drop_closure_loss()) |
| 226 | + self.optim_progress.optimizer_position += 1 |
221 | 227 |
|
222 | 228 | def on_run_end(self) -> _OUTPUTS_TYPE: |
223 | 229 | outputs, self.outputs = self.outputs, [] # free memory |
|
0 commit comments