diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 95f7c0a6dcb7e..2f61fdc47397f 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -698,6 +698,12 @@ log_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict :noindex: +manual_backward +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward + :noindex: + print ~~~~~ @@ -916,7 +922,10 @@ True if using Automatic Mixed Precision (AMP) automatic_optimization ~~~~~~~~~~~~~~~~~~~~~~ -When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling your optimizers. However, we do take care of precision and any accelerators used. +When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling +your optimizers. However, we do take care of precision and any accelerators used. + +See :ref:`manual optimization` for details. .. code-block:: python @@ -931,7 +940,9 @@ When set to ``False``, Lightning does not automate the optimization process. Thi self.manual_backward(loss) opt.step() -This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. +This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note +that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. +Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. .. code-block:: python @@ -1086,13 +1097,6 @@ get_progress_bar_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict :noindex: -manual_backward -~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward - :noindex: - - on_after_backward ~~~~~~~~~~~~~~~~~ diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 422302ea8987e..d9b8d25911009 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -3,27 +3,39 @@ ************ Optimization ************ - Lightning offers two modes for managing the optimization process: -- automatic optimization (AutoOpt) +- automatic optimization - manual optimization -For the majority of research cases, **automatic optimization** will do the right thing for you and it is what -most users should use. +For the majority of research cases, **automatic optimization** will do the right thing for you and it is what most +users should use. For advanced/expert users who want to do esoteric optimization schedules or techniques, use **manual optimization**. ------- +----- Manual optimization =================== -For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable -to manually manage the optimization process. To do so, do the following: +For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to +manually manage the optimization process. + +This is only recommended for experts who need ultimate flexibility. +Lightning will handle only precision and accelerators logic. +The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc.. + +To manually optimize, do the following: + +* Set ``self.automatic_optimization=False`` in your ``LightningModule``'s ``__init__``. +* Use the following functions and call them manually: -* Set the ``automatic_optimization`` property to ``False`` in your ``LightningModule`` ``__init__`` function -* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``. + * ``self.optimizers()`` to access your optimizers (one or multiple) + * ``optimizer.zero_grad()`` to clear the gradients from the previous training step + * ``self.manual_backward(loss)`` instead of ``loss.backward()`` + * ``optimizer.step()`` to update your model parameters +Here is a minimal example of manual optimization. + .. testcode:: python from pytorch_lightning import LightningModule @@ -32,25 +44,37 @@ to manually manage the optimization process. To do so, do the following: def __init__(self): super().__init__() - # Important: This property activate ``manual optimization`` for your model + # Important: This property activates manual optimization. self.automatic_optimization = False def training_step(batch, batch_idx): opt = self.optimizers() + opt.zero_grad() loss = self.compute_loss(batch) self.manual_backward(loss) + opt.step() -.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc.. +.. warning:: + Before 1.2, ``optimizer.step()`` was calling ``optimizer.zero_grad()`` internally. + From 1.2, it is left to the user's expertise. -.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertise. +.. tip:: + Be careful where you call ``optimizer.zero_grad()``, or your model won't converge. + It is good practice to call ``optimizer.zero_grad()`` before ``self.manual_backward(loss)``. -.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such. +----- -.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you. +Gradient accumulation +--------------------- +You can accumulate gradients over batches similarly to +:attr:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches` of automatic optimization. +To perform gradient accumulation with one optimizer, you can do as such. -.. code-block:: python +.. testcode:: python + # accumulate gradients over `n` batches def __init__(self): + super().__init__() self.automatic_optimization = False def training_step(self, batch, batch_idx): @@ -59,36 +83,16 @@ to manually manage the optimization process. To do so, do the following: loss = self.compute_loss(batch) self.manual_backward(loss) - # accumulate gradient batches - if batch_idx % 2 == 0: + # accumulate gradients of `n` batches + if (batch_idx + 1) % n == 0: opt.step() opt.zero_grad() -.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward``, ``zero_grad`` and ``backward`` of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs `_. +----- -Here is the same example as above using a ``closure``. - -.. testcode:: python - - def __init__(self): - self.automatic_optimization = False - - def training_step(self, batch, batch_idx): - opt = self.optimizers() - - def closure(): - # Only zero_grad on the first batch to accumulate gradients - is_first_batch_to_accumulate = batch_idx % 2 == 0 - if is_first_batch_to_accumulate: - opt.zero_grad() - - loss = self.compute_loss(batch) - self.manual_backward(loss) - return loss - - opt.step(closure=closure) - -.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``. +Use multiple optimizers (like GANs) [manual] +-------------------------------------------- +Here is an example training a simple GAN with multiple optimizers. .. testcode:: python @@ -97,13 +101,12 @@ Here is the same example as above using a ``closure``. from pytorch_lightning import LightningModule class SimpleGAN(LightningModule): - def __init__(self): super().__init__() self.G = Generator() self.D = Discriminator() - # Important: This property activate ``manual optimization`` for this model + # Important: This property activates manual optimization. self.automatic_optimization = False def sample_z(self, n) -> Tensor: @@ -115,7 +118,8 @@ Here is the same example as above using a ``closure``. return self.G(z) def training_step(self, batch, batch_idx): - # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + # Implementation follows the PyTorch tutorial: + # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html g_opt, d_opt = self.optimizers() X, _ = batch @@ -126,10 +130,9 @@ Here is the same example as above using a ``closure``. g_X = self.sample_G(batch_size) - ########################### - # Optimize Discriminator # - ########################### - d_opt.zero_grad() + ########################## + # Optimize Discriminator # + ########################## d_x = self.D(X) errD_real = self.criterion(d_x, real_label) @@ -138,17 +141,17 @@ Here is the same example as above using a ``closure``. errD = (errD_real + errD_fake) + d_opt.zero_grad() self.manual_backward(errD) d_opt.step() - ####################### - # Optimize Generator # - ####################### - g_opt.zero_grad() - + ###################### + # Optimize Generator # + ###################### d_z = self.D(g_X) errG = self.criterion(d_z, real_label) + g_opt.zero_grad() self.manual_backward(errG) g_opt.step() @@ -159,32 +162,98 @@ Here is the same example as above using a ``closure``. d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5) return g_opt, d_opt -.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting. +----- + +Learning rate scheduling [manual] +--------------------------------- +You can call ``lr_scheduler.step()`` at arbitrary intervals. +Use ``self.lr_schedulers()`` in your :class:`~pytorch_lightning.LightningModule` to access any learning rate schedulers +defined in your :meth:`~pytorch_lightning.LightningModule.configure_optimizers`. + +.. warning:: + * Before 1.3, Lightning automatically called ``lr_scheduler.step()`` in both automatic and manual optimization. From + 1.3, ``lr_scheduler.step()`` is now for the user to call at arbitrary intervals. + * Note that the lr_dict keys, such as ``"step"`` and ``""interval"``, will be ignored even if they are provided in + your ``configure_optimizers()`` during manual optimization. + +Here is an example calling ``lr_scheduler.step()`` every step. + +.. testcode:: python + + # step every batch + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # do forward, backward, and optimization + ... + + # single scheduler + sch = self.lr_schedulers() + sch.step() + + # multiple schedulers + sch1, sch2 = self.lr_schedulers() + sch1.step() + sch2.step() + +If you want to call ``lr_scheduler.step()`` every ``n`` steps/epochs, do the following. + +.. testcode:: python + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # do forward, backward, and optimization + ... + + sch = self.lr_schedulers() + + # step every `n` batches + if (batch_idx + 1) % n == 0: + sch.step() + + # step every `n` epochs + if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % n == 0: + sch.step() + +----- + +Improve training speed with model toggling +------------------------------------------ +Toggling models can improve your training speed when performing gradient accumulation with multiple optimizers in a +distributed setting. Here is an explanation of what it does: -Considering the current optimizer as A and all other optimizers as B. -Toggling means that all parameters from B exclusive to A will have their ``requires_grad`` attribute set to ``False``. Their original state will be restored when exiting the context manager. +* Considering the current optimizer as A and all other optimizers as B. +* Toggling means that all parameters from B exclusive to A will have their ``requires_grad`` attribute set to ``False``. +* Their original state will be restored when exiting the context manager. When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed. +:class:`~pytorch_lightning.core.optimizer.LightningOptimizer` provides a +:meth:`~pytorch_lightning.core.optimizer.LightningOptimizer.toggle_model` function as a +:func:`contextlib.contextmanager` for advanced users. Here is an example for advanced use-case. .. testcode:: python # Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus. - class SimpleGAN(LightningModule): - ... - def __init__(self): + super().__init__() self.automatic_optimization = False def training_step(self, batch, batch_idx): - # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + # Implementation follows the PyTorch tutorial: + # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html g_opt, d_opt = self.optimizers() X, _ = batch @@ -194,14 +263,18 @@ Here is an example for advanced use-case. real_label = torch.ones((batch_size, 1), device=self.device) fake_label = torch.zeros((batch_size, 1), device=self.device) - accumulated_grad_batches = batch_idx % 2 == 0 + # Sync and clear gradients + # at the end of accumulation or + # at the end of an epoch. + is_last_batch_to_accumulate = \ + (batch_idx + 1) % 2 == 0 or self.trainer.is_last_batch g_X = self.sample_G(batch_size) - ########################### - # Optimize Discriminator # - ########################### - with d_opt.toggle_model(sync_grad=accumulated_grad_batches): + ########################## + # Optimize Discriminator # + ########################## + with d_opt.toggle_model(sync_grad=is_last_batch_to_accumulate): d_x = self.D(X) errD_real = self.criterion(d_x, real_label) @@ -211,36 +284,88 @@ Here is an example for advanced use-case. errD = (errD_real + errD_fake) self.manual_backward(errD) - if accumulated_grad_batches: + if is_last_batch_to_accumulate: d_opt.step() d_opt.zero_grad() - ####################### - # Optimize Generator # - ####################### - with g_opt.toggle_model(sync_grad=accumulated_grad_batches): + ###################### + # Optimize Generator # + ###################### + with g_opt.toggle_model(sync_grad=is_last_batch_to_accumulate): d_z = self.D(g_X) errG = self.criterion(d_z, real_label) self.manual_backward(errG) - if accumulated_grad_batches: + if is_last_batch_to_accumulate: g_opt.step() g_opt.zero_grad() self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True) +----- + +Use closure for LBFGS-like optimizers +------------------------------------- +It is a good practice to provide the optimizer with a closure function that performs a ``forward``, ``zero_grad`` and +``backward`` of your model. It is optional for most optimizers, but makes your code compatible if you switch to an +optimizer which requires a closure, such as :class:`torch.optim.LBFGS`. + +See `the PyTorch docs `_ for more about the closure. + +Here is an example using a closure function. + +.. testcode:: python + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def configure_optimizers(self): + return torch.optim.LBFGS(...) + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + + def closure(): + loss = self.compute_loss(batch) + opt.zero_grad() + self.manual_backward(loss) + return loss + + opt.step(closure=closure) + ------ +Access your own optimizer [manual] +---------------------------------- +``optimizer`` is a :class:`~pytorch_lightning.core.optimizer.LightningOptimizer` object wrapping your own optimizer +configured in your :meth:`~pytorch_lightning.LightningModule.configure_optimizers`. You can access your own optimizer +with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to +support accelerators and precision for you. + +.. testcode:: python + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(batch, batch_idx): + optimizer = self.optimizers() + + # `optimizer` is a `LightningOptimizer` wrapping the optimizer. + # To access it, do the following. + # However, it won't work on TPU, AMP, etc... + optimizer = optimizer.optimizer + ... + +----- + Automatic optimization ====================== -With Lightning most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()`` +With Lightning, most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()`` since Lightning automates that for you. -.. warning:: - Before 1.2.2, ``.zero_grad()`` was called after ``.backward()`` and ``.step()`` internally. - From 1.2.2, Lightning calls ``.zero_grad()`` before ``.backward()``. - -Under the hood Lightning does the following: +Under the hood, Lightning does the following: .. code-block:: python @@ -269,221 +394,220 @@ In the case of multiple optimizers, Lightning does the following: for lr_scheduler in lr_schedulers: lr_scheduler.step() +.. warning:: + Before 1.2.2, Lightning internally calls ``backward``, ``step`` and ``zero_grad`` in the order. + From 1.2.2, the order is changed to ``zero_grad``, ``backward`` and ``step``. + +----- Learning rate scheduling ------------------------ -Every optimizer you use can be paired with any `Learning Rate Scheduler `_. -In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers`` method: +Every optimizer you use can be paired with any +`Learning Rate Scheduler `_. In the basic +use-case, the scheduler(s) should be returned as the second output from the +:meth:`~pytorch_lightning.LightningModule.configure_optimizers` method: -.. testcode:: +.. testcode:: python # no LR scheduler def configure_optimizers(self): - return Adam(...) + return Adam(...) # Adam + LR scheduler def configure_optimizers(self): - optimizer = Adam(...) - scheduler = LambdaLR(optimizer, ...) - return [optimizer], [scheduler] + optimizer = Adam(...) + scheduler = LambdaLR(optimizer, ...) + return [optimizer], [scheduler] # Two optimizers each with a scheduler def configure_optimizers(self): - optimizer1 = Adam(...) - optimizer2 = SGD(...) - scheduler1 = LambdaLR(optimizer1, ...) - scheduler2 = LambdaLR(optimizer2, ...) - return [optimizer1, optimizer2], [scheduler1, scheduler2] + optimizer1 = Adam(...) + optimizer2 = SGD(...) + scheduler1 = LambdaLR(optimizer1, ...) + scheduler2 = LambdaLR(optimizer2, ...) + return [optimizer1, optimizer2], [scheduler1, scheduler2] -When there are schedulers in which the ``.step()`` method is conditioned on a metric value (for example the -:class:`~torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler), Lightning requires that the output -from ``configure_optimizers`` should be dicts, one for each optimizer, with the keyword ``monitor`` -set to metric that the scheduler should be conditioned on. +When there are schedulers in which the ``.step()`` method is conditioned on a metric value, such as the +:class:`~torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, Lightning requires that the output from +:meth:`~pytorch_lightning.LightningModule.configure_optimizers` should be dicts, one for each optimizer, with the +keyword ``"monitor"`` set to metric that the scheduler should be conditioned on. .. testcode:: - # The ReduceLROnPlateau scheduler requires a monitor - def configure_optimizers(self): - return { - 'optimizer': Adam(...), - 'lr_scheduler': ReduceLROnPlateau(optimizer, ...), - 'monitor': 'metric_to_track' - } - - # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler - def configure_optimizers(self): - optimizer1 = Adam(...) - optimizer2 = SGD(...) - scheduler1 = ReduceLROnPlateau(optimizer1, ...) - scheduler2 = LambdaLR(optimizer2, ...) - return ( - {'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'}, - {'optimizer': optimizer2, 'lr_scheduler': scheduler2}, - ) + # The ReduceLROnPlateau scheduler requires a monitor + def configure_optimizers(self): + optimizer = Adam(...) + return { + 'optimizer': optimizer, + 'lr_scheduler': ReduceLROnPlateau(optimizer, ...), + 'monitor': 'metric_to_track', + } + + # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler + def configure_optimizers(self): + optimizer1 = Adam(...) + optimizer2 = SGD(...) + scheduler1 = ReduceLROnPlateau(optimizer1, ...) + scheduler2 = LambdaLR(optimizer2, ...) + return ( + {'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'}, + {'optimizer': optimizer2, 'lr_scheduler': scheduler2}, + ) .. note:: - Metrics can be made availble to condition on by simply logging it using ``self.log('metric_to_track', metric_val)`` - in your lightning module. + Metrics can be made available to monitor by simply logging it using ``self.log('metric_to_track', metric_val)`` in + your :class:`~pytorch_lightning.LightningModule`. -By default, all schedulers will be called after each epoch ends. To change this behaviour, a scheduler configuration should be -returned as a dict which can contain the following keywords: +By default, all schedulers will be called after each epoch ends. To change this behaviour, a scheduler configuration +should be returned as a dict which can contain the following keywords: -* ``scheduler`` (required): the actual scheduler object -* ``monitor`` (optional): metric to condition -* ``interval`` (optional): either ``epoch`` (default) for stepping after each epoch ends or ``step`` for stepping +* ``"scheduler"`` (required): the actual scheduler object +* ``"monitor"`` (optional): metric to condition +* ``"interval"`` (optional): either ``"epoch"`` (default) for stepping after each epoch ends or ``"step"`` for stepping after each optimization step -* ``frequency`` (optional): how many epochs/steps should pass between calls to ``scheduler.step()``. Default is 1, +* ``"frequency"`` (optional): how many epochs/steps should pass between calls to ``scheduler.step()``. Default is 1, corresponding to updating the learning rate after every epoch/step. -* ``strict`` (optional): if set to ``True`` will enforce that value specified in ``monitor`` is available while trying - to call ``scheduler.step()``, and stop training if not found. If ``False`` will only give a warning and continue training - (without calling the scheduler). -* ``name`` (optional): if using the :class:`~pytorch_lightning.callbacks.LearningRateMonitor` callback to monitor the - learning rate progress, this keyword can be used to specify a specific name the learning rate should be logged as. +* ``"strict"`` (optional): if set to ``True``, will enforce that value specified in ``"monitor"`` is available while + trying to call ``scheduler.step()``, and stop training if not found. If ``False``, it will only give a warning and + continue training without calling the scheduler. +* ``"name"`` (optional): if using the :class:`~pytorch_lightning.callbacks.LearningRateMonitor` callback to monitor the + learning rate progress, this keyword can be used to specify a name the learning rate should be logged as. -.. testcode:: - - # Same as the above example with additional params passed to the first scheduler - # In this case the ReduceLROnPlateau will step after every 10 processed batches - def configure_optimizers(self): - optimizers = [Adam(...), SGD(...)] - schedulers = [ - { - 'scheduler': ReduceLROnPlateau(optimizers[0], ...), - 'monitor': 'metric_to_track', - 'interval': 'step', - 'frequency': 10, - 'strict': True, - }, - LambdaLR(optimizers[1], ...) - ] - return optimizers, schedulers +.. testcode:: python ----------- + # Same as the above example with additional params passed to the first scheduler + # In this case the ReduceLROnPlateau will step after every 10 processed batches + def configure_optimizers(self): + optimizers = [Adam(...), SGD(...)] + schedulers = [ + { + 'scheduler': ReduceLROnPlateau(optimizers[0], ...), + 'monitor': 'metric_to_track', + 'interval': 'step', + 'frequency': 10, + 'strict': True, + }, + LambdaLR(optimizers[1], ...) + ] + return optimizers, schedulers + +----- Use multiple optimizers (like GANs) ----------------------------------- -To use multiple optimizers return two or more optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers` +To use multiple optimizers (optionally with learning rate schedulers), return two or more optimizers from +:meth:`~pytorch_lightning.core.LightningModule.configure_optimizers`. -.. testcode:: +.. testcode:: python - # one optimizer - def configure_optimizers(self): - return Adam(...) + # two optimizers, no schedulers + def configure_optimizers(self): + return Adam(...), SGD(...) - # two optimizers, no schedulers - def configure_optimizers(self): - return Adam(...), SGD(...) + # two optimizers, one scheduler for adam only + def configure_optimizers(self): + opt1 = Adam(...) + opt2 = SGD(...) + optimizers = [opt1, opt2] + lr_schedulers = {'scheduler': ReduceLROnPlateau(opt1, ...), 'monitor': 'metric_to_track'} + return optimizers, lr_schedulers - # Two optimizers, one scheduler for adam only - def configure_optimizers(self): - return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'} + # two optimizers, two schedulers + def configure_optimizers(self): + opt1 = Adam(...) + opt2 = SGD(...) + return [opt1, opt2], [StepLR(opt1, ...), OneCycleLR(opt2, ...)] -Lightning will call each optimizer sequentially: +Under the hood, Lightning will call each optimizer sequentially: .. code-block:: python - for epoch in epochs: - for batch in data: - for opt in optimizers: - loss = train_step(batch, batch_idx, optimizer_idx) - opt.zero_grad() - loss.backward() - opt.step() + for epoch in epochs: + for batch in data: + for opt in optimizers: + loss = train_step(batch, batch_idx, optimizer_idx) + opt.zero_grad() + loss.backward() + opt.step() - for lr_scheduler in lr_schedulers: - lr_scheduler.step() + for lr_scheduler in lr_schedulers: + lr_scheduler.step() ----------- +----- Step optimizers at arbitrary intervals -------------------------------------- To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling, -override the :meth:`optimizer_step` function. +override the :meth:`~pytorch_lightning.LightningModule.optimizer_step` function. -For example, here step optimizer A every 2 batches and optimizer B every 4 batches +.. warning:: + If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to + ``optimizer.step()`` function as shown in the examples because ``training_step()``, ``optimizer.zero_grad()``, + ``backward()`` are called in the closure function. -.. testcode:: +For example, here step optimizer A every batch and optimizer B every 2 batches. - def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): - optimizer.zero_grad() +.. testcode:: python - # Alternating schedule for optimizer steps (ie: GANs) - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - # update generator opt every 2 steps + # Alternating schedule for optimizer steps (e.g. GANs) + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + # update generator every step if optimizer_idx == 0: - if batch_nb % 2 == 0 : - optimizer.step(closure=closure) + optimizer.step(closure=optimizer_closure) - # update discriminator opt every 4 steps + # update discriminator every 2 steps if optimizer_idx == 1: - if batch_nb % 4 == 0 : - optimizer.step(closure=closure) + if (batch_idx + 1) % 2 == 0: + optimizer.step(closure=optimizer_closure) -Here we add a learning-rate warm up + # ... + # add as many optimizers as you want -.. testcode:: +Here we add a learning rate warm-up. + +.. testcode:: python # learning rate warm-up - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - # warm up lr + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + # skip the first 500 steps if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) for pg in optimizer.param_groups: pg['lr'] = lr_scale * self.hparams.learning_rate # update params - optimizer.step(closure=closure) + optimizer.step(closure=optimizer_closure) -.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches and much more ... - -.. testcode:: - - # function hook in LightningModule - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - optimizer.step(closure=closure) +----- -.. note:: To access your wrapped Optimizer from ``LightningOptimizer``, do as follow. +Access your own optimizer +------------------------- +``optimizer`` is a :class:`~pytorch_lightning.core.optimizer.LightningOptimizer` object wrapping your own optimizer +configured in your :meth:`~pytorch_lightning.LightningModule.configure_optimizers`. You can access your own optimizer +with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to +support accelerators and precision for you. -.. testcode:: +.. testcode:: python # function hook in LightningModule - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - - # `optimizer is a ``LightningOptimizer`` wrapping the optimizer. - # To access it, do as follow: - optimizer = optimizer.optimizer - - # run step. However, it won't work on TPU, AMP, etc... - optimizer.step(closure=closure) - - ----------- - -Using the closure functions for optimization --------------------------------------------- - -When using optimization schemes such as LBFGS, the `second_order_closure` needs to be enabled. By default, this function is defined by wrapping the `training_step` and the backward steps as follows - -.. warning:: - Before 1.2.2, ``.zero_grad()`` was called outside the closure internally. - From 1.2.2, the closure calls ``.zero_grad()`` inside, so there is no need to define your own closure - when using similar optimizers to :class:`torch.optim.LBFGS` which requires reevaluation of the loss with the closure in ``optimizer.step()``. - -.. testcode:: - - def second_order_closure(pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden): - # Model training step on a given batch - result = pl_module.training_step(split_batch, batch_idx, opt_idx, hidden) - - # Model backward pass - pl_module.backward(result, optimizer, opt_idx) - - # on_after_backward callback - pl_module.on_after_backward(result.training_step_output, batch_idx, result.loss) - - return result - - # This default `second_order_closure` function can be enabled by passing it directly into the `optimizer.step` - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - # update params - optimizer.step(second_order_closure) + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + optimizer.step(closure=optimizer_closure) + + # `optimizer` is a `LightningOptimizer` wrapping the optimizer. + # To access it, do the following. + # However, it won't work on TPU, AMP, etc... + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + optimizer = optimizer.optimizer + optimizer.step(closure=optimizer_closure) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b320a9b223840..9830e6ca38fa6 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -260,7 +260,7 @@ def on_predict_end(self) -> None: def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ - Called after optimizer.step() and before optimizer.zero_grad(). + Called after ``training_step()`` and before ``optimizer.zero_grad()``. Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated. @@ -268,10 +268,13 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None: This is where it is called:: for optimizer in optimizers: - optimizer.step() + out = training_step(...) + model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() + backward() + Args: optimizer: The optimizer for which grads should be zeroed. """ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7efe88515b37e..54ea9d1bdb77e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1097,28 +1097,22 @@ def configure_optimizers(self): Return: Any of these 6 options. - - Single optimizer. - - List or Tuple - List of optimizers. - - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict). - - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' + - **Single optimizer**. + - **List or Tuple** of optimizers. + - **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers (or + multiple lr_dict). + - **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"`` key whose value is a single LR scheduler or lr_dict. - - Tuple of dictionaries as described, with an optional 'frequency' key. - - None - Fit will run without any optimizer. + - **Tuple of dictionaries** as described above, with an optional ``"frequency"`` key. + - **None** - Fit will run without any optimizer. Note: - The 'frequency' value is an int corresponding to the number of sequential batches - optimized with the specific optimizer. It should be given to none or to all of the optimizers. - There is a difference between passing multiple optimizers in a list, - and passing multiple optimizers in dictionaries with a frequency of 1: - In the former case, all optimizers will operate on the given batch in each optimization step. - In the latter, only one optimizer will operate on the given batch at every step. - - The lr_dict is a dictionary which contains the scheduler and its associated configuration. - The default configuration is shown below. + The lr_dict is a dictionary which contains the scheduler and its associated configuration. The default + configuration is shown below. .. code-block:: python - { + lr_dict = { 'scheduler': lr_scheduler, # The LR scheduler instance (required) 'interval': 'epoch', # The unit of the scheduler's step size 'frequency': 1, # The frequency of the scheduler @@ -1128,43 +1122,51 @@ def configure_optimizers(self): 'name': None, # Custom name for LearningRateMonitor to use } - Only the ``scheduler`` key is required, the rest will be set to the defaults above. + Only the ``"scheduler"`` key is required, the rest will be set to the defaults above. + + Note: + The ``"frequency"`` value is an ``int`` corresponding to the number of sequential batches optimized with the + specific optimizer. It should be given to none or to all of the optimizers. + + There is a difference between passing multiple optimizers in a list and passing multiple optimizers in + dictionaries with a frequency of 1: + In the former case, all optimizers will operate on the given batch in each optimization step. + In the latter, only one optimizer will operate on the given batch at every step. Examples:: # most cases def configure_optimizers(self): - opt = Adam(self.parameters(), lr=1e-3) - return opt + return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): - generator_opt = Adam(self.model_gen.parameters(), lr=0.01) - disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) - return generator_opt, disriminator_opt + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) + return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): - generator_opt = Adam(self.model_gen.parameters(), lr=0.01) - disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) - discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) - return [generator_opt, disriminator_opt], [discriminator_sched] + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) + dis_sch = CosineAnnealing(dis_opt, T_max=10) + return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_disc.parameters(), lr=0.02) - gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99), - 'interval': 'step'} # called after each training step - dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch - return [gen_opt, dis_opt], [gen_sched, dis_sched] + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) + gen_sch = {'scheduler': ExponentialLR(gen_opt, 0.99), + 'interval': 'step'} # called after each training step + dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch + return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_disc.parameters(), lr=0.02) + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, @@ -1172,32 +1174,22 @@ def configure_optimizers(self): ) Note: - Some things to know: - - Lightning calls ``.backward()`` and ``.step()`` on each optimizer - and learning rate scheduler as needed. - - - If you use 16-bit precision (``precision=16``), Lightning will automatically - handle the optimizers for you. - - - If you use multiple optimizers, :meth:`training_step` will have an additional - ``optimizer_idx`` parameter. - - - If you use LBFGS Lightning handles the closure function automatically for you. - - - If you use multiple optimizers, gradients will be calculated only - for the parameters of current optimizer at each training step. - - - If you need to control how often those optimizers step or override the - default ``.step()`` schedule, override the :meth:`optimizer_step` hook. - - - If you only want to call a learning rate scheduler every ``x`` step or epoch, - or want to monitor a custom metric, you can specify these in a lr_dict: + - Lightning calls ``.backward()`` and ``.step()`` on each optimizer and learning rate scheduler as needed. + - If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizers. + - If you use multiple optimizers, :meth:`training_step` will have an additional ``optimizer_idx`` parameter. + - If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you. + - If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer + at each training step. + - If you need to control how often those optimizers step or override the default ``.step()`` schedule, + override the :meth:`optimizer_step` hook. + - If you only want to call a learning rate scheduler every ``x`` step or epoch, or want to monitor a custom + metric, you can specify these in a lr_dict: .. code-block:: python - { + lr_dict = { 'scheduler': lr_scheduler, 'interval': 'step', # or 'epoch' 'monitor': 'val_f1', @@ -1210,23 +1202,21 @@ def configure_optimizers(self): def manual_backward(self, loss: Tensor, optimizer: Optional[Optimizer] = None, *args, **kwargs) -> None: """ Call this directly from your training_step when doing optimizations manually. - By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you + By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you. This function forwards all args to the .backward() call as well. - .. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set - - .. tip:: In manual mode we still automatically accumulate grad over batches if - Trainer(accumulate_grad_batches=x) is set and you use `optimizer.step()` + See :ref:`manual optimization` for more examples. Example:: def training_step(...): - opt_a, opt_b = self.optimizers() + opt = self.optimizers() loss = ... + opt.zero_grad() # automatically applies scaling, etc... self.manual_backward(loss) - opt_a.step() + opt.step() """ if optimizer is not None: rank_zero_deprecation( @@ -1336,18 +1326,18 @@ def optimizer_step( Warning: If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to ``optimizer.step()`` function as shown in the examples. This ensures that - ``train_step_and_backward_closure`` is called within + ``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within :meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`. Args: epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer - optimizer_idx: If you used multiple optimizers this indexes into that list. - optimizer_closure: closure for all optimizers - on_tpu: true if TPU backward is required - using_native_amp: True if using native amp - using_lbfgs: True if the matching optimizer is lbfgs + optimizer_idx: If you used multiple optimizers, this indexes into that list. + optimizer_closure: Closure for all optimizers + on_tpu: ``True`` if TPU backward is required + using_native_amp: ``True`` if using native amp + using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS` Examples:: @@ -1359,22 +1349,18 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, # Alternating schedule for optimizer steps (i.e.: GANs) def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): - # update generator opt every 2 steps + # update generator opt every step if optimizer_idx == 0: - if batch_idx % 2 == 0 : - optimizer.step(closure=optimizer_closure) - optimizer.zero_grad() + optimizer.step(closure=optimizer_closure) - # update discriminator opt every 4 steps + # update discriminator opt every 2 steps if optimizer_idx == 1: - if batch_idx % 4 == 0 : + if (batch_idx + 1) % 2 == 0 : optimizer.step(closure=optimizer_closure) - optimizer.zero_grad() # ... # add as many optimizers as you want - Here's another example showing how to use this for more advanced things such as learning rate warm-up: @@ -1391,7 +1377,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, # update params optimizer.step(closure=optimizer_closure) - optimizer.zero_grad() """ if not isinstance(optimizer, LightningOptimizer): @@ -1400,6 +1385,26 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer.step(closure=optimizer_closure) def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + """Override this method to change the default behaviour of ``optimizer.zero_grad()``. + + Args: + epoch: Current epoch + batch_idx: Index of current batch + optimizer: A PyTorch optimizer + optimizer_idx: If you used multiple optimizers this indexes into that list. + + Examples:: + + # DEFAULT + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad() + + # Set gradients to `None` instead of zero to improve performance. + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad(set_to_none=True) + + See :meth:`torch.optim.Optimizer.zero_grad` for the explanation of the above example. + """ optimizer.zero_grad() def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 71d9407062001..8dbc41821b24a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -651,9 +651,7 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: return batch_outputs def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): - """ - wrap the forward step in a closure so second order methods work - """ + """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)