diff --git a/CHANGELOG.md b/CHANGELOG.md index 306500c3e6f42..9ef3267812a80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -190,6 +190,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactored `EpochResultStore` ([#5522](https://github.com/PyTorchLightning/pytorch-lightning/pull/5522)) +- LightningOptimizer manual optimizer is more flexible and expose `toggle_model` ([#5771](https://github.com/PyTorchLightning/pytorch-lightning/pull/5771)) + + + ### Deprecated - Function `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index f71df4f8903ba..3f7cd7f224a97 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -21,46 +21,117 @@ Manual optimization For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to manually manage the optimization process. To do so, do the following: -* Disable automatic optimization in Trainer: Trainer(automatic_optimization=False) +* Override your LightningModule ``automatic_optimization`` property to return ``False`` * Drop or ignore the optimizer_idx argument -* Use `self.manual_backward(loss)` instead of `loss.backward()` to automatically scale your loss +* Use `self.manual_backward(loss)` instead of `loss.backward()`. + +.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with zero_grad, accumulated_grad_batches, model toggling, etc.. + +.. warning:: Before 1.2, ``optimzer.step`` was calling ``zero_grad`` internally. From 1.2, it is left to the users expertize. + +.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such. + +.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you. + .. code-block:: python - def training_step(self, batch, batch_idx, optimizer_idx): + def training_step(batch, batch_idx, optimizer_idx): + opt = self.optimizers() + + loss = self.compute_loss(batch) + self.manual_backward(loss) + opt.step() - # 1. ignore optimizer_idx - # 2. `use_pl_optimizer=True` means `opt_g` and `opt_d` will be of type `LightningOptimizer` - # `LightningOptimizer` simply wrapped your optimizer and behave the same way ! - # When calling `optimizer.step`, `LightningOptimizer` will just handle TPU, AMP, accumulate_grad_batches, etc ... for you. + # accumulate gradient batches + if batch_idx % 2 == 0: + opt.zero_grad() - # access your optimizers with `use_pl_optimizer=False` or `optimizer.optimizer` when using use_pl_optimizer=True - # use_pl_optimizer=True is the default - (opt_g, opt_d) = self.optimizers(use_pl_optimizer=True) - # do anything you want - loss_a = ... +.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. - # use self.backward which will also handle scaling the loss when using amp - self.manual_backward(loss_a, opt_g) - opt_g.step() +Here is the same example as above using a ``closure``. +.. code-block:: python + + def training_step(batch, batch_idx, optimizer_idx): + opt = self.optimizers() + + def forward_and_backward(): + loss = self.compute_loss(batch) + self.manual_backward(loss) - # do anything you want - loss_b = ... + opt.step(closure=forward_and_backward) - # pass in any args that loss.backward() normally takes - self.manual_backward(loss_b, opt_d, retain_graph=True) - self.manual_backward(loss_b, opt_d) - opt_d.step() + # accumulate gradient batches + if batch_idx % 2 == 0: + opt.zero_grad() + + +.. code-block:: python - # log losses - self.log('loss_a', loss_a) - self.log('loss_b', loss_b) + # Scenario for a GAN. -.. note:: This is only recommended for experts who need ultimate flexibility + def training_step(...): + opt_gen, opt_dis = self.optimizers() -Manual optimization does not yet support accumulated gradients but will be live in 1.1.0 + # compute generator loss + loss_gen = self.compute_generator_loss(...) + + # zero_grad needs to be called before backward + opt_gen.zero_grad() + self.manual_backward(loss_gen) + opt_gen.step() + + # compute discriminator loss + loss_dis = self.compute_discriminator_loss(...) + + # zero_grad needs to be called before backward + opt_dis.zero_grad() + self.manual_backward(loss_dis) + opt_dis.step() + + +.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting. + +Here is an explanation of what it does: + +Considering the current optimizer as A and all other optimizers as B. +Toggling means that all parameters from B exclusive to A will have their ``requires_grad`` attribute set to ``False``. Their original state will be restored when exiting the context manager. + +When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. +Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed. + +Here is an example on how to use it: + +.. code-block:: python + + + # Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus. + + def training_step(self, batch, batch_idx, ...): + opt_gen, opt_dis = self.optimizers() + + accumulated_grad_batches = batch_idx % 2 == 0 + + # compute generator loss + def closure_gen(): + loss_gen = self.compute_generator_loss(...) + self.manual_backward(loss_gen) + if accumulated_grad_batches: + opt_gen.zero_grad() + + with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): + opt_gen.step(closure=closure_gen) + + def closure_dis(): + loss_dis = self.compute_discriminator_loss(...) + self.manual_backward(loss_dis) + if accumulated_grad_batches: + opt_dis.zero_grad() + + with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): + opt_dis.step(closure=closure_dis) ------ @@ -166,7 +237,7 @@ returned as a dict which can contain the following keywords: * ``strict`` (optional): if set to ``True`` will enforce that value specified in ``monitor`` is available while trying to call ``scheduler.step()``, and stop training if not found. If ``False`` will only give a warning and continue training (without calling the scheduler). -* ``name`` (optional): if using the :class:`~pytorch_lightning.callbacks.LearningRateMonitor` callback to monitor the +* ``name`` (optional): if using the :class:`~pytorch_lightning.callbacks.LearningRateMonitor` callback to monitor the learning rate progress, this keyword can be used to specify a specific name the learning rate should be logged as. .. testcode:: @@ -248,23 +319,6 @@ For example, here step optimizer A every 2 batches and optimizer B every 4 batch if batch_nb % 4 == 0 : optimizer.step(closure=closure) -.. note:: When using ``Trainer(enable_pl_optimizer=True)``, ``.step`` accepts a boolean ``make_optimizer_step`` which can be used as follow. - -.. testcode:: - - def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): - optimizer.zero_grad() - - # Alternating schedule for optimizer steps (ie: GANs) - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - # update generator opt every 2 steps - if optimizer_idx == 0: - optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0) - - # update discriminator opt every 4 steps - if optimizer_idx == 1: - optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0) - Here we add a learning-rate warm up .. testcode:: diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e348a57b5c103..4f4b10e2730c1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -238,7 +238,7 @@ def backward( self, closure_loss: torch.Tensor, optimizer: Optimizer, - opt_idx: int, + optimizer_idx: int, should_accumulate: bool, *args, **kwargs, @@ -247,17 +247,15 @@ def backward( Args: closure_loss: a tensor holding the loss value to backpropagate - optimizer: the optimizer to do the step later on. - opt_idx: the index of the optimizer should_accumulate: whether to accumulate gradients """ - self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, opt_idx) + self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, optimizer_idx) output = self.precision_plugin.backward( - self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs + self.lightning_module, closure_loss, optimizer, optimizer_idx, should_accumulate, *args, **kwargs ) - self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, opt_idx) + self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, optimizer_idx) return output diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index aa7f909d9b682..9c87836b4415a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1186,7 +1186,7 @@ def configure_optimizers(self): """ rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer") - def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None: + def manual_backward(self, loss: Tensor, optimizer: Optional[Optimizer] = None, *args, **kwargs) -> None: """ Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you @@ -1207,12 +1207,18 @@ def training_step(...): self.manual_backward(loss, opt_a) opt_a.step() """ + if optimizer is not None: + rank_zero_warn( + "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4", + DeprecationWarning + ) + # make sure we're using manual opt self._verify_is_manual_optimization('manual_backward') # backward self._running_manual_backward = True - self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs) + self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) self._running_manual_backward = False def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 42af0f44e0071..d18abde814aab 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import types +from contextlib import contextmanager from typing import Callable, Optional from weakref import proxy -from torch.optim.optimizer import Optimizer +from torch.optim import Optimizer from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -35,13 +36,7 @@ class LightningOptimizer: the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches """ - def __init__(self, optimizer: Optimizer, accumulate_grad_batches: Optional[int] = None): - - assert accumulate_grad_batches is None or isinstance(accumulate_grad_batches, int) - if isinstance(accumulate_grad_batches, int) and accumulate_grad_batches < 1: - raise MisconfigurationException( - f"accumulate_grad_batches parameters {accumulate_grad_batches} should be >= 1" - ) + def __init__(self, optimizer: Optimizer): self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k != 'step'} @@ -57,7 +52,6 @@ def __init__(self, optimizer: Optimizer, accumulate_grad_batches: Optional[int] self._optimizer = optimizer self._trainer = None - self._accumulate_grad_batches = accumulate_grad_batches self._optimizer_idx = None self._total_optimizer_step_calls = 0 @@ -89,14 +83,6 @@ def param_groups(self): def param_groups(self, param_groups): self._optimizer.param_groups = param_groups - @property - def accumulate_grad_batches(self): - return self._accumulate_grad_batches - - @accumulate_grad_batches.setter - def accumulate_grad_batches(self, accumulate_grad_batches): - self._accumulate_grad_batches = accumulate_grad_batches - def _on_trainer_init(self, trainer): self._trainer = proxy(trainer) for opt_idx, opt in enumerate(trainer.optimizers): @@ -114,17 +100,31 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx): optimizer = trainer.lightning_optimizers[opt_idx] return optimizer - def _accumulated_batches_reached(self): - if self.accumulate_grad_batches is None: - return self._trainer.train_loop._accumulated_batches_reached() - return (self._trainer.batch_idx + 1) % self.accumulate_grad_batches == 0 + def _toggle_model(self): + model_ref = self._trainer.get_model() + model_ref.toggle_optimizer(self, self._optimizer_idx) - @property - def _should_accumulate(self): - # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached() - is_final_batch = self._trainer.train_loop._num_training_batches_reached() - return not (accumulation_done or is_final_batch) + def _untoggle_model(self): + model_ref = self._trainer.get_model() + model_ref.untoggle_optimizer(self) + + @contextmanager + def toggle_model(self, sync_grad: bool = True): + """ + This function is just a helper for advanced users. + + Considering the current optimizer as A and all other optimizers as B. + Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False. + + + When performing gradient accumulation, there is no need to perform grad synchronization + during the accumulation phase. + Setting `sync_grad` to False will block this synchronization and improve performance. + """ + with self._trainer.train_loop.block_ddp_sync_behaviour(not sync_grad): + self._toggle_model() + yield + self._untoggle_model() def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs): trainer = self._trainer @@ -134,137 +134,90 @@ def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: st with trainer.profiler.profile(profiler_name): trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) - trainer.train_loop.on_before_zero_grad(optimizer) - - model.optimizer_zero_grad(trainer.current_epoch, trainer.batch_idx, optimizer, self._optimizer_idx) - - def _check_make_optimizer_step(self, make_optimizer_step: Optional[bool]) -> bool: - if make_optimizer_step is not None and self._trainer.overriden_optimizer_zero_grad: - raise MisconfigurationException( - "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed." - ) - if self._trainer.train_loop.automatic_optimization: - if self._trainer.overriden_optimizer_step and self._trainer.overriden_optimizer_zero_grad: - return True + trainer.train_loop.on_before_zero_grad(optimizer) + model.optimizer_zero_grad(trainer.current_epoch, trainer.batch_idx, optimizer, self._optimizer_idx) - if make_optimizer_step is None: - make_optimizer_step = not self._should_accumulate - - return make_optimizer_step - - def step(self, *args, closure: Optional[Callable] = None, make_optimizer_step: Optional[bool] = None, **kwargs): + def step(self, *args, closure: Optional[Callable] = None, **kwargs): """ Call this directly from your training_step when doing optimizations manually. - By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you + By using this we can ensure that all the proper scaling when using 16-bit, accelerator etc + is been done properly for you. - .. tip:: In manual mode we still automatically accumulate grad over batches if - Trainer(accumulate_grad_batches=x) is set. + .. note:: In Manual Optimization, the user is expected to know when to call zero_grad, + perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators Args: closure: One could provide its own optimizer_closure. Set to None by default. - make_optimizer_step: Whether to force an optimizer step. When nothing is provided, - we will use `accumulate_grad_batches` for accumulation frequency by default. - However, one coud provide True and False based on its own scheduling. - Refer to example 2 and 3 - args: Any parameters provided to wrapped optimizer.step() kwargs: Any parameters provided to wrapped optimizer.step() Example:: + # Scenario for a GAN. + def training_step(...): - (opt_a, opt_b) = self.optimizers() - loss_a = ... - # automatically applies scaling, etc... - self.manual_backward(loss_a, opt_a) - opt_a.step() + opt_gen, opt_dis = self.optimizers() - Example:: + ... - def training_step(self, batch, batch_idx): - # using Boring Model - opt = self.optimizers() # only 1 optimizer - - def compute_loss(): - x = batch[0] - x = F.dropout(x, 0.1) - predictions = self(x) - predictions = F.dropout(predictions, 0.1) - loss = self.loss(None, predictions) - return loss - - def closure(): - # emulate MC dropout training - num_backward = 1 - losses = [] - for backward_idx in range(num_backward + 1): - loss = compute_loss() - losses.append(loss) - retain_graph = num_backward!= backward_idx - self.manual_backward(loss, opt, retain_graph=retain_graph) - loss_mean = torch.stack(losses).mean() - loss_std = torch.stack(losses).std() - self.log("train_loss_mean", loss_mean, on_step=True, prog_bar=True, on_epoch=True) - self.log("train_loss_std", loss_std, on_step=True, prog_bar=True, on_epoch=True) - - opt.step(loss, closure=closure) + # compute generator loss + loss_gen = self.compute_generator_loss(...) + # zero_grad needs to be called before backward + opt_gen.zero_grad() + self.manual_backward(loss_gen) + opt_gen.step() - Example:: + # compute discriminator loss + loss_dis = self.compute_discriminator_loss(...) + + # zero_grad needs to be called before backward + opt_dis.zero_grad() + self.manual_backward(loss_dis) + opt_dis.step() - # Scenario for a gan. - def training_step(self, batch, batch_idx, optimizer_idx): + # Scenario for a GAN advanced - # emulate gans training + def training_step(self, batch, batch_idx, ...): opt_gen, opt_dis = self.optimizers() - # Note: Be careful, don't log on the same key in self.log in both closure - # as they will be aggregated together on epoch_end - - def gen_closure(): - ... forward and compute loss for generator - loss_gen = ... - self.log("loss_gen", loss_gen, on_step=True, on_epoch=True) - self.manual_backward(loss_gen, opt_gen) - - def dis_closure(): - ... forward and compute loss for discriminator - loss_dis = ... - self.log("loss_dis", loss_dis, on_step=True, on_epoch=True) - self.manual_backward(loss_dis, opt_dis) - - # this will accumulate gradients for 2 batches and then call opt_gen.step() - opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0) - - # update discriminator every 4 batches - # therefore, no gradient accumulation for discriminator - if batch_idx % 4 == 0 : - # Note: Set make_optimizer_step to True or it will use by default - # Trainer(accumulate_grad_batches=x) - opt_dis.step(closure=optimizer_closure, make_optimizer_step=True) - """ - profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" + ... + accumulated_grad_batches = batch_idx % 2 == 0 + + # compute generator loss + def closure_gen(): + loss_gen = self.compute_generator_loss(...) + self.manual_backward(loss_gen) + if accumulated_grad_batches: + opt_gen.zero_grad() + + with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): + opt_gen.step(closure=closure_gen) + + def closure_dis(): + loss_dis = self.compute_discriminator_loss(...) + self.manual_backward(loss_dis) + if accumulated_grad_batches: + opt_dis.zero_grad() + with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): + opt_dis.step(closure=closure_dis) + + """ if closure is None: + profiler_name = "closure_{self._optimizer_idx}" closure = do_nothing_closure else: if not isinstance(closure, types.FunctionType): raise MisconfigurationException("When closure is provided, it should be a function") + profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" - make_optimizer_step = self._check_make_optimizer_step(make_optimizer_step) - - if make_optimizer_step: - self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) - self._total_optimizer_step_calls += 1 - else: - # make sure to call optimizer_closure when accumulating - with self._trainer.profiler.profile(f"closure_{self._optimizer_idx}"): - with self._trainer.train_loop.block_ddp_sync_behaviour(True): - closure() + self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) + self._total_optimizer_step_calls += 1 def __repr__(self): groups = [{k: round(v, 12) if isinstance(v, float) else v diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 884b05cfd8de2..cc1ad4da5266c 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -64,7 +64,7 @@ def backward( should_accumulate: whether to accumulate gradients or not """ - closure_loss = amp.scale_loss(closure_loss, optimizer) + closure_loss = amp.scale_loss(closure_loss, model.trainer.optimizers if optimizer is None else optimizer) # enter apex context context = closure_loss diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index e8a6511798664..60c0f5f84626f 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -79,7 +79,6 @@ def pre_optimizer_step( if not pl_module.automatic_optimization: self.scaler.unscale_(optimizer) - pl_module.trainer.call_hook("on_after_backward") return False diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 2393c040bcc8f..995c83079992c 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -116,9 +116,6 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) return obj - def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): - optimizer.synchronize() - def model_to_device(self): if self.on_gpu: torch.cuda.set_device(self.root_device) @@ -158,3 +155,8 @@ def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = gathered = hvd.allgather(result) gathered_result = list(gathered.split(1, dim=0)) return gathered_result + + def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + # synchronize all horovod optimizers. + for optimizer in self.lightning_module.trainer.optimizers: + optimizer.synchronize() diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 503955ac875ac..6aa32934ad771 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -138,7 +138,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): def test_freeze_unfreeze_function(tmpdir): - """Test freeze properly set requieres_grad on the modules""" + """Test freeze properly sets requires_grad on the modules""" seed_everything(42) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index a67d73a3bb16a..fab6ceccbfd88 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -18,7 +18,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -31,7 +30,6 @@ class TestModel(BoringModel): def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - # optimizer = LightningOptimizer(self.trainer, optimizer) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] @@ -98,20 +96,22 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) - self.manual_backward(loss_1, opt_1) + self.manual_backward(loss_1) opt_1.step() + opt_1.zero_grad() - def closure(): - output = self.layer(batch) - loss_2 = self.loss(batch, output) - self.manual_backward(loss_2, opt_2) + output = self.layer(batch) + loss_2 = self.loss(batch, output) + self.manual_backward(loss_2) - opt_2.step(closure=closure) + if batch_idx % 2 == 0: + opt_2.step() + opt_2.zero_grad() def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - optimizer_1 = LightningOptimizer(optimizer_1, 4) + optimizer_1 = LightningOptimizer(optimizer_1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) return [optimizer_1, optimizer_2], [lr_scheduler] @@ -128,8 +128,8 @@ def configure_optimizers(self): ) trainer.fit(model) - assert len(mock_sgd_step.mock_calls) == 2 - assert len(mock_adam_step.mock_calls) == 8 + assert len(mock_sgd_step.mock_calls) == 8 + assert len(mock_adam_step.mock_calls) == 4 @patch("torch.optim.Adam.step", autospec=True) @@ -152,20 +152,20 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) - self.manual_backward(loss_1, opt_1) + self.manual_backward(loss_1) opt_1.step() def closure(): output = self.layer(batch) loss_2 = self.loss(batch, output) - self.manual_backward(loss_2, opt_2) + self.manual_backward(loss_2) opt_2.step(closure=closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - optimizer_1 = LightningOptimizer(optimizer_1, 4) + optimizer_1 = LightningOptimizer(optimizer_1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) return [optimizer_1, optimizer_2], [lr_scheduler] @@ -183,8 +183,8 @@ def configure_optimizers(self): ) trainer.fit(model) - assert len(mock_sgd_step.mock_calls) == 2 - assert len(mock_adam_step.mock_calls) == 4 + assert len(mock_sgd_step.mock_calls) == 8 + assert len(mock_adam_step.mock_calls) == 8 def test_state(tmpdir): @@ -239,7 +239,7 @@ def test_state(tmpdir): def test_lightning_optimizer_automatic_optimization(tmpdir): """ - Test lightning optimize works with make_optimizer_step in automatic_optimization + Test lightning optimize works with in automatic_optimization """ class TestModel(BoringModel): @@ -256,15 +256,16 @@ def training_epoch_end(self, outputs): def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs ): - assert optimizer_closure.__name__ == "train_step_and_backward_closure" - - optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 2 == 0) + optimizer_closure() + if batch_idx % 2 == 0: + optimizer.step() + optimizer.zero_grad() def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - optimizer_1 = LightningOptimizer(optimizer_1, 4) + optimizer_1 = LightningOptimizer(optimizer_1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) return [optimizer_1, optimizer_2], [lr_scheduler] @@ -286,7 +287,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir): """ with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ - patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: class TestModel(BoringModel): @@ -308,151 +309,6 @@ def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, if batch_idx % 5 == 0: optimizer.zero_grad() - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - on_tpu, - using_native_amp, - using_lbfgs, - ): - - assert optimizer_closure.__name__ == "train_step_and_backward_closure" - - optimizer.step(closure=optimizer_closure) - - def configure_optimizers(self): - optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) - return [optimizer_1, optimizer_2], [lr_scheduler] - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=10, - limit_val_batches=1, - max_epochs=1, - weights_summary=None, - ) - trainer.fit(model) - - assert adam_zero_grad.call_count == 2 - assert sgd_zero_grad.call_count == 5 - - -def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad_make_optimizer_step(tmpdir): - """ - Test lightning optimize works with optimizer_zero_grad overrides and make_optimizer_step in automatic_optimization - """ - - try: - with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ - patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: - - class TestModel(BoringModel): - - def training_step(self, batch, batch_idx, optimizer_idx=None): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def training_epoch_end(self, outputs): - outputs = sum(outputs, []) - torch.stack([x["loss"] for x in outputs]).mean() - - def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): - if optimizer_idx == 0: - if batch_idx % 2 == 0: - optimizer.zero_grad() - - if optimizer_idx == 1: - if batch_idx % 5 == 0: - optimizer.zero_grad() - - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - on_tpu, - using_native_amp, - using_lbfgs, - ): - - assert optimizer_closure.__name__ == "train_step_and_backward_closure" - - if optimizer_idx == 0: - optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 3 == 0) - return - optimizer.step(closure=optimizer_closure) - - def configure_optimizers(self): - optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) - return [optimizer_1, optimizer_2], [lr_scheduler] - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=20, - limit_val_batches=1, - max_epochs=1, - weights_summary=None, - ) - trainer.fit(model) - - assert adam_zero_grad.call_count == 4 - assert sgd_zero_grad.call_count == 10 - - except MisconfigurationException as e: - assert "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed" in str(e) - - -def test_lightning_optimizer_automatic_optimization_make_optimizer_step_2(tmpdir): - """ - Test lightning optimize works with make_optimizer_step in automatic_optimization - """ - - with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ - patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: - - class TestModel(BoringModel): - - def training_step(self, batch, batch_idx, optimizer_idx=None): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def training_epoch_end(self, outputs): - outputs = sum(outputs, []) - torch.stack([x["loss"] for x in outputs]).mean() - - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - on_tpu, - using_native_amp, - using_lbfgs, - ): - - assert optimizer_closure.__name__ == "train_step_and_backward_closure" - - make_optimizer_step = None - if optimizer_idx == 0: - make_optimizer_step = batch_idx % 4 == 0 - optimizer.step(closure=optimizer_closure, make_optimizer_step=make_optimizer_step) - def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) @@ -469,5 +325,5 @@ def configure_optimizers(self): ) trainer.fit(model) - assert adam_zero_grad.call_count == 20 - assert sgd_zero_grad.call_count == 5 + assert adam_zero_grad.call_count == 4 + assert sgd_zero_grad.call_count == 10 diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index b11108c62e445..749efba426b04 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -189,3 +189,29 @@ def test_v1_4_0_deprecated_lightning_data_parallel(): dp_model = LightningDataParallel(model, device_ids=[0]) assert isinstance(dp_model, torch.nn.DataParallel) assert isinstance(dp_model.module, LightningParallelModule) + + +def test_v1_4_0_deprecated_manual_optimization_optimizer(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, *_, **kwargs): + opt = self.optimizers() + output = self.layer(batch) + loss = self.loss(batch, output) + self.manual_backward(loss, opt) + + @property + def automatic_optimization(self): + return False + + model = TestModel() + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + with pytest.deprecated_call( + match="`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4" + ): + trainer.fit(model) diff --git a/tests/plugins/test_rpc_sequential_plugin.py b/tests/plugins/test_rpc_sequential_plugin.py index d357161a27747..f1a4743080289 100644 --- a/tests/plugins/test_rpc_sequential_plugin.py +++ b/tests/plugins/test_rpc_sequential_plugin.py @@ -164,6 +164,7 @@ def training_step(self, batch, batch_idx): self.manual_backward(loss, opt) assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() > 0 opt.step() + opt.zero_grad() assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() == 0 def validation_step(self, batch, batch_idx): diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 3ad6e65512585..ff174b5cad648 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -30,3 +30,4 @@ python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler +python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 807c5585ea5bc..5d4a848429f2c 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -345,8 +345,8 @@ def training_step(self, batch, batch_idx, optimizer_idx): # ensure we forward the correct params to the optimizer # without retain_graph we can't do multiple backward passes - self.manual_backward(loss_2, opt_b, retain_graph=True) - self.manual_backward(loss_2, opt_a) + self.manual_backward(loss_2, retain_graph=True) + self.manual_backward(loss_2) assert self.layer.weight.grad is not None opt_b.step() @@ -416,6 +416,7 @@ def training_step(self, batch, batch_idx): self.manual_backward(loss, opt) opt.step() + opt.zero_grad() return loss.detach() if self.detach else loss @@ -547,7 +548,9 @@ def training_step(self, batch, batch_idx): if self.should_update: self.manual_backward(loss, opt) - opt.step(make_optimizer_step=self.should_have_updated) + if self.should_have_updated: + opt.step() + opt.zero_grad() return loss.detach() if self.detach else loss @@ -636,6 +639,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): assert self.layer.weight.grad is not None opt_b.step() + opt_b.zero_grad() def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -718,6 +722,7 @@ def optimizer_closure(): weight_before = self.layer.weight.clone() opt.step(closure=optimizer_closure) + opt.zero_grad() weight_after = self.layer.weight.clone() assert not torch.equal(weight_before, weight_after) @@ -840,7 +845,8 @@ def optimizer_closure(): retain_graph = num_backward != backward_idx # noqa E225 self.manual_backward(loss_1, opt, retain_graph=retain_graph) - opt.step(closure=optimizer_closure, make_optimizer_step=True) + opt.step(closure=optimizer_closure) + opt.zero_grad() def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -910,14 +916,16 @@ def dis_closure(): self.manual_backward(loss_dis, opt_dis) # this will accumulate gradients for 2 batches and then call opt_gen.step() - opt_gen.step(closure=gen_closure, make_optimizer_step=(batch_idx % 2 == 0), optim='sgd') + gen_closure() + if batch_idx % 2 == 0: + opt_gen.step(closure=gen_closure, optim='sgd') + opt_gen.zero_grad() # update discriminator every 4 baches # therefore, no gradient accumulation for discriminator if batch_idx % 4 == 0: - # Note: Set make_optimizer_step to True or it will use by default - # Trainer(accumulate_grad_batches=x) - opt_dis.step(closure=dis_closure, make_optimizer_step=True) + opt_dis.step(closure=dis_closure) + opt_dis.zero_grad() def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -1018,14 +1026,14 @@ def dis_closure(): make_manual_backward(loss_ones_gen, opt_dis, make_optimizer_step=make_dis_optimizer_step) # this will accumulate gradients for 2 batches and then call opt_gen.step() - opt_gen.step(closure=gen_closure, make_optimizer_step=make_gen_optimizer_step) + if make_gen_optimizer_step: + opt_gen.step(closure=gen_closure) + opt_gen.zero_grad() # update discriminator every 4 baches # therefore, no gradient accumulation for discriminator if make_dis_optimizer_step: - # Note: Set make_optimizer_step to True or it will use by default - # Trainer(accumulate_grad_batches=x) - opt_dis.step(closure=dis_closure, make_optimizer_step=True) + opt_dis.step(closure=dis_closure) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -1037,11 +1045,11 @@ def configure_optimizers(self): return [optimizer_gen, optimizer_dis] -def train_manual_optimization(tmpdir, accelerator): +def train_manual_optimization(tmpdir, accelerator, model_cls=TesManualOptimizationDDPModel): seed_everything(42) - model = TesManualOptimizationDDPModel() + model = model_cls() model_copy = deepcopy(model) model.val_dataloader = None model.training_epoch_end = None @@ -1084,3 +1092,68 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp_spawn(tmpdir """ train_manual_optimization(tmpdir, "ddp_spawn") + + +class TesManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + + # emulate gans training + opt_gen, opt_dis = self.optimizers() + + # Note: Be careful, don't log on the same key in self.log in both closure + # as they will be aggregated together on epoch_end + + world_size = torch_distrib.get_world_size(torch_distrib.group.WORLD) + assert world_size == 2 + + make_gen_optimizer_step = batch_idx % 2 == 1 + make_dis_optimizer_step = batch_idx % 4 == 0 + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss_ones = self.loss_ones(None, predictions) + loss_zeros = self.loss_zeros(None, predictions) + return loss_ones, loss_zeros + + def make_manual_backward(loss, opt, retain_graph=False, make_optimizer_step=True): + self.manual_backward(loss, opt, retain_graph=retain_graph) + if make_optimizer_step: + grad_clone = self.layer.weight.grad.clone() + assert self.manual_sync_grad() + self.layer.weight.grad /= world_size + assert torch.equal(self.layer.weight.grad, grad_clone) + + def gen_closure(): + loss_ones_gen, loss_zeros = compute_loss() + make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step) + make_manual_backward(loss_ones_gen, opt_gen, make_optimizer_step=make_gen_optimizer_step) + + def dis_closure(): + loss_ones_gen, loss_zeros = compute_loss() + make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True, make_optimizer_step=make_dis_optimizer_step) + make_manual_backward(loss_ones_gen, opt_dis, make_optimizer_step=make_dis_optimizer_step) + + # this will accumulate gradients for 2 batches and then call opt_gen.step() + with opt_gen.toggle_model(sync_grad=make_gen_optimizer_step): + gen_closure() + if make_gen_optimizer_step: + opt_gen.step() + opt_gen.zero_grad() + + with opt_dis.toggle_model(sync_grad=make_dis_optimizer_step): + dis_closure() + if make_dis_optimizer_step: + opt_dis.step() + opt_dis.zero_grad() + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir): + train_manual_optimization(tmpdir, "ddp", model_cls=TesManualOptimizationDDPModelToggleModel)