Skip to content

Commit 7f6154f

Browse files
dhkim0225Bordatchatonananthsubcarmocca
authored
Add Trainer(gradient_clip_algorithm='value'|'norm') (#6123)
* add changelog * add clip by value * fix bug in training tricks.rst * fix bug in trainer.rst * Update trainer.rst * Update trainer.rst * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/precision/deepspeed_precision.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/utilities/enums.py Co-authored-by: Jirka Borovec <[email protected]> * yapf formatting * update training tricks * update based on comment * update based on comment * Update pytorch_lightning/trainer/trainer.py Co-authored-by: ananthsub <[email protected]> * update based on comment * pep8 * mypy * mypy * Update docs/source/advanced/training_tricks.rst Co-authored-by: thomas chaton <[email protected]> * Update sharded_native_amp.py * Update test_sharded_parity.py * update test codes * Update test_tpu.py * Update pytorch_lightning/trainer/connectors/training_trick_connector.py Co-authored-by: Carlos Mocholí <[email protected]> * Update test_trainer.py * Update enums.py * Update enums.py * add super-class initialization to precision plugins. * add clip_grad horovod cpu test * add clip_grad horovod cpu test * use subprocess check_call * change order of horovod tests * set max_epochs 2 in horovod test * remove clip_grad_val test from horovod-cpu * remove "type: ignore" * divide clip grad val test in horovod * update based on comments * add super-class initialization to precision plugins. * bugfix * bugfix * revert some changes * revert some changes * Update tests/models/test_horovod.py * merge master * Delete signature test No point in testing a signature Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent b7f3a3c commit 7f6154f

File tree

17 files changed

+222
-49
lines changed

17 files changed

+222
-49
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))
2020

2121

