File tree Expand file tree Collapse file tree 1 file changed +20
-8
lines changed Expand file tree Collapse file tree 1 file changed +20
-8
lines changed Original file line number Diff line number Diff line change @@ -299,10 +299,14 @@ Under the hood, Lightning does the following:
299299
300300 for epoch in epochs:
301301 for batch in data:
302- loss = model.training_step(batch, batch_idx, ... )
303- optimizer.zero_grad()
304- loss.backward()
305- optimizer.step()
302+
303+ def closure ():
304+ loss = model.training_step(batch, batch_idx, ... )
305+ optimizer.zero_grad()
306+ loss.backward()
307+ return loss
308+
309+ optimizer.step(closure)
306310
307311 for lr_scheduler in lr_schedulers:
308312 lr_scheduler.step()
@@ -314,14 +318,22 @@ In the case of multiple optimizers, Lightning does the following:
314318 for epoch in epochs:
315319 for batch in data:
316320 for opt in optimizers:
317- loss = model.training_step(batch, batch_idx, optimizer_idx)
318- opt.zero_grad()
319- loss.backward()
320- opt.step()
321+
322+ def closure ():
323+ loss = model.training_step(batch, batch_idx, optimizer_idx)
324+ opt.zero_grad()
325+ loss.backward()
326+ return loss
327+
328+ opt.step(closure)
321329
322330 for lr_scheduler in lr_schedulers:
323331 lr_scheduler.step()
324332
333+ As can be seen in the code snippet above, Lightning defines a closure with ``training_step ``, ``zero_grad ``
334+ and ``backward `` for the optimizer to execute. This mechanism is in place to support optimizers which operate on the
335+ output of the closure (e.g. the loss) or need to call the closure several times (e.g. :class: `~torch.optim.LBFGS `).
336+
325337.. warning ::
326338 Before 1.2.2, Lightning internally calls ``backward ``, ``step `` and ``zero_grad `` in the order.
327339 From 1.2.2, the order is changed to ``zero_grad ``, ``backward `` and ``step ``.
You can’t perform that action at this time.
0 commit comments