Skip to content

Commit 4913cbb

Browse files
Fix metric state reset (#5273)
* Fix metric state reset * Fix test * Improve formatting Co-authored-by: Ananya Harsh Jha <[email protected]>
1 parent dabfeca commit 4913cbb

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

pytorch_lightning/metrics/metric.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def add_state(
9494
reset to this value when ``self.reset()`` is called.
9595
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
9696
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
97-
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
97+
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
98+
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
9899
function in this parameter.
99100
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
100101
Default is ``False``.
@@ -244,7 +245,7 @@ def reset(self):
244245
"""
245246
for attr, default in self._defaults.items():
246247
current_val = getattr(self, attr)
247-
if isinstance(current_val, torch.Tensor):
248+
if isinstance(default, torch.Tensor):
248249
setattr(self, attr, deepcopy(default).to(current_val.device))
249250
else:
250251
setattr(self, attr, deepcopy(default))

tests/metrics/test_metric.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ def compute(self):
2626
pass
2727

2828

29+
class DummyList(Metric):
30+
name = "DummyList"
31+
32+
def __init__(self):
33+
super().__init__()
34+
self.add_state("x", list(), dist_reduce_fx=None)
35+
36+
def update(self):
37+
pass
38+
39+
def compute(self):
40+
pass
41+
42+
2943
def test_inherit():
3044
a = Dummy()
3145

@@ -77,12 +91,21 @@ def test_reset():
7791
class A(Dummy):
7892
pass
7993

94+
class B(DummyList):
95+
pass
96+
8097
a = A()
8198
assert a.x == 0
8299
a.x = torch.tensor(5)
83100
a.reset()
84101
assert a.x == 0
85102

103+
b = B()
104+
assert isinstance(b.x, list) and len(b.x) == 0
105+
b.x = torch.tensor(5)
106+
b.reset()
107+
assert isinstance(b.x, list) and len(b.x) == 0
108+
86109

87110
def test_update():
88111
class A(Dummy):

0 commit comments

Comments
 (0)