Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))


- Metric `compute()` method will no longer automatically call `reset()` ([#5409](https://github.com/PyTorchLightning/pytorch-lightning/pull/5409/))


- Set PyTorch 1.4 as min requirements, also for testing and examples `torchvision>=0.5` and `torchtext>=0.5` ([#5418](https://github.com/PyTorchLightning/pytorch-lightning/pull/5418))


Expand Down
5 changes: 5 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to t
serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the
provided input.

.. warning::
From v1.2 onward ``compute()`` will no longer automatically call ``reset()``,
and it is up to the user to reset metrics between epochs, except in the case where the
metric is directly passed to ``LightningModule``s ``self.log``.

These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
logic present in ``.compute()`` is applied to state information from all processes.
Expand Down
14 changes: 11 additions & 3 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
Gets the metrics to log at the end of epoch
"""
result = {}

meta = self['meta']
for k, options in meta.items():
if k == '_internal':
Expand All @@ -320,12 +319,16 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
self[k].reset()
else:
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
# reset metric anyway so state does not accumulate
# NOTE: we must compute before reseting just in case the computed value is needed
# later (i.e. if the step metric gets visited first, and then the epoch metric)
self[k].compute()
self[k].reset()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why reset is being added there ? self[k].compute() and self[k].reset() makes compute() useless no ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute() will save the computed value to self._computed, so if compute() is called again without any update() call, the cached value will be provided. We still call reset() so that we don't get too much memory usage by needlessly accumulating tensors over epochs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important if we have a metric that is logged on both step and epoch. Since it will show up twice in the results meta dict, we want to reset the step metric on epoch end, but still allow the epoch metric to be computed.


return result

Expand All @@ -348,12 +351,16 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False):
if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
self[k].reset()
else:
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
# reset metric anyway so state does not accumulate
# NOTE: we must compute before reseting just in case the computed value is needed
# later (i.e. if the step metric gets visited first, and then the epoch metric)
self[k].compute()
self[k].reset()

return result

Expand All @@ -373,6 +380,7 @@ def get_forked_metrics(self, add_dataloader_idx=False):
if options['forked']:
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
self[k].reset()
else:
result[dl_key] = self[k]

Expand Down
44 changes: 24 additions & 20 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def add_state(
and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError(
"state variable must be a tensor or any empty list (where you can append tensors)"
)
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")

if dist_reduce_fx == "sum":
dist_reduce_fx = dim_zero_sum
Expand All @@ -137,9 +135,7 @@ def add_state(
elif dist_reduce_fx == "cat":
dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable):
raise ValueError(
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
)
raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]")

setattr(self, name, default)

Expand All @@ -161,15 +157,15 @@ def forward(self, *args, **kwargs):
self._to_sync = self.dist_sync_on_step

# save context before switch
self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}

# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()

# restore context
for attr, val in self._cache.items():
for attr, val in cache.items():
setattr(self, attr, val)
self._to_sync = True
self._computed = None
Expand Down Expand Up @@ -201,6 +197,7 @@ def _wrap_update(self, update):
def wrapped_func(*args, **kwargs):
self._computed = None
return update(*args, **kwargs)

return wrapped_func

def _wrap_compute(self, compute):
Expand All @@ -211,19 +208,24 @@ def wrapped_func(*args, **kwargs):
return self._computed

dist_sync_fn = self.dist_sync_fn
if (
dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized():
# User provided a bool, so we assume DDP if available
dist_sync_fn = gather_all_tensors

synced = False
if self._to_sync and dist_sync_fn is not None:
# cache prior to syncing
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}

# sync
self._sync_dist(dist_sync_fn)
synced = True

self._computed = compute(*args, **kwargs)
self.reset()
if synced:
# if we synced, restore to cache so that we can continue to accumulate un-synced state
for attr, val in cache.items():
setattr(self, attr, val)

return self._computed

Expand Down Expand Up @@ -270,8 +272,8 @@ def __setstate__(self, state):
self.compute = self._wrap_compute(self.compute)

def _apply(self, fn):
""" Overwrite _apply function such that we can also move metric states
to the correct device when `.to`, `.cuda`, etc methods are called
"""Overwrite _apply function such that we can also move metric states
to the correct device when `.to`, `.cuda`, etc methods are called
"""
self = super()._apply(fn)
# Also apply fn to metric states
Expand All @@ -282,13 +284,15 @@ def _apply(self, fn):
elif isinstance(current_val, Sequence):
setattr(self, key, [fn(cur_v) for cur_v in current_val])
else:
raise TypeError('Expected metric state to be either a torch.Tensor'
f'or a list of torch.Tensor, but encountered {current_val}')
raise TypeError(
"Expected metric state to be either a torch.Tensor"
f"or a list of torch.Tensor, but encountered {current_val}"
)
return self

def persistent(self, mode: bool = False):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""Method for post-init to change if metric states should be saved to
its state_dict
"""
for key in self._persistent.keys():
self._persistent[key] = mode
Expand Down
6 changes: 3 additions & 3 deletions tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def compute(self):
assert a._computed == 1
a.update(2)
assert a._computed is None
assert a.compute() == 2
assert a._computed == 2
assert a.compute() == 3
assert a._computed == 3

# called without update, should return cached value
a._computed = 5
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_pickle(tmpdir):
assert metric_loaded.compute() == 1

metric_loaded.update(5)
assert metric_loaded.compute() == 5
assert metric_loaded.compute() == 6

metric_pickled = cloudpickle.dumps(a)
metric_loaded = cloudpickle.loads(metric_pickled)
Expand Down
1 change: 1 addition & 0 deletions tests/metrics/test_metric_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def training_step(self, batch, batch_idx):
def training_epoch_end(self, outs):
assert torch.allclose(self.sum, self.metric.compute())
self.sum = 0.0
self.metric.reset()

model = TestModel()
model.val_dataloader = None
Expand Down