Skip to content

Commit fe9803c

Browse files
authored
Fix manual backward for DeepSpeed (#13882)
1 parent dbafd6e commit fe9803c

File tree

5 files changed

+23
-6
lines changed

5 files changed

+23
-6
lines changed

src/pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def backward(
5959
model: "pl.LightningModule",
6060
closure_loss: Tensor,
6161
optimizer: Optional[Optimizer],
62+
optimizer_idx: Optional[int],
6263
*args: Any,
6364
**kwargs: Any,
6465
) -> None:
@@ -71,7 +72,7 @@ def backward(
7172
"""
7273
opt = optimizer or model.trainer.optimizers
7374
with amp.scale_loss(closure_loss, opt) as closure_loss:
74-
super().backward(model, closure_loss, optimizer, *args, **kwargs)
75+
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)
7576

7677
def optimizer_step(
7778
self,

src/pytorch_lightning/plugins/precision/deepspeed.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,15 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
6262
self.amp_type = amp_type
6363
self.amp_level = amp_level
6464

65-
def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:
65+
def backward(
66+
self,
67+
model: "pl.LightningModule",
68+
closure_loss: Tensor,
69+
optimizer: Optional[Optimizer],
70+
optimizer_idx: Optional[int],
71+
*args: Any,
72+
**kwargs: Any,
73+
) -> None:
6674
if is_overridden("backward", model):
6775
warning_cache.warn(
6876
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"

src/pytorch_lightning/plugins/precision/ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, precision: int) -> None:
4444
super().__init__()
4545
self.precision = precision
4646

47-
def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> None:
47+
def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None:
4848
if is_overridden("backward", model):
4949
warning_cache.warn(
5050
"You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle"

src/pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def backward(
6464
model: "pl.LightningModule",
6565
closure_loss: Tensor,
6666
optimizer: Optional[Optimizer],
67+
optimizer_idx: Optional[int],
6768
*args: Any,
6869
**kwargs: Any,
6970
) -> None:
@@ -76,7 +77,7 @@ def backward(
7677
"""
7778
# do backward pass
7879
if model is not None and isinstance(model, pl.LightningModule):
79-
model.backward(closure_loss, optimizer, *args, **kwargs)
80+
model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs)
8081
else:
8182
self._run_backward(closure_loss, *args, **kwargs)
8283

src/pytorch_lightning/strategies/strategy.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,14 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
171171
"""
172172
return optimizer.state_dict()
173173

174-
def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
174+
def backward(
175+
self,
176+
closure_loss: Tensor,
177+
optimizer: Optional[Optimizer],
178+
optimizer_idx: Optional[int],
179+
*args: Any,
180+
**kwargs: Any,
181+
) -> Tensor:
175182
"""Forwards backward-calls to the precision plugin.
176183
177184
Args:
@@ -181,7 +188,7 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
181188
assert self.lightning_module is not None
182189
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)
183190

184-
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
191+
self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs)
185192

186193
closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
187194
self.post_backward(closure_loss)

0 commit comments

Comments
 (0)