Skip to content

Commit ae6d79a

Browse files
author
Sean Naren
authored
Merge 969c25e into b9cf122
2 parents b9cf122 + 969c25e commit ae6d79a

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,8 @@ def __init__(self) -> None:
3333
self.scaler = ShardedGradScaler()
3434

3535
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
36+
if clip_val <= 0:
37+
return
38+
3639
optimizer = cast(OSS, optimizer)
3740
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)

tests/plugins/test_sharded_plugin.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from unittest import mock
23

34
import pytest
45
import torch
@@ -11,6 +12,22 @@
1112
from tests.helpers.runif import RunIf
1213

1314

15+
@pytest.mark.parametrize("clip_val", [0, 10])
16+
@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True)
17+
@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm')
18+
def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir):
19+
"""
20+
Ensure that clip gradients is only called if the value is greater than 0.
21+
"""
22+
model = BoringModel()
23+
trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val)
24+
trainer.fit(model)
25+
if clip_val > 0:
26+
mock_oss_clip_grad_norm.assert_called()
27+
else:
28+
mock_oss_clip_grad_norm.assert_not_called()
29+
30+
1431
@RunIf(fairscale=True)
1532
@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )])
1633
def test_sharded_ddp_choice(tmpdir, accelerator):

0 commit comments

Comments
 (0)