diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index fb412c2e71435..5a63a6e92db03 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -285,7 +285,10 @@ def _store( def on_train_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" - for opt_idx, optimizer in trainer.fit_loop.epoch_loop.batch_loop.get_active_optimizers(): + # import is here to avoid circular imports + from pytorch_lightning.loops.utilities import _get_active_optimizers + + for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) current_param_groups = optimizer.param_groups diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 41a512cfec57f..d7996aba13ea8 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional -import numpy as np from deprecate import void from torch import Tensor -from torch.optim import Optimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop +from pytorch_lightning.loops.utilities import _get_active_optimizers from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -41,7 +40,6 @@ def __init__(self) -> None: self.manual_loop = ManualOptimization() self._warning_cache: WarningCache = WarningCache() - self._optimizer_freq_cumsum: Optional[int] = None self._remaining_splits: Optional[List[Any]] = None @property @@ -49,13 +47,6 @@ def done(self) -> bool: """Returns if all batch splits have been processed already.""" return len(self._remaining_splits) == 0 - @property - def optimizer_freq_cumsum(self) -> int: - """Returns the cumulated sum of optimizer frequencies.""" - if self._optimizer_freq_cumsum is None: - self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) - return self._optimizer_freq_cumsum - def connect( self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None ) -> None: @@ -123,7 +114,8 @@ def advance(self, batch, batch_idx): if self.trainer.lightning_module.automatic_optimization: # in automatic optimization, hand over execution to the OptimizerLoop - batch_outputs = self.optimizer_loop.run(split_batch, self.get_active_optimizers(batch_idx), batch_idx) + optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx) + batch_outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) # combine outputs from each optimizer for k in range(len(batch_outputs)): self.batch_outputs[k].extend(batch_outputs[k]) @@ -142,10 +134,6 @@ def teardown(self) -> None: # release memory self._remaining_splits = None - def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: - """Gets the number of active optimizers based on their frequency.""" - return len(self.get_active_optimizers(batch_idx)) - def _tbptt_split_batch(self, batch: Any) -> List[Any]: """Splits a single batch into a list of sequence steps for tbptt. @@ -175,21 +163,3 @@ def _update_running_loss(self, current_loss: Tensor) -> None: # reset for next set of accumulated grads self.accumulated_loss.reset() - - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: - """Returns the currently active optimizers. When multiple optimizers are used with different frequencies, - only one of the optimizers is active at a time. - - Returns: - A list of tuples (opt_idx, optimizer) of currently active optimizers. - """ - if not self.trainer.optimizer_frequencies: - # call training_step once per optimizer - return list(enumerate(self.trainer.optimizers)) - - optimizers_loop_length = self.optimizer_freq_cumsum[-1] - current_place_in_loop = batch_idx % optimizers_loop_length - - # find optimzier index by looking for the first {item > current_place} in the cumsum list - opt_idx = np.searchsorted(self.optimizer_freq_cumsum, current_place_in_loop, side="right") - return [(opt_idx, self.trainer.optimizers[opt_idx])] diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index f829c20e557b1..6210016e68da5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -18,7 +18,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.loops.optimization.closure import OutputResult -from pytorch_lightning.loops.utilities import _prepare_dataloader_iter +from pytorch_lightning.loops.utilities import _get_active_optimizers, _prepare_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -97,7 +97,7 @@ def reset(self) -> None: self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() # track epoch output - self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] + self._epoch_output = [[] for _ in range(self._num_active_optimizers(self.total_batch_idx))] if not self.restarting or self._num_training_batches_reached(): self.batch_progress.reset_on_epoch() @@ -334,10 +334,13 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) - """updates the lr schedulers based on the given interval.""" if interval == "step" and self._should_accumulate(): return + active_optimizers = _get_active_optimizers( + self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx + ) self.trainer.optimizer_connector.update_learning_rates( interval=interval, update_plateau_schedulers=update_plateau_schedulers, - opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)], + opt_indices=[opt_idx for opt_idx, _ in active_optimizers], ) def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: @@ -371,3 +374,7 @@ def _save_loggers_on_train_batch_end(self) -> None: should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() + + def _num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: + """Gets the number of active optimizers based on their frequency.""" + return len(_get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 6fdde18fa6bb2..ec42cd156da27 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -13,8 +13,10 @@ # limitations under the License. from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, Optional, Sequence +from functools import lru_cache +from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple +import numpy as np import torch from torch.optim import Optimizer @@ -139,3 +141,30 @@ def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> yield None else: yield None + + +@lru_cache(1) +def _cumulative_optimizer_frequencies(frequencies: Tuple[int]): + return np.cumsum(frequencies) + + +def _get_active_optimizers( + optimizers: List[Optimizer], frequencies: List[int], batch_idx: Optional[int] = None +) -> List[Tuple[int, Optimizer]]: + """Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only + one of the optimizers is active at a time. + + Returns: + A list of tuples (opt_idx, optimizer) of currently active optimizers. + """ + if not frequencies: + # call training_step once per optimizer + return list(enumerate(optimizers)) + + freq_cumsum = _cumulative_optimizer_frequencies(tuple(frequencies)) + optimizers_loop_length = freq_cumsum[-1] + current_place_in_loop = batch_idx % optimizers_loop_length + + # find optimizer index by looking for the first {item > current_place} in the cumsum list + opt_idx = np.searchsorted(freq_cumsum, current_place_in_loop, side="right") + return [(opt_idx, optimizers[opt_idx])]