From d713d49cd0e23c6d8700bfc2a6364adf830527b5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 16 Apr 2021 13:54:13 +0200 Subject: [PATCH 1/9] reset --- pytorch_lightning/core/step_result.py | 21 +++++++++---------- .../logger_connector/epoch_result_store.py | 15 +++++++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 7a193662b597b..417503e14616b 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -287,16 +287,12 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # reset metric anyway so state does not accumulate - # NOTE: we must compute before reseting just in case the computed value is needed - # later (i.e. if the step metric gets visited first, and then the epoch metric) + # compute for reuse later self[k].compute() - self[k].reset() return result @@ -319,16 +315,12 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # reset metric anyway so state does not accumulate - # NOTE: we must compute before reseting just in case the computed value is needed - # later (i.e. if the step metric gets visited first, and then the epoch metric) + # compute for reuse later self[k].compute() - self[k].reset() return result @@ -348,7 +340,6 @@ def get_forked_metrics(self, add_dataloader_idx=False): if options['forked']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] @@ -587,6 +578,14 @@ def get_non_metrics_keys(self): """ return [k for k, v in self.items() if not isinstance(v, Metric)] + def reset(self): + """ + Call at the end of epoch to reset all metric objects + """ + for k, value in self.items(): + if isinstance(value, Metric): + value.reset() + def choose_last(x): if isinstance(x, (torch.Tensor, list)): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 594da76192aed..bbd97b2b206e8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -233,6 +233,18 @@ def auto_reduce_results_on_epoch_end(self) -> None: self.has_reduced = True + def reset(self): + """ + Call at the end of epoch to reset Result objects + """ + for dl_idx in range(self.num_dataloaders): + epoch_metrics = self._internals[dl_idx] + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + for opt_idx in list(epoch_metrics): + epoch_metrics[opt_idx].reset() + else: + epoch_metrics.reset() + def __getitem__(self, key: str) -> Any: return self._internals.get(key, None) @@ -443,6 +455,9 @@ def get_forked_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_forked_metrics") def reset(self): + if hasattr(self, '_internals'): + for k, value in self._internals.items(): + value.reset() self._internals = {} self._dataloader_idx: Optional[int] = None self._split_idx: Optional[int] = None From 0e6c1417d7918f6b7e16eae1b9c6db6fd2a94ee4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 16 Apr 2021 15:37:24 +0200 Subject: [PATCH 2/9] fix tests --- .../connectors/logger_connector/epoch_result_store.py | 8 ++++---- pytorch_lightning/trainer/trainer.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index bbd97b2b206e8..39ff61fb69c7b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -238,7 +238,7 @@ def reset(self): Call at the end of epoch to reset Result objects """ for dl_idx in range(self.num_dataloaders): - epoch_metrics = self._internals[dl_idx] + epoch_metrics = self._internals[dl_idx] if not self.has_reduced else self._internals_reduced[dl_idx] if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: for opt_idx in list(epoch_metrics): epoch_metrics[opt_idx].reset() @@ -274,6 +274,7 @@ def __init__(self, trainer: 'pl.Trainer') -> None: _should_warn = trainer.accelerator_connector.is_distributed _should_warn &= not trainer.training_type_plugin.rpc_enabled self._should_warn = _should_warn + self._internals = {} self.reset() @@ -455,9 +456,8 @@ def get_forked_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_forked_metrics") def reset(self): - if hasattr(self, '_internals'): - for k, value in self._internals.items(): - value.reset() + for k, value in self._internals.items(): + value.reset() self._internals = {} self._dataloader_idx: Optional[int] = None self._split_idx: Optional[int] = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b1c29ff2c8892..99b12127713c9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -653,9 +653,6 @@ def run_evaluation(self, on_epoch=False): ) self.validating = True - # reset cached results - self.logger_connector.reset() - # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() @@ -747,6 +744,9 @@ def run_evaluation(self, on_epoch=False): torch.set_grad_enabled(True) + # reset cached results + self.logger_connector.reset() + return eval_loop_results def track_output_for_epoch_end(self, outputs, output): From 747b643089d790079ed20235ae38af2711f6e4df Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 16 Apr 2021 15:51:58 +0200 Subject: [PATCH 3/9] fix tests --- tests/core/test_metric_result_integration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..725655c54136d 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -76,6 +76,7 @@ def _ddp_test_fn(rank, worldsize): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() + result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] @@ -127,6 +128,7 @@ def test_result_metric_integration(): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() + result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] From 6c8e5c2359dcbe150944f4e73d92fc851087f6c7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Apr 2021 15:41:32 +0100 Subject: [PATCH 4/9] Apply suggestions from code review Co-authored-by: ananthsub --- pytorch_lightning/core/step_result.py | 2 +- .../trainer/connectors/logger_connector/epoch_result_store.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 417503e14616b..f2cdd31ab739f 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -578,7 +578,7 @@ def get_non_metrics_keys(self): """ return [k for k, v in self.items() if not isinstance(v, Metric)] - def reset(self): + def reset(self) -> None: """ Call at the end of epoch to reset all metric objects """ diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 39ff61fb69c7b..61f66dbed9dfa 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -233,7 +233,7 @@ def auto_reduce_results_on_epoch_end(self) -> None: self.has_reduced = True - def reset(self): + def reset(self) -> None: """ Call at the end of epoch to reset Result objects """ @@ -455,7 +455,7 @@ def get_epoch_log_metrics(self) -> Dict: def get_forked_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_forked_metrics") - def reset(self): + def reset(self) -> None: for k, value in self._internals.items(): value.reset() self._internals = {} From 71847380b4cc65b3c38a45c43058b0704e2bd62b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 16 Apr 2021 18:59:27 +0200 Subject: [PATCH 5/9] move logic --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 99b12127713c9..c86abe614f7df 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -741,11 +741,11 @@ def run_evaluation(self, on_epoch=False): # enable train mode again self.evaluation_loop.on_evaluation_model_train() - - torch.set_grad_enabled(True) - + # reset cached results self.logger_connector.reset() + + torch.set_grad_enabled(True) return eval_loop_results From ee8e6235fd45356aa2cfca1ce5fb004fe9941105 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 16 Apr 2021 19:15:59 +0200 Subject: [PATCH 6/9] chglog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 277fee3463e22..0a62759b0db92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -279,6 +279,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941)) +- Fixed metric objects passed directly to `self.log` not being reset correctly ([#7055](https://github.com/PyTorchLightning/pytorch-lightning/pull/7055)) + + ## [1.2.7] - 2021-04-06 ### Fixed From 128e12a61c88f50d7ec90225294ae750ff2161fc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 16 Apr 2021 19:16:49 +0200 Subject: [PATCH 7/9] pep8 --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c86abe614f7df..7206d6a1022da 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -741,10 +741,10 @@ def run_evaluation(self, on_epoch=False): # enable train mode again self.evaluation_loop.on_evaluation_model_train() - + # reset cached results self.logger_connector.reset() - + torch.set_grad_enabled(True) return eval_loop_results From d52cdc608cba06b06b367b4fe05783a32a0f303e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Apr 2021 20:45:50 +0100 Subject: [PATCH 8/9] Add test --- .../trainer/logging_/test_logger_connector.py | 116 +++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index bddde7e77f5a8..5cc365ef90a98 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -22,10 +22,11 @@ import pytest import torch from torch.utils.data import DataLoader +from torchmetrics import Accuracy, AveragePrecision +from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.step_result import Result -from pytorch_lightning.metrics import Accuracy from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder @@ -590,3 +591,116 @@ def validation_step(self, batch, batch_idx): assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 1 assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 3 + + +def test_metrics_reset(tmpdir): + """Tests that metrics are reset correctly after the end of the train/val/test epoch.""" + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 1) + + for stage in ['train', 'val', 'test']: + acc = Accuracy() + acc.reset = mock.Mock(side_effect=acc.reset) + ap = AveragePrecision(num_classes=1, pos_label=1) + ap.reset = mock.Mock(side_effect=ap.reset) + self.add_module(f"acc_{stage}", acc) + self.add_module(f"ap_{stage}", ap) + + def forward(self, x): + return self.layer(x) + + def _step(self, stage, batch): + labels = (batch.detach().sum(1) > 0).float() # Fake some targets + logits = self.forward(batch) + loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) + probs = torch.sigmoid(logits.detach()) + self.log(f"loss/{stage}", loss) + + acc = self._modules[f"acc_{stage}"] + ap = self._modules[f"ap_{stage}"] + + labels_int = labels.to(torch.long) + acc(probs, labels_int) + ap(probs, labels_int) + + # Metric.forward calls reset so reset the mocks here + acc.reset.reset_mock() + ap.reset.reset_mock() + + self.log(f"{stage}/accuracy", acc) + self.log(f"{stage}/ap", ap) + + return loss + + def training_step(self, batch, batch_idx, *args, **kwargs): + return self._step('train', batch) + + def validation_step(self, batch, batch_idx, *args, **kwargs): + return self._step('val', batch) + + def test_step(self, batch, batch_idx, *args, **kwargs): + return self._step('test', batch) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def _assert_epoch_end(self, stage): + acc = self._modules[f"acc_{stage}"] + ap = self._modules[f"ap_{stage}"] + + acc.reset.asset_not_called() + ap.reset.assert_not_called() + + def on_train_epoch_end(self, outputs): + self._assert_epoch_end('train') + + def on_validation_epoch_end(self, outputs): + self._assert_epoch_end('val') + + def on_test_epoch_end(self, outputs): + self._assert_epoch_end('test') + + def _assert_called(model, stage): + acc = model._modules[f"acc_{stage}"] + ap = model._modules[f"ap_{stage}"] + + acc.reset.assert_called() + acc.reset.reset_mock() + + ap.reset.assert_called() + ap.reset.reset_mock() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=1, + progress_bar_refresh_rate=0, + ) + + trainer.fit(model) + _assert_called(model, 'train') + _assert_called(model, 'val') + + trainer.validate(model) + _assert_called(model, 'val') + + trainer.test(model) + _assert_called(model, 'test') From 783cfd37fafa8fe049b276788920acc0fada6ae4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Apr 2021 22:38:06 +0100 Subject: [PATCH 9/9] Improve test --- tests/trainer/logging_/test_logger_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 5cc365ef90a98..118899e32276e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -679,10 +679,10 @@ def _assert_called(model, stage): acc = model._modules[f"acc_{stage}"] ap = model._modules[f"ap_{stage}"] - acc.reset.assert_called() + acc.reset.assert_called_once() acc.reset.reset_mock() - ap.reset.assert_called() + ap.reset.assert_called_once() ap.reset.reset_mock() model = TestModel()