22+
- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](https://github.com/PyTorchLightning/pytorch-lightning/pull/6123)).
23+
24+
2225
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
2326

2427

docs/source/advanced/training_tricks.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN.
2626

2727
Gradient Clipping
2828
-----------------
29-
Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient
30-
norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
29+
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
30+
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
31+
If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by default, this will
32+
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.
3133

3234
.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`
3335

@@ -39,6 +41,10 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
3941
# clip gradients with norm above 0.5
4042
trainer = Trainer(gradient_clip_val=0.5)
4143

44+
# clip gradients with value above 0.5
45+
# gradient_clip_algorithm types => :class:`~pytorch_lightning.utilities.enums.GradClipAlgorithmType`
46+
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')
47+
4248
----------
4349

4450
Stochastic Weight Averaging

pytorch_lightning/accelerators/accelerator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_lightning.trainer.states import TrainerState
2525
from pytorch_lightning.utilities import rank_zero_warn
2626
from pytorch_lightning.utilities.apply_func import move_data_to_device
27-
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
27+
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
2828

2929
if TYPE_CHECKING:
3030
from torch.cuda.amp import GradScaler
@@ -315,10 +315,14 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
315315
model_ref = self.lightning_module
316316
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
317317

318-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
318+
def clip_gradients(
319+
self,
320+
optimizer: Optimizer,
321+
clip_val: Union[int, float],
322+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
323+
) -> None:
319324
"""clips all the optimizer parameters to the given value"""
320-
321-
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val)
325+
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm)
322326

323327
def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
324328
"""Hook to do something on the end of an training epoch

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
3030
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
3131

3232
def __init__(self, amp_level: str = "O2") -> None:
33+
super().__init__()
3334
self.backend = AMPType.APEX
3435
self.amp_level = amp_level
3536

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
19+
from pytorch_lightning.utilities import GradClipAlgorithmType
1920
from pytorch_lightning.utilities.model_helpers import is_overridden
2021
from pytorch_lightning.utilities.warnings import WarningCache
2122

@@ -80,7 +81,7 @@ def clip_gradients(
8081
model: 'LightningModule',
8182
optimizer: 'Optimizer',
8283
clip_val: Union[int, float],
83-
norm_type: float = 2.0
84+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
8485
) -> None:
8586
"""
8687
DeepSpeed handles clipping gradients via the training type plugin.

pytorch_lightning/plugins/precision/double.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class DoublePrecisionPlugin(PrecisionPlugin):
6767
precision: int = 64
6868

6969
def __init__(self) -> None:
70+
super().__init__()
7071
self.patches: List[_DoublePrecisionPatch] = []
7172

7273
def connect(

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
3131

3232
def __init__(self) -> None:
33+
super().__init__()
3334
if not _NATIVE_AMP_AVAILABLE:
3435
raise MisconfigurationException(
3536
"You have asked for native AMP but your PyTorch version does not support it."

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818

1919
from pytorch_lightning.plugins.base_plugin import Plugin
20+
from pytorch_lightning.utilities import GradClipAlgorithmType
2021

2122
if TYPE_CHECKING:
2223
from torch.nn import Module
@@ -33,6 +34,13 @@ class PrecisionPlugin(Plugin):
3334
EPSILON: float = 1e-6
3435
precision: Union[str, int] = 32
3536

37+
def __init__(self) -> None:
38+
super().__init__()
39+
self.clip_grad_funcs = {
40+
GradClipAlgorithmType.VALUE: self.clip_grad_by_value,
41+
GradClipAlgorithmType.NORM: self.clip_grad_by_norm,
42+
}
43+
3644
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
3745
"""The master params of the model. Returns the plain model params here.
3846
Maybe different in other precision plugins.
@@ -103,20 +111,29 @@ def clip_gradients(
103111
model: 'LightningModule',
104112
optimizer: 'Optimizer',
105113
clip_val: Union[int, float],
106-
norm_type: float = 2.0
114+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
107115
) -> None:
108-
"""Clips the gradients to a specific value"""
116+
"""Clips the gradients"""
109117
if clip_val is None:
110118
return
111119

112-
grad_clip_val = float(clip_val)
113-
114-
if grad_clip_val <= 0:
120+
clip_val = float(clip_val)
121+
if clip_val <= 0:
115122
return
116123

124+
clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm]
125+
clip_grad_func(optimizer, clip_val) # type: ignore
126+
127+
def clip_grad_by_value(self, optimizer: 'Optimizer', clip_val: Union[int, float]) -> None:
128+
"""Clip gradients by value"""
117129
parameters = list(self.master_params(optimizer))
130+
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
118131

119-
max_norm = grad_clip_val
132+
def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
133+
"""Clip gradients by norm"""
134+
# TODO: separate TPU case from here
135+
parameters = list(self.master_params(optimizer))
136+
max_norm = clip_val
120137

121138
if isinstance(parameters, torch.Tensor):
122139
parameters = [parameters]

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
if TYPE_CHECKING:
2424
from torch.optim import Optimizer
2525

26-
from pytorch_lightning.core import LightningModule
27-
2826

2927
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
3028
"""Mixed Precision for Sharded Training
@@ -34,15 +32,11 @@ def __init__(self) -> None:
3432
super().__init__()
3533
self.scaler = ShardedGradScaler()
3634

37-
def clip_gradients(
35+
def clip_grad_by_norm(
3836
self,
39-
model: 'LightningModule',
4037
optimizer: 'Optimizer',
4138
clip_val: Union[int, float],
4239
norm_type: float = 2.0
4340
) -> None:
44-
if clip_val <= 0:
45-
return
46-
4741
optimizer = cast(OSS, optimizer)
4842
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.callbacks import GradientAccumulationScheduler
15+
from pytorch_lightning.utilities import GradClipAlgorithmType
1516
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1617

1718

@@ -23,6 +24,7 @@ def __init__(self, trainer):
2324
def on_trainer_init(
2425
self,
2526
gradient_clip_val,
27+
gradient_clip_algorithm,
2628
track_grad_norm,
2729
accumulate_grad_batches,
2830
truncated_bptt_steps,
@@ -32,7 +34,12 @@ def on_trainer_init(
3234
self.trainer.terminate_on_nan = terminate_on_nan
3335

3436
# gradient clipping
37+
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
38+
raise MisconfigurationException(
39+
f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
40+
)
3541
self.trainer.gradient_clip_val = gradient_clip_val
42+
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm
3643

3744
# gradient norm tracking
3845
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':

0 commit comments

Comments
 (0)