From af2d044a424204a728b305bb1694f2c77bd27d49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 17:53:38 +0200 Subject: [PATCH 1/5] add run_backward --- pytorch_lightning/plugins/precision/deepspeed_precision.py | 3 +++ pytorch_lightning/plugins/precision/native_amp.py | 5 +++++ pytorch_lightning/plugins/precision/precision_plugin.py | 6 +++++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 52e788956aaef..91ab8f17d1750 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -65,6 +65,9 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any deepspeed_engine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) + def run_backward(self, tensor, model, *args, **kwargs): + model.backward(tensor, *args, **kwargs) + def clip_gradients( self, optimizer: Optimizer, diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 50c527f5f407d..4e8ade9a26e9f 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, Generator, Union import torch +from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl @@ -68,6 +69,10 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) closure_loss = self.scaler.scale(closure_loss) return super().pre_backward(model, closure_loss) + def run_backward(self, tensor, model, *args, **kwargs): + tensor = self.scaler.scale(tensor) + super().run_backward(tensor, model, *args, **kwargs) + def pre_optimizer_step( self, model: "pl.LightningModule", diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c81a474faad34..6e85a7f35a651 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -76,7 +76,7 @@ def backward( if model is not None and isinstance(model, pl.LightningModule): model.backward(closure_loss, optimizer, *args, **kwargs) else: - closure_loss.backward(*args, **kwargs) + self.run_backward(closure_loss, *args, **kwargs) def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: """Run after precision plugin executes backward. @@ -90,6 +90,10 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te model.trainer.call_hook("on_after_backward") return closure_loss + def run_backward(self, tensor, model, *args, **kwargs) -> None: + """Lightning-independent backward logic.""" + tensor.backward(*args, **kwargs) + def pre_optimizer_step( self, model: "pl.LightningModule", From 112c819646648fd4f44f517d1f9b73895073c015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 18:42:14 +0200 Subject: [PATCH 2/5] add types --- pytorch_lightning/plugins/precision/deepspeed_precision.py | 2 +- pytorch_lightning/plugins/precision/native_amp.py | 5 +++-- pytorch_lightning/plugins/precision/precision_plugin.py | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 91ab8f17d1750..6954adcbef164 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -65,7 +65,7 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any deepspeed_engine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) - def run_backward(self, tensor, model, *args, **kwargs): + def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: model.backward(tensor, *args, **kwargs) def clip_gradients( diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 4e8ade9a26e9f..2fcb2c862f7fe 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, Generator, Union import torch +from torch import Tensor from torch.nn import Module from torch.optim import LBFGS, Optimizer @@ -69,9 +70,9 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) closure_loss = self.scaler.scale(closure_loss) return super().pre_backward(model, closure_loss) - def run_backward(self, tensor, model, *args, **kwargs): + def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: tensor = self.scaler.scale(tensor) - super().run_backward(tensor, model, *args, **kwargs) + super()._run_backward(tensor, model, *args, **kwargs) def pre_optimizer_step( self, diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 6e85a7f35a651..76893b2cad80a 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -76,7 +76,7 @@ def backward( if model is not None and isinstance(model, pl.LightningModule): model.backward(closure_loss, optimizer, *args, **kwargs) else: - self.run_backward(closure_loss, *args, **kwargs) + self._run_backward(closure_loss, *args, **kwargs) def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: """Run after precision plugin executes backward. @@ -90,8 +90,8 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te model.trainer.call_hook("on_after_backward") return closure_loss - def run_backward(self, tensor, model, *args, **kwargs) -> None: - """Lightning-independent backward logic.""" + def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: + """Lightning-independent backward logic. Currently only used by Lightning Lite. Subject to further refactors.""" tensor.backward(*args, **kwargs) def pre_optimizer_step( From d65454b3728de5da4abed09f973cfda03d1d6968 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Oct 2021 01:44:16 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/precision/precision_plugin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 76893b2cad80a..9ec127886396c 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -91,7 +91,10 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te return closure_loss def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: - """Lightning-independent backward logic. Currently only used by Lightning Lite. Subject to further refactors.""" + """Lightning-independent backward logic. + + Currently only used by Lightning Lite. Subject to further refactors. + """ tensor.backward(*args, **kwargs) def pre_optimizer_step( From 29c82275ecdb8490a3fca1692ffd590f136822ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:40:01 +0200 Subject: [PATCH 4/5] Update pytorch_lightning/plugins/precision/native_amp.py Co-authored-by: thomas chaton --- pytorch_lightning/plugins/precision/native_amp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 2fcb2c862f7fe..9ba2fba800a46 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -71,7 +71,8 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) return super().pre_backward(model, closure_loss) def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: - tensor = self.scaler.scale(tensor) + if not self.is_bfloat16: + tensor = self.scaler.scale(tensor) super()._run_backward(tensor, model, *args, **kwargs) def pre_optimizer_step( From 06c69d77998541092661d5b69cb9a779a6f11392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:41:07 +0200 Subject: [PATCH 5/5] fix indentation --- pytorch_lightning/plugins/precision/native_amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 9ba2fba800a46..e298569996274 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -72,7 +72,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: if not self.is_bfloat16: - tensor = self.scaler.scale(tensor) + tensor = self.scaler.scale(tensor) super()._run_backward(tensor, model, *args, **kwargs) def pre_optimizer_step(