Skip to content

Commit d072e44

Browse files
authored
Fix dtype inference during gradient norm computation (#14051)
1 parent b4ade23 commit d072e44

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6767
- Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992))
6868

6969

70+
- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))
71+
72+
7073
## [1.7.0] - 2022-08-02
7174

7275
### Added

src/pytorch_lightning/utilities/grads.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator
4141
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
4242

4343
norms = {
44-
f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type).item()
44+
f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type)
4545
for name, p in module.named_parameters()
4646
if p.grad is not None
4747
}
4848
if norms:
49-
total_norm = torch.tensor(list(norms.values())).norm(norm_type).item()
49+
total_norm = torch.tensor(list(norms.values())).norm(norm_type)
5050
norms[f"grad_{norm_type}_norm_total"] = total_norm
51-
norms = {k: round(v, 4) for k, v in norms.items()}
51+
norms = {k: round(v.item(), 4) for k, v in norms.items()}
5252
return norms

tests/tests_pytorch/utilities/test_grads.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,17 @@ def __init__(self):
7676
def test_grad_norm_invalid_norm_type(norm_type):
7777
with pytest.raises(ValueError, match="`norm_type` must be a positive number or 'inf'"):
7878
grad_norm(Mock(), norm_type)
79+
80+
81+
def test_grad_norm_with_double_dtype():
82+
class Model(nn.Module):
83+
def __init__(self):
84+
super().__init__()
85+
dtype = torch.double
86+
self.param = nn.Parameter(torch.tensor(1.0, dtype=dtype))
87+
# grad norm of this would become infinite
88+
self.param.grad = torch.tensor(1e23, dtype=dtype)
89+
90+
model = Model()
91+
norms = grad_norm(model, 2)
92+
assert all(torch.isfinite(torch.tensor(v)) for v in norms.values()), norms

0 commit comments

Comments
 (0)