Skip to content

Commit c7f30a2

Browse files
authored
[doc] Fix closure in manual optimization (#6374)
* Fix manual optimization docs * Fix typo. Thanks @import-antigravity
1 parent 2708c39 commit c7f30a2

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

docs/source/common/optimizers.rst

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@ to manually manage the optimization process. To do so, do the following:
4040
loss = self.compute_loss(batch)
4141
self.manual_backward(loss)
4242

43-
4443
.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..
4544

46-
.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize.
45+
.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertise.
4746

4847
.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such.
4948

@@ -65,8 +64,7 @@ to manually manage the optimization process. To do so, do the following:
6564
opt.step()
6665
opt.zero_grad()
6766
68-
69-
.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs <https://pytorch.org/docs/stable/optim.html#optimizer-step-closure>`_.
67+
.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward``, ``zero_grad`` and ``backward`` of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs <https://pytorch.org/docs/stable/optim.html#optimizer-step-closure>`_.
7068

7169
Here is the same example as above using a ``closure``.
7270

@@ -78,20 +76,20 @@ Here is the same example as above using a ``closure``.
7876
def training_step(self, batch, batch_idx):
7977
opt = self.optimizers()
8078

81-
def forward_and_backward():
79+
def closure():
80+
# Only zero_grad on the first batch to accumulate gradients
81+
is_first_batch_to_accumulate = batch_idx % 2 == 0
82+
if is_first_batch_to_accumulate:
83+
opt.zero_grad()
84+
8285
loss = self.compute_loss(batch)
8386
self.manual_backward(loss)
87+
return loss
8488

85-
opt.step(closure=forward_and_backward)
86-
87-
# accumulate gradient batches
88-
if batch_idx % 2 == 0:
89-
opt.zero_grad()
90-
89+
opt.step(closure=closure)
9190

9291
.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``.
9392

94-
9593
.. testcode:: python
9694

9795
import torch
@@ -174,10 +172,8 @@ Setting ``sync_grad`` to ``False`` will block this synchronization and improve y
174172

175173
Here is an example for advanced use-case.
176174

177-
178175
.. testcode:: python
179176

180-
181177
# Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus.
182178

183179
class SimpleGAN(LightningModule):

0 commit comments

Comments
 (0)