Skip to content

Commit 02acb21

Browse files
SkafteNickiteddykokers-rog
authored
[Metrics] Disable default reset after compute (#5409)
* reset * self._cache -> cache (make cache local variable so it is not overwritten) * pep8 * fix metric result integration * rm print statements * better comment * changelog * Update docs/source/metrics.rst Co-authored-by: Roger Shieh <[email protected]> Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: Roger Shieh <[email protected]>
1 parent 29bcf30 commit 02acb21

File tree

6 files changed

+47
-26
lines changed

6 files changed

+47
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6464
- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
6565

6666

67+
- Metric `compute()` method will no longer automatically call `reset()` ([#5409](https://github.com/PyTorchLightning/pytorch-lightning/pull/5409/))
68+
69+
6770
- Set PyTorch 1.4 as min requirements, also for testing and examples `torchvision>=0.5` and `torchtext>=0.5` ([#5418](https://github.com/PyTorchLightning/pytorch-lightning/pull/5418))
6871

6972

docs/source/metrics.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to t
2020
serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the
2121
provided input.
2222

23+
.. warning::
24+
From v1.2 onward ``compute()`` will no longer automatically call ``reset()``,
25+
and it is up to the user to reset metrics between epochs, except in the case where the
26+
metric is directly passed to ``LightningModule``s ``self.log``.
27+
2328
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
2429
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
2530
logic present in ``.compute()`` is applied to state information from all processes.

pytorch_lightning/core/step_result.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
306306
Gets the metrics to log at the end of epoch
307307
"""
308308
result = {}
309-
310309
meta = self['meta']
311310
for k, options in meta.items():
312311
if k == '_internal':
@@ -320,12 +319,16 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
320319
if options['logger'] and options['on_epoch']:
321320
if isinstance(self[k], Metric):
322321
result[dl_key] = self[k].compute().detach()
322+
self[k].reset()
323323
else:
324324
result[dl_key] = self[k]
325325

326326
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
327-
# compute metric on epoch anyway so state does not accumulate
327+
# reset metric anyway so state does not accumulate
328+
# NOTE: we must compute before reseting just in case the computed value is needed
329+
# later (i.e. if the step metric gets visited first, and then the epoch metric)
328330
self[k].compute()
331+
self[k].reset()
329332

330333
return result
331334

@@ -348,12 +351,16 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False):
348351
if options['prog_bar'] and options['on_epoch']:
349352
if isinstance(self[k], Metric):
350353
result[dl_key] = self[k].compute().detach()
354+
self[k].reset()
351355
else:
352356
result[dl_key] = self[k]
353357

354358
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
355-
# compute metric on epoch anyway so state does not accumulate
359+
# reset metric anyway so state does not accumulate
360+
# NOTE: we must compute before reseting just in case the computed value is needed
361+
# later (i.e. if the step metric gets visited first, and then the epoch metric)
356362
self[k].compute()
363+
self[k].reset()
357364

358365
return result
359366

@@ -373,6 +380,7 @@ def get_forked_metrics(self, add_dataloader_idx=False):
373380
if options['forked']:
374381
if isinstance(self[k], Metric):
375382
result[dl_key] = self[k].compute().detach()
383+
self[k].reset()
376384
else:
377385
result[dl_key] = self[k]
378386

pytorch_lightning/metrics/metric.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,7 @@ def add_state(
126126
and not isinstance(default, list) # noqa: W503
127127
or (isinstance(default, list) and len(default) != 0) # noqa: W503
128128
):
129-
raise ValueError(
130-
"state variable must be a tensor or any empty list (where you can append tensors)"
131-
)
129+
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")
132130

133131
if dist_reduce_fx == "sum":
134132
dist_reduce_fx = dim_zero_sum
@@ -137,9 +135,7 @@ def add_state(
137135
elif dist_reduce_fx == "cat":
138136
dist_reduce_fx = dim_zero_cat
139137
elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable):
140-
raise ValueError(
141-
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
142-
)
138+
raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]")
143139

144140
setattr(self, name, default)
145141

@@ -161,15 +157,15 @@ def forward(self, *args, **kwargs):
161157
self._to_sync = self.dist_sync_on_step
162158

163159
# save context before switch
164-
self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
160+
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
165161

166162
# call reset, update, compute, on single batch
167163
self.reset()
168164
self.update(*args, **kwargs)
169165
self._forward_cache = self.compute()
170166

171167
# restore context
172-
for attr, val in self._cache.items():
168+
for attr, val in cache.items():
173169
setattr(self, attr, val)
174170
self._to_sync = True
175171
self._computed = None
@@ -201,6 +197,7 @@ def _wrap_update(self, update):
201197
def wrapped_func(*args, **kwargs):
202198
self._computed = None
203199
return update(*args, **kwargs)
200+
204201
return wrapped_func
205202

206203
def _wrap_compute(self, compute):
@@ -211,19 +208,24 @@ def wrapped_func(*args, **kwargs):
211208
return self._computed
212209

213210
dist_sync_fn = self.dist_sync_fn
214-
if (
215-
dist_sync_fn is None
216-
and torch.distributed.is_available()
217-
and torch.distributed.is_initialized()
218-
):
211+
if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized():
219212
# User provided a bool, so we assume DDP if available
220213
dist_sync_fn = gather_all_tensors
221214

215+
synced = False
222216
if self._to_sync and dist_sync_fn is not None:
217+
# cache prior to syncing
218+
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
219+
220+
# sync
223221
self._sync_dist(dist_sync_fn)
222+
synced = True
224223

225224
self._computed = compute(*args, **kwargs)
226-
self.reset()
225+
if synced:
226+
# if we synced, restore to cache so that we can continue to accumulate un-synced state
227+
for attr, val in cache.items():
228+
setattr(self, attr, val)
227229

228230
return self._computed
229231

@@ -270,8 +272,8 @@ def __setstate__(self, state):
270272
self.compute = self._wrap_compute(self.compute)
271273

272274
def _apply(self, fn):
273-
""" Overwrite _apply function such that we can also move metric states
274-
to the correct device when `.to`, `.cuda`, etc methods are called
275+
"""Overwrite _apply function such that we can also move metric states
276+
to the correct device when `.to`, `.cuda`, etc methods are called
275277
"""
276278
self = super()._apply(fn)
277279
# Also apply fn to metric states
@@ -282,13 +284,15 @@ def _apply(self, fn):
282284
elif isinstance(current_val, Sequence):
283285
setattr(self, key, [fn(cur_v) for cur_v in current_val])
284286
else:
285-
raise TypeError('Expected metric state to be either a torch.Tensor'
286-
f'or a list of torch.Tensor, but encountered {current_val}')
287+
raise TypeError(
288+
"Expected metric state to be either a torch.Tensor"
289+
f"or a list of torch.Tensor, but encountered {current_val}"
290+
)
287291
return self
288292

289293
def persistent(self, mode: bool = False):
290-
""" Method for post-init to change if metric states should be saved to
291-
its state_dict
294+
"""Method for post-init to change if metric states should be saved to
295+
its state_dict
292296
"""
293297
for key in self._persistent.keys():
294298
self._persistent[key] = mode

tests/metrics/test_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def compute(self):
139139
assert a._computed == 1
140140
a.update(2)
141141
assert a._computed is None
142-
assert a.compute() == 2
143-
assert a._computed == 2
142+
assert a.compute() == 3
143+
assert a._computed == 3
144144

145145
# called without update, should return cached value
146146
a._computed = 5
@@ -192,7 +192,7 @@ def test_pickle(tmpdir):
192192
assert metric_loaded.compute() == 1
193193

194194
metric_loaded.update(5)
195-
assert metric_loaded.compute() == 5
195+
assert metric_loaded.compute() == 6
196196

197197
metric_pickled = cloudpickle.dumps(a)
198198
metric_loaded = cloudpickle.loads(metric_pickled)

tests/metrics/test_metric_lightning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def training_step(self, batch, batch_idx):
4646
def training_epoch_end(self, outs):
4747
assert torch.allclose(self.sum, self.metric.compute())
4848
self.sum = 0.0
49+
self.metric.reset()
4950

5051
model = TestModel()
5152
model.val_dataloader = None

0 commit comments

Comments
 (0)