From 21507d9e3e097cad770d61df6cdd6b7a6264abcc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 7 Jan 2021 20:46:26 +0100 Subject: [PATCH 1/8] reset --- docs/source/metrics.rst | 5 +++++ pytorch_lightning/core/step_result.py | 11 +++++++---- pytorch_lightning/metrics/metric.py | 16 ++++++++++++---- tests/metrics/test_metric.py | 6 +++--- tests/metrics/test_metric_lightning.py | 1 + 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 17ba0ce94766d..d59023d8e462f 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -20,6 +20,11 @@ The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to t serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the provided input. +.. warning:: + From v1.2 and foreward ``compute()`` will no longer automatically also call ``reset()``, + and it is up to the user to reset metrics between epochs, except in the case where the + metric is directly passed to ``LightningModule``s ``self.log``. + These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in distributed mode, the internal state of each metric is synced and reduced across each process, so that the logic present in ``.compute()`` is applied to state information from all processes. diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 091f9a789efda..4d20b3c1a5cd3 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -320,12 +320,13 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() + self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute metric on epoch anyway so state does not accumulate - self[k].compute() + # reset metric anyway so state does not accumulate + self[k].reset() return result @@ -348,12 +349,13 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() + self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute metric on epoch anyway so state does not accumulate - self[k].compute() + # reset metric anyway so state does not accumulate + self[k].reset() return result @@ -373,6 +375,7 @@ def get_forked_metrics(self, add_dataloader_idx=False): if options['forked']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() + self[k].reset() else: result[dl_key] = self[k] diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 0f61b94c55139..6e93a5eb5a58d 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -212,15 +212,23 @@ def wrapped_func(*args, **kwargs): and torch.distributed.is_initialized()): # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors - + + synced = False if self._to_sync and dist_sync_fn is not None: + # cache prior to syncing + self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + + # sync self._sync_dist(dist_sync_fn) + synced = True self._computed = compute(*args, **kwargs) - self.reset() - + if synced: + # ADDED: if we synced, restore to cache + for attr, val in self._cache.items(): + setattr(self, attr, val) + return self._computed - return wrapped_func @abstractmethod diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index d1c8b8c441cc5..e9b3f9c2f0f92 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -117,8 +117,8 @@ def compute(self): assert a._computed == 1 a.update(2) assert a._computed is None - assert a.compute() == 2 - assert a._computed == 2 + assert a.compute() == 3 + assert a._computed == 3 # called without update, should return cached value a._computed = 5 @@ -162,7 +162,7 @@ def test_pickle(tmpdir): assert metric_loaded.compute() == 1 metric_loaded.update(5) - assert metric_loaded.compute() == 5 + assert metric_loaded.compute() == 6 metric_pickled = cloudpickle.dumps(a) metric_loaded = cloudpickle.loads(metric_pickled) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index ed809c5e8527e..0f148fbec12de 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -34,6 +34,7 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outs): assert torch.allclose(self.sum, self.metric.compute()) self.sum = 0.0 + self.metric.reset() model = TestModel() model.val_dataloader = None From 321924e5158209a67bce5af3b91ede614643a20a Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 7 Jan 2021 17:21:31 -0500 Subject: [PATCH 2/8] self._cache -> cache (make cache local variable so it is not overwritten) --- pytorch_lightning/metrics/metric.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 6e93a5eb5a58d..12f98ad4aba7e 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -157,7 +157,7 @@ def forward(self, *args, **kwargs): self._to_sync = self.dist_sync_on_step # save context before switch - self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} # call reset, update, compute, on single batch self.reset() @@ -165,7 +165,7 @@ def forward(self, *args, **kwargs): self._forward_cache = self.compute() # restore context - for attr, val in self._cache.items(): + for attr, val in cache.items(): setattr(self, attr, val) self._to_sync = True self._computed = None @@ -216,7 +216,7 @@ def wrapped_func(*args, **kwargs): synced = False if self._to_sync and dist_sync_fn is not None: # cache prior to syncing - self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} # sync self._sync_dist(dist_sync_fn) @@ -225,7 +225,7 @@ def wrapped_func(*args, **kwargs): self._computed = compute(*args, **kwargs) if synced: # ADDED: if we synced, restore to cache - for attr, val in self._cache.items(): + for attr, val in cache.items(): setattr(self, attr, val) return self._computed From 9279e24d0b3f490a6277c9f04820872fb15cd4b0 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 7 Jan 2021 17:26:16 -0500 Subject: [PATCH 3/8] pep8 --- pytorch_lightning/metrics/metric.py | 39 ++++++++++++++--------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 12f98ad4aba7e..78eb3465ba875 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -57,6 +57,7 @@ class Metric(nn.Module, ABC): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None """ + def __init__( self, compute_on_step: bool = True, @@ -119,12 +120,10 @@ def add_state( """ if ( not isinstance(default, torch.Tensor) - and not isinstance(default, list) # noqa: W503 + and not isinstance(default, list) # noqa: W503 or (isinstance(default, list) and len(default) != 0) # noqa: W503 ): - raise ValueError( - "state variable must be a tensor or any empty list (where you can append tensors)" - ) + raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") if dist_reduce_fx == "sum": dist_reduce_fx = dim_zero_sum @@ -133,9 +132,7 @@ def add_state( elif dist_reduce_fx == "cat": dist_reduce_fx = dim_zero_cat elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): - raise ValueError( - "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" - ) + raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]") setattr(self, name, default) @@ -197,6 +194,7 @@ def _wrap_update(self, update): def wrapped_func(*args, **kwargs): self._computed = None return update(*args, **kwargs) + return wrapped_func def _wrap_compute(self, compute): @@ -207,28 +205,27 @@ def wrapped_func(*args, **kwargs): return self._computed dist_sync_fn = self.dist_sync_fn - if (dist_sync_fn is None - and torch.distributed.is_available() - and torch.distributed.is_initialized()): + if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors - + synced = False if self._to_sync and dist_sync_fn is not None: # cache prior to syncing cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - + # sync self._sync_dist(dist_sync_fn) synced = True self._computed = compute(*args, **kwargs) if synced: - # ADDED: if we synced, restore to cache + # if we synced, restore to cache so that we can continue to accumulate un-synced state for attr, val in cache.items(): setattr(self, attr, val) - + return self._computed + return wrapped_func @abstractmethod @@ -268,8 +265,8 @@ def __setstate__(self, state): self.compute = self._wrap_compute(self.compute) def _apply(self, fn): - """ Overwrite _apply function such that we can also move metric states - to the correct device when `.to`, `.cuda`, etc methods are called + """Overwrite _apply function such that we can also move metric states + to the correct device when `.to`, `.cuda`, etc methods are called """ self = super()._apply(fn) # Also apply fn to metric states @@ -280,13 +277,15 @@ def _apply(self, fn): elif isinstance(current_val, Sequence): setattr(self, key, [fn(cur_v) for cur_v in current_val]) else: - raise TypeError('Expected metric state to be either a torch.Tensor' - f'or a list of torch.Tensor, but encountered {current_val}') + raise TypeError( + "Expected metric state to be either a torch.Tensor" + f"or a list of torch.Tensor, but encountered {current_val}" + ) return self def persistent(self, mode: bool = False): - """ Method for post-init to change if metric states should be saved to - its state_dict + """Method for post-init to change if metric states should be saved to + its state_dict """ for key in self._persistent.keys(): self._persistent[key] = mode From b4be1fcf56ec1c907d7ab2a1cd999f1b01b83619 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 7 Jan 2021 18:40:42 -0500 Subject: [PATCH 4/8] fix metric result integration --- pytorch_lightning/core/step_result.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 4d20b3c1a5cd3..9d16bd36d1281 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -306,9 +306,9 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: Gets the metrics to log at the end of epoch """ result = {} - meta = self['meta'] for k, options in meta.items(): + print(k) if k == '_internal': continue @@ -321,11 +321,15 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() self[k].reset() + print("reset", k) else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # reset metric anyway so state does not accumulate + # NOTE: we must compute before setting just incase the computed value is needed + # before reseting + self[k].compute() self[k].reset() return result @@ -355,6 +359,9 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # reset metric anyway so state does not accumulate + # NOTE: we must compute before setting just incase the computed value is needed + # before reseting + self[k].compute() self[k].reset() return result From a587d5369854170b80319e497bf0f92fdd62ede8 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 7 Jan 2021 18:45:51 -0500 Subject: [PATCH 5/8] rm print statements --- pytorch_lightning/core/step_result.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 9d16bd36d1281..424fb21897a34 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -308,7 +308,6 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: result = {} meta = self['meta'] for k, options in meta.items(): - print(k) if k == '_internal': continue @@ -321,7 +320,6 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() self[k].reset() - print("reset", k) else: result[dl_key] = self[k] From 78edf9ca6f7d29abc9fc0936cd083339fcbe881a Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 7 Jan 2021 18:51:29 -0500 Subject: [PATCH 6/8] better comment --- pytorch_lightning/core/step_result.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 424fb21897a34..72d5939b67795 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -325,8 +325,8 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # reset metric anyway so state does not accumulate - # NOTE: we must compute before setting just incase the computed value is needed - # before reseting + # NOTE: we must compute before reseting just in case the computed value is needed + # later (i.e. if the step metric gets visited first, and then the epoch metric) self[k].compute() self[k].reset() @@ -357,8 +357,8 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): if k in self and not options['on_epoch'] and isinstance(self[k], Metric): # reset metric anyway so state does not accumulate - # NOTE: we must compute before setting just incase the computed value is needed - # before reseting + # NOTE: we must compute before reseting just in case the computed value is needed + # later (i.e. if the step metric gets visited first, and then the epoch metric) self[k].compute() self[k].reset() From e276e767bde704f81a225220cfc99377508e4500 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 8 Jan 2021 14:52:27 +0100 Subject: [PATCH 7/8] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7808e8d61c83f..6c03b2fb6e65d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Metric `compute()` method will no longer automatically call `reset()` ([#5409](https://github.com/PyTorchLightning/pytorch-lightning/pull/5409/)) + + ### Deprecated - `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) From b7bbc30a783ee44b3cc8b63cd80e7f623f721f25 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 11 Jan 2021 11:35:01 +0100 Subject: [PATCH 8/8] Update docs/source/metrics.rst Co-authored-by: Roger Shieh --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index a8aa47fbd58fc..c50d7c2991753 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -21,7 +21,7 @@ serves the dual purpose of calling ``update()`` on its input and simultaneously provided input. .. warning:: - From v1.2 and foreward ``compute()`` will no longer automatically also call ``reset()``, + From v1.2 onward ``compute()`` will no longer automatically call ``reset()``, and it is up to the user to reset metrics between epochs, except in the case where the metric is directly passed to ``LightningModule``s ``self.log``.