Skip to content

Commit a99b744

Browse files
awaelchlicarmocca
andauthored
Add unit tests for pl.utilities.grads (#9765)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 4dc32ad commit a99b744

File tree

5 files changed

+94
-5
lines changed

5 files changed

+94
-5
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
297297
- Update the logic to check for accumulation steps with deepspeed ([#9826](https://github.com/PyTorchLightning/pytorch-lightning/pull/9826))
298298

299299

300+
- `pytorch_lightning.utilities.grads.grad_norm` now raises an exception if parameter `norm_type <= 0` ([#9765](https://github.com/PyTorchLightning/pytorch-lightning/pull/9765))
301+
302+
303+
300304
- Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896))
301305

302306

303307
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
304308

305309

310+
306311
### Deprecated
307312

308313
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ def on_trainer_init(
4949
)
5050

5151
# gradient norm tracking
52-
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != "inf":
52+
if track_grad_norm != -1 and not (
53+
(isinstance(track_grad_norm, (int, float)) or track_grad_norm == "inf") and float(track_grad_norm) > 0
54+
):
5355
raise MisconfigurationException(
54-
f"`track_grad_norm` should be an int, a float or 'inf' (infinity norm). Got {track_grad_norm}."
56+
f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}."
5557
)
5658

5759
self.trainer._terminate_on_nan = terminate_on_nan

pytorch_lightning/utilities/grads.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def grad_norm(module: Module, norm_type: Union[float, int, str]) -> Dict[str, fl
3535
as a single vector.
3636
"""
3737
norm_type = float(norm_type)
38+
if norm_type <= 0:
39+
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
40+
3841
norms = {
3942
f"grad_{norm_type}_norm_{name}": p.grad.data.norm(norm_type).item()
4043
for name, p in module.named_parameters()

tests/trainer/test_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -911,9 +911,10 @@ def test_invalid_terminate_on_nan(tmpdir):
911911
Trainer(default_root_dir=tmpdir, terminate_on_nan="False")
912912

913913

914-
def test_invalid_track_grad_norm(tmpdir):
915-
with pytest.raises(MisconfigurationException, match="`track_grad_norm` should be an int, a float"):
916-
Trainer(default_root_dir=tmpdir, track_grad_norm="nan")
914+
@pytest.mark.parametrize("track_grad_norm", [0, torch.tensor(1), "nan"])
915+
def test_invalid_track_grad_norm(tmpdir, track_grad_norm):
916+
with pytest.raises(MisconfigurationException, match="`track_grad_norm` must be a positive number or 'inf'"):
917+
Trainer(default_root_dir=tmpdir, track_grad_norm=track_grad_norm)
917918

918919

919920
@mock.patch("torch.Tensor.backward")

tests/utilities/test_grads.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from unittest.mock import Mock
15+
16+
import pytest
17+
import torch
18+
import torch.nn as nn
19+
20+
from pytorch_lightning.utilities import grad_norm
21+
22+
23+
@pytest.mark.parametrize(
24+
"norm_type,expected",
25+
[
26+
(
27+
1,
28+
{"grad_1.0_norm_param0": 1 + 2 + 3, "grad_1.0_norm_param1": 4 + 5, "grad_1.0_norm_total": 15},
29+
),
30+
(
31+
2,
32+
{
33+
"grad_2.0_norm_param0": pow(1 + 4 + 9, 0.5),
34+
"grad_2.0_norm_param1": pow(16 + 25, 0.5),
35+
"grad_2.0_norm_total": pow(1 + 4 + 9 + 16 + 25, 0.5),
36+
},
37+
),
38+
(
39+
3.14,
40+
{
41+
"grad_3.14_norm_param0": pow(1 + 2 ** 3.14 + 3 ** 3.14, 1 / 3.14),
42+
"grad_3.14_norm_param1": pow(4 ** 3.14 + 5 ** 3.14, 1 / 3.14),
43+
"grad_3.14_norm_total": pow(1 + 2 ** 3.14 + 3 ** 3.14 + 4 ** 3.14 + 5 ** 3.14, 1 / 3.14),
44+
},
45+
),
46+
(
47+
"inf",
48+
{
49+
"grad_inf_norm_param0": max(1, 2, 3),
50+
"grad_inf_norm_param1": max(4, 5),
51+
"grad_inf_norm_total": max(1, 2, 3, 4, 5),
52+
},
53+
),
54+
],
55+
)
56+
def test_grad_norm(norm_type, expected):
57+
"""Test utility function for computing the p-norm of individual parameter groups and norm in total."""
58+
59+
class Model(nn.Module):
60+
def __init__(self):
61+
super().__init__()
62+
self.param0 = nn.Parameter(torch.rand(3))
63+
self.param1 = nn.Parameter(torch.rand(2, 1))
64+
self.param0.grad = torch.tensor([-1.0, 2.0, -3.0])
65+
self.param1.grad = torch.tensor([[-4.0], [5.0]])
66+
# param without grad should not contribute to norm
67+
self.param2 = nn.Parameter(torch.rand(1))
68+
69+
model = Model()
70+
norms = grad_norm(model, norm_type)
71+
expected = {k: round(v, 4) for k, v in expected.items()}
72+
assert norms == expected
73+
74+
75+
@pytest.mark.parametrize("norm_type", [-1, 0])
76+
def test_grad_norm_invalid_norm_type(norm_type):
77+
with pytest.raises(ValueError, match="`norm_type` must be a positive number or 'inf'"):
78+
grad_norm(Mock(), norm_type)

0 commit comments

Comments
 (0)