Skip to content

Commit eaa16c7

Browse files
authored
docs: explain how Lightning uses closures for automatic optimization (#8551)
1 parent 75e18a5 commit eaa16c7

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

docs/source/common/optimizers.rst

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,14 @@ Under the hood, Lightning does the following:
299299
300300
for epoch in epochs:
301301
for batch in data:
302-
loss = model.training_step(batch, batch_idx, ...)
303-
optimizer.zero_grad()
304-
loss.backward()
305-
optimizer.step()
302+
303+
def closure():
304+
loss = model.training_step(batch, batch_idx, ...)
305+
optimizer.zero_grad()
306+
loss.backward()
307+
return loss
308+
309+
optimizer.step(closure)
306310
307311
for lr_scheduler in lr_schedulers:
308312
lr_scheduler.step()
@@ -314,14 +318,22 @@ In the case of multiple optimizers, Lightning does the following:
314318
for epoch in epochs:
315319
for batch in data:
316320
for opt in optimizers:
317-
loss = model.training_step(batch, batch_idx, optimizer_idx)
318-
opt.zero_grad()
319-
loss.backward()
320-
opt.step()
321+
322+
def closure():
323+
loss = model.training_step(batch, batch_idx, optimizer_idx)
324+
opt.zero_grad()
325+
loss.backward()
326+
return loss
327+
328+
opt.step(closure)
321329
322330
for lr_scheduler in lr_schedulers:
323331
lr_scheduler.step()
324332
333+
As can be seen in the code snippet above, Lightning defines a closure with ``training_step``, ``zero_grad``
334+
and ``backward`` for the optimizer to execute. This mechanism is in place to support optimizers which operate on the
335+
output of the closure (e.g. the loss) or need to call the closure several times (e.g. :class:`~torch.optim.LBFGS`).
336+
325337
.. warning::
326338
Before 1.2.2, Lightning internally calls ``backward``, ``step`` and ``zero_grad`` in the order.
327339
From 1.2.2, the order is changed to ``zero_grad``, ``backward`` and ``step``.
@@ -396,8 +408,11 @@ For example, here step optimizer A every batch and optimizer B every 2 batches.
396408
# update discriminator every 2 steps
397409
if optimizer_idx == 1:
398410
if (batch_idx + 1) % 2 == 0:
399-
# the closure (which includes the `training_step`) won't run if the line below isn't executed
411+
# the closure (which includes the `training_step`) will be executed by `optimizer.step`
400412
optimizer.step(closure=optimizer_closure)
413+
else:
414+
# optional: call the closure by itself to run `training_step` + `backward` without an optimizer step
415+
optimizer_closure()
401416

402417
# ...
403418
# add as many optimizers as you want

0 commit comments

Comments
 (0)