diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 52e788956aaef..6954adcbef164 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -65,6 +65,9 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any deepspeed_engine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) + def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: + model.backward(tensor, *args, **kwargs) + def clip_gradients( self, optimizer: Optimizer, diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 50c527f5f407d..e298569996274 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -15,6 +15,8 @@ from typing import Any, Callable, Dict, Generator, Union import torch +from torch import Tensor +from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl @@ -68,6 +70,11 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) closure_loss = self.scaler.scale(closure_loss) return super().pre_backward(model, closure_loss) + def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: + if not self.is_bfloat16: + tensor = self.scaler.scale(tensor) + super()._run_backward(tensor, model, *args, **kwargs) + def pre_optimizer_step( self, model: "pl.LightningModule", diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c81a474faad34..9ec127886396c 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -76,7 +76,7 @@ def backward( if model is not None and isinstance(model, pl.LightningModule): model.backward(closure_loss, optimizer, *args, **kwargs) else: - closure_loss.backward(*args, **kwargs) + self._run_backward(closure_loss, *args, **kwargs) def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: """Run after precision plugin executes backward. @@ -90,6 +90,13 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te model.trainer.call_hook("on_after_backward") return closure_loss + def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: + """Lightning-independent backward logic. + + Currently only used by Lightning Lite. Subject to further refactors. + """ + tensor.backward(*args, **kwargs) + def pre_optimizer_step( self, model: "pl.LightningModule",