Skip to content

Commit cf13cec

Browse files
manipopopoBorda
authored andcommitted
Fix Metric.state_dict (#5614)
* Fix Metric.state_dict * Update CHANGELOG.md * Update CHANGELOG.md * Detach tensors in a list if needed Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> (cherry picked from commit e87424a)
1 parent 3f2dcf5 commit cf13cec

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
170170
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))
171171

172172

173+
- Fixed `Metric`'s `state_dict` not included when child modules ([#5614](https://github.com/PyTorchLightning/pytorch-lightning/pull/5614))
174+
175+
173176
- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))
174177

175178

pytorch_lightning/metrics/metric.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,26 @@ def persistent(self, mode: bool = False):
298298
for key in self._persistent.keys():
299299
self._persistent[key] = mode
300300

301-
def state_dict(self, *args, **kwargs):
301+
def state_dict(self, destination=None, prefix='', keep_vars=False):
302+
destination = super().state_dict(
303+
destination=destination,
304+
prefix=prefix,
305+
keep_vars=keep_vars
306+
)
302307
# Register metric states to be part of the state_dict
303-
state_dict = super().state_dict()
304308
for key in self._defaults.keys():
305309
if self._persistent[key]:
306310
current_val = getattr(self, key)
307-
state_dict.update({key: current_val})
308-
return state_dict
311+
if not keep_vars:
312+
if torch.is_tensor(current_val):
313+
current_val = current_val.detach()
314+
elif isinstance(current_val, list):
315+
current_val = [
316+
cur_v.detach() if torch.is_tensor(cur_v) else cur_v
317+
for cur_v in current_val
318+
]
319+
destination[prefix + key] = current_val
320+
return destination
309321

310322
def _filter_kwargs(self, **kwargs):
311323
""" filter kwargs such that they match the update signature of the metric """

tests/metrics/test_metric.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pytest
88
import torch
9+
from torch import nn
910

1011
from pytorch_lightning.metrics.metric import Metric, MetricCollection
1112

@@ -211,6 +212,25 @@ def test_state_dict(tmpdir):
211212
assert metric.state_dict() == OrderedDict()
212213

213214

215+
def test_child_metric_state_dict():
216+
""" test that child metric states will be added to parent state dict """
217+
class TestModule(nn.Module):
218+
def __init__(self):
219+
super().__init__()
220+
self.metric = Dummy()
221+
self.metric.add_state('a', torch.tensor(0), persistent=True)
222+
self.metric.add_state('b', [], persistent=True)
223+
self.metric.register_buffer('c', torch.tensor(0))
224+
225+
module = TestModule()
226+
expected_state_dict = {
227+
'metric.a': torch.tensor(0),
228+
'metric.b': [],
229+
'metric.c': torch.tensor(0)
230+
}
231+
assert module.state_dict() == expected_state_dict
232+
233+
214234
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
215235
def test_device_and_dtype_transfer(tmpdir):
216236
metric = DummyMetric1()

0 commit comments

Comments
 (0)