Skip to content

Commit b6de017

Browse files
SeanNarentchaton
authored andcommitted
[Fix] Call clip gradients if clip val greater than 0 (#6330)
* Call clip gradients if clip val greater than 0 * format * Format * Move to top of file
1 parent c8f4b1e commit b6de017

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242
- Resolve memory leak for evaluation ([#6326](https://github.com/PyTorchLightning/pytorch-lightning/pull/6326)
4343

4444

45+
- Ensure that clip gradients is only called if the value is greater than 0 ([#6330](https://github.com/PyTorchLightning/pytorch-lightning/pull/6330)
46+
47+
4548
## [1.2.2] - 2021-03-02
4649

4750
### Added

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def __init__(self):
3131
super().__init__()
3232
self.scaler = ShardedGradScaler()
3333

34-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
34+
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
35+
if clip_val <= 0:
36+
return
37+
3538
optimizer = cast(OSS, optimizer)
3639
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)

tests/plugins/test_sharded_plugin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import platform
3+
from unittest import mock
34

45
import pytest
56
import torch
@@ -12,6 +13,25 @@
1213
from tests.helpers.boring_model import BoringModel
1314

1415

16+
@pytest.mark.parametrize("clip_val", [0, 10])
17+
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
18+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine")
19+
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
20+
@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm')
21+
def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir):
22+
"""
23+
Ensure that clip gradients is only called if the value is greater than 0.
24+
"""
25+
model = BoringModel()
26+
trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val)
27+
trainer.fit(model)
28+
if clip_val > 0:
29+
mock_oss_clip_grad_norm.assert_called()
30+
else:
31+
mock_oss_clip_grad_norm.assert_not_called()
32+
33+
34+
@RunIf(fairscale=True)
1535
@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )])
1636
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
1737
def test_sharded_ddp_choice(tmpdir, accelerator):

0 commit comments

Comments
 (0)