Skip to content

Commit b91747e

Browse files
authored
remove backward from training batch loop (#9265)
1 parent 285db62 commit b91747e

File tree

4 files changed

+6
-25
lines changed

4 files changed

+6
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6767
- Loop customization:
6868
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
6969
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
70+
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
7071

7172
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
7273

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1418,7 +1418,7 @@ def training_step(...):
14181418
self._verify_is_manual_optimization("manual_backward")
14191419

14201420
# backward
1421-
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, None, None, *args, **kwargs)
1421+
self.trainer.accelerator.backward(loss, None, None, *args, **kwargs)
14221422

14231423
def backward(
14241424
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Any, Callable, List, Optional, Tuple
1717

1818
import numpy as np
19-
import torch
2019
from deprecate import void
2120
from torch import Tensor
2221
from torch.optim import Optimizer
@@ -246,25 +245,6 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]:
246245
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
247246
return splits
248247

249-
# TODO: remove this method and update tests
250-
def backward(
251-
self,
252-
loss: Tensor,
253-
optimizer: Optional[torch.optim.Optimizer],
254-
opt_idx: Optional[int] = None,
255-
*args: Any,
256-
**kwargs: Any,
257-
) -> Tensor:
258-
"""Performs the backward step.
259-
260-
Args:
261-
loss: The loss value to back-propagate on
262-
optimizer: Current optimizer being used. ``None`` if using manual optimization.
263-
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
264-
"""
265-
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
266-
return loss
267-
268248
def _update_running_loss(self, current_loss: Tensor) -> None:
269249
"""Updates the running loss value with the current value"""
270250
if self.trainer.lightning_module.automatic_optimization:

tests/trainer/test_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def test_gradient_clipping_by_norm(tmpdir, precision):
910910
gradient_clip_val=1.0,
911911
)
912912

913-
old_backward = trainer.fit_loop.epoch_loop.batch_loop.backward
913+
old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward
914914

915915
def backward(*args, **kwargs):
916916
# test that gradient is clipped correctly
@@ -920,7 +920,7 @@ def backward(*args, **kwargs):
920920
assert (grad_norm - 1.0).abs() < 0.01, f"Gradient norm != 1.0: {grad_norm}"
921921
return ret_val
922922

923-
trainer.fit_loop.epoch_loop.batch_loop.backward = backward
923+
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward = backward
924924
trainer.fit(model)
925925

926926

@@ -945,7 +945,7 @@ def test_gradient_clipping_by_value(tmpdir, precision):
945945
default_root_dir=tmpdir,
946946
)
947947

948-
old_backward = trainer.fit_loop.epoch_loop.batch_loop.backward
948+
old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward
949949

950950
def backward(*args, **kwargs):
951951
# test that gradient is clipped correctly
@@ -958,7 +958,7 @@ def backward(*args, **kwargs):
958958
), f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."
959959
return ret_val
960960

961-
trainer.fit_loop.epoch_loop.batch_loop.backward = backward
961+
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward = backward
962962
trainer.fit(model)
963963

964964

0 commit comments

Comments
 (0)