@@ -299,10 +299,14 @@ Under the hood, Lightning does the following:
299299
300300 for epoch in epochs:
301301 for batch in data:
302- loss = model.training_step(batch, batch_idx, ... )
303- optimizer.zero_grad()
304- loss.backward()
305- optimizer.step()
302+
303+ def closure ():
304+ loss = model.training_step(batch, batch_idx, ... )
305+ optimizer.zero_grad()
306+ loss.backward()
307+ return loss
308+
309+ optimizer.step(closure)
306310
307311 for lr_scheduler in lr_schedulers:
308312 lr_scheduler.step()
@@ -314,14 +318,22 @@ In the case of multiple optimizers, Lightning does the following:
314318 for epoch in epochs:
315319 for batch in data:
316320 for opt in optimizers:
317- loss = model.training_step(batch, batch_idx, optimizer_idx)
318- opt.zero_grad()
319- loss.backward()
320- opt.step()
321+
322+ def closure ():
323+ loss = model.training_step(batch, batch_idx, optimizer_idx)
324+ opt.zero_grad()
325+ loss.backward()
326+ return loss
327+
328+ opt.step(closure)
321329
322330 for lr_scheduler in lr_schedulers:
323331 lr_scheduler.step()
324332
333+ As can be seen in the code snippet above, Lightning defines a closure with ``training_step ``, ``zero_grad ``
334+ and ``backward `` for the optimizer to execute. This mechanism is in place to support optimizers which operate on the
335+ output of the closure (e.g. the loss) or need to call the closure several times (e.g. :class: `~torch.optim.LBFGS `).
336+
325337.. warning ::
326338 Before 1.2.2, Lightning internally calls ``backward ``, ``step `` and ``zero_grad `` in the order.
327339 From 1.2.2, the order is changed to ``zero_grad ``, ``backward `` and ``step ``.
@@ -396,8 +408,11 @@ For example, here step optimizer A every batch and optimizer B every 2 batches.
396408 # update discriminator every 2 steps
397409 if optimizer_idx == 1:
398410 if (batch_idx + 1) % 2 == 0:
399- # the closure (which includes the `training_step `) won't run if the line below isn't executed
411+ # the closure (which includes the `training_step `) will be executed by ` optimizer.step `
400412 optimizer.step(closure=optimizer_closure)
413+ else:
414+ # optional: call the closure by itself to run `training_step ` + `backward ` without an optimizer step
415+ optimizer_closure()
401416
402417 # ...
403418 # add as many optimizers as you want
0 commit comments