Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def add_state(
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.
Expand Down Expand Up @@ -244,7 +245,7 @@ def reset(self):
"""
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(current_val, torch.Tensor):
if isinstance(default, torch.Tensor):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))
Expand Down
23 changes: 23 additions & 0 deletions tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ def compute(self):
pass


class DummyList(Metric):
name = "DummyList"

def __init__(self):
super().__init__()
self.add_state("x", list(), dist_reduce_fx=None)

def update(self):
pass

def compute(self):
pass


def test_inherit():
a = Dummy()

Expand Down Expand Up @@ -77,12 +91,21 @@ def test_reset():
class A(Dummy):
pass

class B(DummyList):
pass

a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.reset()
assert a.x == 0

b = B()
assert isinstance(b.x, list) and len(b.x) == 0
b.x = torch.tensor(5)
b.reset()
assert isinstance(b.x, list) and len(b.x) == 0


def test_update():
class A(Dummy):
Expand Down