diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bf80b1ccf739..ab74be455ad54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Metric `compute()` method will no longer automatically call `reset()` ([#5409](https://github.com/PyTorchLightning/pytorch-lightning/pull/5409/)) + + - Set PyTorch 1.4 as min requirements, also for testing and examples `torchvision>=0.5` and `torchtext>=0.5` ([#5418](https://github.com/PyTorchLightning/pytorch-lightning/pull/5418)) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 9bdaae188b626..c50d7c2991753 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -20,6 +20,11 @@ The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to t serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the provided input. +.. warning:: + From v1.2 onward ``compute()`` will no longer automatically call ``reset()``, + and it is up to the user to reset metrics between epochs, except in the case where the + metric is directly passed to ``LightningModule``s ``self.log``. + These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in distributed mode, the internal state of each metric is synced and reduced across each process, so that the logic present in ``.compute()`` is applied to state information from all processes. diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 091f9a789efda..72d5939b67795 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -306,7 +306,6 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: Gets the metrics to log at the end of epoch """ result = {} - meta = self['meta'] for k, options in meta.items(): if k == '_internal': @@ -320,12 +319,16 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() + self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute metric on epoch anyway so state does not accumulate + # reset metric anyway so state does not accumulate + # NOTE: we must compute before reseting just in case the computed value is needed + # later (i.e. if the step metric gets visited first, and then the epoch metric) self[k].compute() + self[k].reset() return result @@ -348,12 +351,16 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() + self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute metric on epoch anyway so state does not accumulate + # reset metric anyway so state does not accumulate + # NOTE: we must compute before reseting just in case the computed value is needed + # later (i.e. if the step metric gets visited first, and then the epoch metric) self[k].compute() + self[k].reset() return result @@ -373,6 +380,7 @@ def get_forked_metrics(self, add_dataloader_idx=False): if options['forked']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() + self[k].reset() else: result[dl_key] = self[k] diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 05b719e8a0610..5c8aaefc2084a 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -126,9 +126,7 @@ def add_state( and not isinstance(default, list) # noqa: W503 or (isinstance(default, list) and len(default) != 0) # noqa: W503 ): - raise ValueError( - "state variable must be a tensor or any empty list (where you can append tensors)" - ) + raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") if dist_reduce_fx == "sum": dist_reduce_fx = dim_zero_sum @@ -137,9 +135,7 @@ def add_state( elif dist_reduce_fx == "cat": dist_reduce_fx = dim_zero_cat elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): - raise ValueError( - "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" - ) + raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]") setattr(self, name, default) @@ -161,7 +157,7 @@ def forward(self, *args, **kwargs): self._to_sync = self.dist_sync_on_step # save context before switch - self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} # call reset, update, compute, on single batch self.reset() @@ -169,7 +165,7 @@ def forward(self, *args, **kwargs): self._forward_cache = self.compute() # restore context - for attr, val in self._cache.items(): + for attr, val in cache.items(): setattr(self, attr, val) self._to_sync = True self._computed = None @@ -201,6 +197,7 @@ def _wrap_update(self, update): def wrapped_func(*args, **kwargs): self._computed = None return update(*args, **kwargs) + return wrapped_func def _wrap_compute(self, compute): @@ -211,19 +208,24 @@ def wrapped_func(*args, **kwargs): return self._computed dist_sync_fn = self.dist_sync_fn - if ( - dist_sync_fn is None - and torch.distributed.is_available() - and torch.distributed.is_initialized() - ): + if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors + synced = False if self._to_sync and dist_sync_fn is not None: + # cache prior to syncing + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + + # sync self._sync_dist(dist_sync_fn) + synced = True self._computed = compute(*args, **kwargs) - self.reset() + if synced: + # if we synced, restore to cache so that we can continue to accumulate un-synced state + for attr, val in cache.items(): + setattr(self, attr, val) return self._computed @@ -270,8 +272,8 @@ def __setstate__(self, state): self.compute = self._wrap_compute(self.compute) def _apply(self, fn): - """ Overwrite _apply function such that we can also move metric states - to the correct device when `.to`, `.cuda`, etc methods are called + """Overwrite _apply function such that we can also move metric states + to the correct device when `.to`, `.cuda`, etc methods are called """ self = super()._apply(fn) # Also apply fn to metric states @@ -282,13 +284,15 @@ def _apply(self, fn): elif isinstance(current_val, Sequence): setattr(self, key, [fn(cur_v) for cur_v in current_val]) else: - raise TypeError('Expected metric state to be either a torch.Tensor' - f'or a list of torch.Tensor, but encountered {current_val}') + raise TypeError( + "Expected metric state to be either a torch.Tensor" + f"or a list of torch.Tensor, but encountered {current_val}" + ) return self def persistent(self, mode: bool = False): - """ Method for post-init to change if metric states should be saved to - its state_dict + """Method for post-init to change if metric states should be saved to + its state_dict """ for key in self._persistent.keys(): self._persistent[key] = mode diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index c3cafa2365267..16ef7312101d6 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -139,8 +139,8 @@ def compute(self): assert a._computed == 1 a.update(2) assert a._computed is None - assert a.compute() == 2 - assert a._computed == 2 + assert a.compute() == 3 + assert a._computed == 3 # called without update, should return cached value a._computed = 5 @@ -192,7 +192,7 @@ def test_pickle(tmpdir): assert metric_loaded.compute() == 1 metric_loaded.update(5) - assert metric_loaded.compute() == 5 + assert metric_loaded.compute() == 6 metric_pickled = cloudpickle.dumps(a) metric_loaded = cloudpickle.loads(metric_pickled) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 2347cc65f8293..0beb0534139ca 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -46,6 +46,7 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outs): assert torch.allclose(self.sum, self.metric.compute()) self.sum = 0.0 + self.metric.reset() model = TestModel() model.val_dataloader = None