From 7984cd8686c71f7472a80a9136c9c08f20bb9d57 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sat, 26 Dec 2020 23:17:06 +0100 Subject: [PATCH 1/3] Fix metric state reset --- pytorch_lightning/metrics/metric.py | 5 +++-- tests/metrics/test_metric.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 0f61b94c55139..a21242c3bdc7e 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -94,7 +94,8 @@ def add_state( reset to this value when ``self.reset()`` is called. dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom + and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction + only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter. persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. Default is ``False``. @@ -244,7 +245,7 @@ def reset(self): """ for attr, default in self._defaults.items(): current_val = getattr(self, attr) - if isinstance(current_val, torch.Tensor): + if isinstance(default, torch.Tensor): setattr(self, attr, deepcopy(default).to(current_val.device)) else: setattr(self, attr, deepcopy(default)) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index d97cd1a176cf2..aaaaa8c5da2cc 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -25,6 +25,18 @@ def update(self): def compute(self): pass +class DummyList(Metric): + name = "DummyList" + + def __init__(self): + super().__init__() + self.add_state("x", list(), dist_reduce_fx=None) + + def update(self): + pass + + def compute(self): + pass def test_inherit(): a = Dummy() @@ -77,12 +89,20 @@ def test_reset(): class A(Dummy): pass + class B(DummyList): + pass + a = A() assert a.x == 0 a.x = torch.tensor(5) a.reset() assert a.x == 0 + b = B() + assert isinstance(b.x, list) and len(b.x) == 0 + b.x.append(torch.tensor(5)) + b.reset() + assert isinstance(b.x, list) and len(b.x) == 0 def test_update(): class A(Dummy): From 8d809373a44018497a203ed57610a6c76f0e38d7 Mon Sep 17 00:00:00 2001 From: Tadej Date: Sat, 26 Dec 2020 23:19:54 +0100 Subject: [PATCH 2/3] Fix test --- tests/metrics/test_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index aaaaa8c5da2cc..27e89df51f0f0 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -100,7 +100,7 @@ class B(DummyList): b = B() assert isinstance(b.x, list) and len(b.x) == 0 - b.x.append(torch.tensor(5)) + b.x = torch.tensor(5) b.reset() assert isinstance(b.x, list) and len(b.x) == 0 From 473dfde3c52a1f6b953ea510cb294bc6091d740e Mon Sep 17 00:00:00 2001 From: Tadej Date: Sat, 26 Dec 2020 23:24:51 +0100 Subject: [PATCH 3/3] Improve formatting --- tests/metrics/test_metric.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 27e89df51f0f0..67e85624379a5 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -25,6 +25,7 @@ def update(self): def compute(self): pass + class DummyList(Metric): name = "DummyList" @@ -38,6 +39,7 @@ def update(self): def compute(self): pass + def test_inherit(): a = Dummy() @@ -91,7 +93,7 @@ class A(Dummy): class B(DummyList): pass - + a = A() assert a.x == 0 a.x = torch.tensor(5) @@ -104,6 +106,7 @@ class B(DummyList): b.reset() assert isinstance(b.x, list) and len(b.x) == 0 + def test_update(): class A(Dummy): def update(self, x):