diff --git a/CHANGELOG.md b/CHANGELOG.md index 763bd2248ef2b..f25f6e783bb58 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the order to call for world ranks & the `root_device` property in `TPUSpawnPlugin` ([#7074](https://github.com/PyTorchLightning/pytorch-lightning/pull/7074)) +- Fixed metric objects passed directly to `self.log` not being reset correctly ([#7055](https://github.com/PyTorchLightning/pytorch-lightning/pull/7055)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 7a193662b597b..f2cdd31ab739f 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -287,16 +287,12 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # reset metric anyway so state does not accumulate - # NOTE: we must compute before reseting just in case the computed value is needed - # later (i.e. if the step metric gets visited first, and then the epoch metric) + # compute for reuse later self[k].compute() - self[k].reset() return result @@ -319,16 +315,12 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # reset metric anyway so state does not accumulate - # NOTE: we must compute before reseting just in case the computed value is needed - # later (i.e. if the step metric gets visited first, and then the epoch metric) + # compute for reuse later self[k].compute() - self[k].reset() return result @@ -348,7 +340,6 @@ def get_forked_metrics(self, add_dataloader_idx=False): if options['forked']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] @@ -587,6 +578,14 @@ def get_non_metrics_keys(self): """ return [k for k, v in self.items() if not isinstance(v, Metric)] + def reset(self) -> None: + """ + Call at the end of epoch to reset all metric objects + """ + for k, value in self.items(): + if isinstance(value, Metric): + value.reset() + def choose_last(x): if isinstance(x, (torch.Tensor, list)): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 594da76192aed..61f66dbed9dfa 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -233,6 +233,18 @@ def auto_reduce_results_on_epoch_end(self) -> None: self.has_reduced = True + def reset(self) -> None: + """ + Call at the end of epoch to reset Result objects + """ + for dl_idx in range(self.num_dataloaders): + epoch_metrics = self._internals[dl_idx] if not self.has_reduced else self._internals_reduced[dl_idx] + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + for opt_idx in list(epoch_metrics): + epoch_metrics[opt_idx].reset() + else: + epoch_metrics.reset() + def __getitem__(self, key: str) -> Any: return self._internals.get(key, None) @@ -262,6 +274,7 @@ def __init__(self, trainer: 'pl.Trainer') -> None: _should_warn = trainer.accelerator_connector.is_distributed _should_warn &= not trainer.training_type_plugin.rpc_enabled self._should_warn = _should_warn + self._internals = {} self.reset() @@ -442,7 +455,9 @@ def get_epoch_log_metrics(self) -> Dict: def get_forked_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_forked_metrics") - def reset(self): + def reset(self) -> None: + for k, value in self._internals.items(): + value.reset() self._internals = {} self._dataloader_idx: Optional[int] = None self._split_idx: Optional[int] = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 62c8f530dca06..772f2dc4aa6f5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -667,9 +667,6 @@ def run_evaluation(self, on_epoch=False): ) self.validating = True - # reset cached results - self.logger_connector.reset() - # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() @@ -759,6 +756,9 @@ def run_evaluation(self, on_epoch=False): # enable train mode again self.evaluation_loop.on_evaluation_model_train() + # reset cached results + self.logger_connector.reset() + torch.set_grad_enabled(True) return eval_loop_results diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..725655c54136d 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -76,6 +76,7 @@ def _ddp_test_fn(rank, worldsize): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() + result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] @@ -127,6 +128,7 @@ def test_result_metric_integration(): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() + result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index bddde7e77f5a8..118899e32276e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -22,10 +22,11 @@ import pytest import torch from torch.utils.data import DataLoader +from torchmetrics import Accuracy, AveragePrecision +from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.step_result import Result -from pytorch_lightning.metrics import Accuracy from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder @@ -590,3 +591,116 @@ def validation_step(self, batch, batch_idx): assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 1 assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 3 + + +def test_metrics_reset(tmpdir): + """Tests that metrics are reset correctly after the end of the train/val/test epoch.""" + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 1) + + for stage in ['train', 'val', 'test']: + acc = Accuracy() + acc.reset = mock.Mock(side_effect=acc.reset) + ap = AveragePrecision(num_classes=1, pos_label=1) + ap.reset = mock.Mock(side_effect=ap.reset) + self.add_module(f"acc_{stage}", acc) + self.add_module(f"ap_{stage}", ap) + + def forward(self, x): + return self.layer(x) + + def _step(self, stage, batch): + labels = (batch.detach().sum(1) > 0).float() # Fake some targets + logits = self.forward(batch) + loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) + probs = torch.sigmoid(logits.detach()) + self.log(f"loss/{stage}", loss) + + acc = self._modules[f"acc_{stage}"] + ap = self._modules[f"ap_{stage}"] + + labels_int = labels.to(torch.long) + acc(probs, labels_int) + ap(probs, labels_int) + + # Metric.forward calls reset so reset the mocks here + acc.reset.reset_mock() + ap.reset.reset_mock() + + self.log(f"{stage}/accuracy", acc) + self.log(f"{stage}/ap", ap) + + return loss + + def training_step(self, batch, batch_idx, *args, **kwargs): + return self._step('train', batch) + + def validation_step(self, batch, batch_idx, *args, **kwargs): + return self._step('val', batch) + + def test_step(self, batch, batch_idx, *args, **kwargs): + return self._step('test', batch) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def _assert_epoch_end(self, stage): + acc = self._modules[f"acc_{stage}"] + ap = self._modules[f"ap_{stage}"] + + acc.reset.asset_not_called() + ap.reset.assert_not_called() + + def on_train_epoch_end(self, outputs): + self._assert_epoch_end('train') + + def on_validation_epoch_end(self, outputs): + self._assert_epoch_end('val') + + def on_test_epoch_end(self, outputs): + self._assert_epoch_end('test') + + def _assert_called(model, stage): + acc = model._modules[f"acc_{stage}"] + ap = model._modules[f"ap_{stage}"] + + acc.reset.assert_called_once() + acc.reset.reset_mock() + + ap.reset.assert_called_once() + ap.reset.reset_mock() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=1, + progress_bar_refresh_rate=0, + ) + + trainer.fit(model) + _assert_called(model, 'train') + _assert_called(model, 'val') + + trainer.validate(model) + _assert_called(model, 'val') + + trainer.test(model) + _assert_called(model, 'test')