Skip to content

Commit 4993c38

Browse files
authored
Merge branch 'master' into tests/examples
2 parents bc9d37b + 465ec75 commit 4993c38

File tree

4 files changed

+63
-15
lines changed

4 files changed

+63
-15
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
3434

3535

36+
- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
37+
38+
3639
### Changed
3740

3841
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
@@ -49,6 +52,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4952

5053
- Fixed feature-lack in hpc load ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526))
5154

55+
56+
- Fixed metrics states being overridden in ddp mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
57+
58+
5259
## [1.0.5] - 2020-11-03
5360

5461
### Added

docs/source/metrics.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
131131
It is highly recommended to re-initialize the metric per mode as
132132
shown in the examples above.
133133

134+
.. note::
135+
136+
Metric states will as default add their internal state to the models ``state_dict``.
137+
To change this after initializing the metric the method ``.persistent(mode)`` can
138+
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.
139+
134140
*********************
135141
Implementing a Metric
136142
*********************

pytorch_lightning/metrics/metric.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def __init__(
8181
self._forward_cache = None
8282

8383
# initialize state
84-
self._reductions = {}
8584
self._defaults = {}
85+
self._persistent = {}
86+
self._reductions = {}
8687

8788
def add_state(
8889
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
@@ -138,16 +139,10 @@ def add_state(
138139
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
139140
)
140141

141-
if isinstance(default, torch.Tensor):
142-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
143-
# persistent keyword is only supported in torch >= 1.6.0
144-
self.register_buffer(name, default, persistent=persistent)
145-
else:
146-
self.register_buffer(name, default)
147-
else:
148-
setattr(self, name, default)
142+
setattr(self, name, default)
149143

150144
self._defaults[name] = deepcopy(default)
145+
self._persistent[name] = persistent
151146
self._reductions[name] = dist_reduce_fx
152147

153148
@torch.jit.unused
@@ -265,3 +260,36 @@ def __setstate__(self, state):
265260
self.__dict__.update(state)
266261
self.update = self._wrap_update(self.update)
267262
self.compute = self._wrap_compute(self.compute)
263+
264+
def _apply(self, fn):
265+
""" Overwrite _apply function such that we can also move metric states
266+
to the correct device when `.to`, `.cuda`, etc methods are called
267+
"""
268+
self = super()._apply(fn)
269+
# Also apply fn to metric states
270+
for key in self._defaults.keys():
271+
current_val = getattr(self, key)
272+
if isinstance(current_val, torch.Tensor):
273+
setattr(self, key, fn(current_val))
274+
elif isinstance(current_val, Sequence):
275+
setattr(self, key, [fn(cur_v) for cur_v in current_val])
276+
else:
277+
raise TypeError('Expected metric state to be either a torch.Tensor'
278+
f'or a list of torch.Tensor, but encountered {current_val}')
279+
return self
280+
281+
def persistent(self, mode: bool = True):
282+
""" Method for post-init to change if metric states should be saved to
283+
its state_dict
284+
"""
285+
for key in self._persistent.keys():
286+
self._persistent[key] = mode
287+
288+
def state_dict(self, *args, **kwargs):
289+
# Register metric states to be part of the state_dict
290+
state_dict = super().state_dict()
291+
for key in self._defaults.keys():
292+
if self._persistent[key]:
293+
current_val = getattr(self, key)
294+
state_dict.update({key: current_val})
295+
return state_dict

tests/metrics/test_metric_lightning.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
import os
1+
import pytest
22

33
import torch
4+
45
from pytorch_lightning import Trainer
56
from pytorch_lightning.metrics import Metric
67
from tests.base.boring_model import BoringModel
8+
import tests.base.develop_utils as tutils
79

810

911
class SumMetric(Metric):
@@ -54,15 +56,19 @@ def test_metric_lightning_log(tmpdir):
5456
class TestModel(BoringModel):
5557
def __init__(self):
5658
super().__init__()
57-
self.metric = SumMetric()
59+
self.metric_step = SumMetric()
60+
self.metric_epoch = SumMetric()
5861
self.sum = 0.0
5962

6063
def training_step(self, batch, batch_idx):
6164
x = batch
62-
self.metric(x.sum())
65+
self.metric_step(x.sum())
6366
self.sum += x.sum()
64-
self.log("sum", self.metric, on_epoch=True, on_step=False)
65-
return self.step(x)
67+
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
68+
return {'loss': self.step(x), 'data': x}
69+
70+
def training_epoch_end(self, outs):
71+
self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum()))
6672

6773
model = TestModel()
6874
model.val_dataloader = None
@@ -78,7 +84,8 @@ def training_step(self, batch, batch_idx):
7884
trainer.fit(model)
7985

8086
logged = trainer.logged_metrics
81-
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
87+
assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
88+
assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
8289

8390

8491
def test_scriptable(tmpdir):

0 commit comments

Comments
 (0)