From d239dea46a6ddd8802edb7596f7b35159f62b1c3 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 31 Mar 2021 14:12:00 +0530 Subject: [PATCH 1/4] Update clip gradients signature for precision plugins --- pytorch_lightning/accelerators/accelerator.py | 2 +- .../plugins/precision/deepspeed_precision.py | 8 +++++++- pytorch_lightning/plugins/precision/precision_plugin.py | 8 +++++++- pytorch_lightning/plugins/precision/sharded_native_amp.py | 8 +++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 569af875e6c64..048f3365e1753 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -318,7 +318,7 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients(optimizer, clip_val) + self.precision_plugin.clip_gradients(self.model, optimizer, clip_val) def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: """Hook to do something on the end of an training epoch diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 6bcbb5ad851dc..6a8357229a6e6 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -75,7 +75,13 @@ def backward( return closure_loss - def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + def clip_gradients( + self, + model: 'LightningModule', + optimizer: 'Optimizer', + clip_val: Union[int, float], + norm_type: float = 2.0 + ) -> None: """ DeepSpeed handles clipping gradients via the training type plugin. """ diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 7172d82391bd3..19eb4e0cfb21b 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -98,7 +98,13 @@ def pre_optimizer_step( def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> None: """Hook to do something after each optimizer step.""" - def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + def clip_gradients( + self, + model: 'LightningModule', + optimizer: 'Optimizer', + clip_val: Union[int, float], + norm_type: float = 2.0 + ) -> None: """Clips the gradients to a specific value""" if clip_val is None: return diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index 39dc01f97df11..23da73550d269 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -32,7 +32,13 @@ def __init__(self) -> None: super().__init__() self.scaler = ShardedGradScaler() - def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + def clip_gradients( + self, + model: 'LightningModule', + optimizer: 'Optimizer', + clip_val: Union[int, float], + norm_type: float = 2.0 + ) -> None: if clip_val <= 0: return From e3f04c63e45b6444b59cda061900cbc30049e4ec Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 31 Mar 2021 14:24:30 +0530 Subject: [PATCH 2/4] Add import for typing --- pytorch_lightning/plugins/precision/sharded_native_amp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index 23da73550d269..b9326b665c00d 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from torch.optim import Optimizer + from pytorch_lightning.core import LightningModule + class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Mixed Precision for Sharded Training From a4e5f1ad87c4cf2603a1d6daf26e85e1c26c8385 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 31 Mar 2021 14:39:39 +0530 Subject: [PATCH 3/4] Add test --- tests/plugins/test_precision_plugin.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/plugins/test_precision_plugin.py diff --git a/tests/plugins/test_precision_plugin.py b/tests/plugins/test_precision_plugin.py new file mode 100644 index 0000000000000..fc00f22a6413e --- /dev/null +++ b/tests/plugins/test_precision_plugin.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from inspect import signature + +from pytorch_lightning.plugins.precision import PrecisionPlugin + + +def test_precision_clip_gradients_signature(): + + expected_params_list = ['self', 'model', 'optimizer', 'clip_val', 'norm_type'] + + params = signature(PrecisionPlugin.clip_gradients).parameters + params_list = [param.name for param in params.values()] + + assert params_list == expected_params_list From 3a1a4b92a40bef38b1e4c57b22e1a51225b58b8a Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 31 Mar 2021 14:43:43 +0530 Subject: [PATCH 4/4] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14a3410a96bf3..81846809fbf85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677)) +- Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))