Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the order to call for world ranks & the `root_device` property in `TPUSpawnPlugin` ([#7074](https://github.com/PyTorchLightning/pytorch-lightning/pull/7074))


- Fixed metric objects passed directly to `self.log` not being reset correctly ([#7055](https://github.com/PyTorchLightning/pytorch-lightning/pull/7055))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
21 changes: 10 additions & 11 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,12 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
self[k].reset()
else:
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# reset metric anyway so state does not accumulate
# NOTE: we must compute before reseting just in case the computed value is needed
# later (i.e. if the step metric gets visited first, and then the epoch metric)
# compute for reuse later
self[k].compute()
self[k].reset()

return result

Expand All @@ -319,16 +315,12 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False):
if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
self[k].reset()
else:
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# reset metric anyway so state does not accumulate
# NOTE: we must compute before reseting just in case the computed value is needed
# later (i.e. if the step metric gets visited first, and then the epoch metric)
# compute for reuse later
self[k].compute()
self[k].reset()

return result

Expand All @@ -348,7 +340,6 @@ def get_forked_metrics(self, add_dataloader_idx=False):
if options['forked']:
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
self[k].reset()
else:
result[dl_key] = self[k]

Expand Down Expand Up @@ -587,6 +578,14 @@ def get_non_metrics_keys(self):
"""
return [k for k, v in self.items() if not isinstance(v, Metric)]

def reset(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def reset(self) -> None:
def reset_metrics(self) -> None:

"""
Call at the end of epoch to reset all metric objects
"""
for k, value in self.items():
if isinstance(value, Metric):
value.reset()


def choose_last(x):
if isinstance(x, (torch.Tensor, list)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ def auto_reduce_results_on_epoch_end(self) -> None:

self.has_reduced = True

def reset(self) -> None:
"""
Call at the end of epoch to reset Result objects
"""
for dl_idx in range(self.num_dataloaders):
epoch_metrics = self._internals[dl_idx] if not self.has_reduced else self._internals_reduced[dl_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current usage, self.has_reduced should always be True right ?
Better to add assert self.has_reduced == True there and use self._internals_reduced directly ?

if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP:
for opt_idx in list(epoch_metrics):
epoch_metrics[opt_idx].reset()
Comment on lines +242 to +244
Copy link
Contributor

@ananthsub ananthsub Apr 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain this check? at the surface, inside the batch train loop reads like we're not at the epoch end?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not completely sure about what is going on myself, but apparently the self._internal_type is equal to ResultStoreType.INSIDE_BATCH_TRAIN_LOOP even when we are at epoch end for all training metrics. self._internal_type is equal to ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP for validation metrics (atleast when this reset function is called).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. I think BATCH should be removed.

But I'd also like to clear all this entirely at some point.

else:
epoch_metrics.reset()

def __getitem__(self, key: str) -> Any:
return self._internals.get(key, None)

Expand Down Expand Up @@ -262,6 +274,7 @@ def __init__(self, trainer: 'pl.Trainer') -> None:
_should_warn = trainer.accelerator_connector.is_distributed
_should_warn &= not trainer.training_type_plugin.rpc_enabled
self._should_warn = _should_warn
self._internals = {}

self.reset()

Expand Down Expand Up @@ -442,7 +455,9 @@ def get_epoch_log_metrics(self) -> Dict:
def get_forked_metrics(self) -> Dict:
return self.run_epoch_by_func_name("get_forked_metrics")

def reset(self):
def reset(self) -> None:
for k, value in self._internals.items():
value.reset()
self._internals = {}
self._dataloader_idx: Optional[int] = None
self._split_idx: Optional[int] = None
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,6 @@ def run_evaluation(self, on_epoch=False):
)
self.validating = True

# reset cached results
self.logger_connector.reset()

# prepare dataloaders
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders()

Expand Down Expand Up @@ -759,6 +756,9 @@ def run_evaluation(self, on_epoch=False):
# enable train mode again
self.evaluation_loop.on_evaluation_model_train()

# reset cached results
self.logger_connector.reset()

torch.set_grad_enabled(True)

return eval_loop_results
Expand Down
2 changes: 2 additions & 0 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _ddp_test_fn(rank, worldsize):
assert batch_expected[k] == batch_log[k]

epoch_log = result.get_epoch_log_metrics()
result.reset()

# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
Expand Down Expand Up @@ -127,6 +128,7 @@ def test_result_metric_integration():
assert batch_expected[k] == batch_log[k]

epoch_log = result.get_epoch_log_metrics()
result.reset()

# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
Expand Down
116 changes: 115 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, AveragePrecision

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
Expand Down Expand Up @@ -590,3 +591,116 @@ def validation_step(self, batch, batch_idx):

assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 1
assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 3


def test_metrics_reset(tmpdir):
"""Tests that metrics are reset correctly after the end of the train/val/test epoch."""

class TestModel(LightningModule):

def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 1)

for stage in ['train', 'val', 'test']:
acc = Accuracy()
acc.reset = mock.Mock(side_effect=acc.reset)
ap = AveragePrecision(num_classes=1, pos_label=1)
ap.reset = mock.Mock(side_effect=ap.reset)
self.add_module(f"acc_{stage}", acc)
self.add_module(f"ap_{stage}", ap)

def forward(self, x):
return self.layer(x)

def _step(self, stage, batch):
labels = (batch.detach().sum(1) > 0).float() # Fake some targets
logits = self.forward(batch)
loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1))
probs = torch.sigmoid(logits.detach())
self.log(f"loss/{stage}", loss)

acc = self._modules[f"acc_{stage}"]
ap = self._modules[f"ap_{stage}"]

labels_int = labels.to(torch.long)
acc(probs, labels_int)
ap(probs, labels_int)

# Metric.forward calls reset so reset the mocks here
acc.reset.reset_mock()
ap.reset.reset_mock()

self.log(f"{stage}/accuracy", acc)
self.log(f"{stage}/ap", ap)

return loss

def training_step(self, batch, batch_idx, *args, **kwargs):
return self._step('train', batch)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return self._step('val', batch)

def test_step(self, batch, batch_idx, *args, **kwargs):
return self._step('test', batch)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64))

def val_dataloader(self):
return DataLoader(RandomDataset(32, 64))

def test_dataloader(self):
return DataLoader(RandomDataset(32, 64))

def _assert_epoch_end(self, stage):
acc = self._modules[f"acc_{stage}"]
ap = self._modules[f"ap_{stage}"]

acc.reset.asset_not_called()
ap.reset.assert_not_called()

def on_train_epoch_end(self, outputs):
self._assert_epoch_end('train')

def on_validation_epoch_end(self, outputs):
self._assert_epoch_end('val')

def on_test_epoch_end(self, outputs):
self._assert_epoch_end('test')

def _assert_called(model, stage):
acc = model._modules[f"acc_{stage}"]
ap = model._modules[f"ap_{stage}"]

acc.reset.assert_called_once()
acc.reset.reset_mock()

ap.reset.assert_called_once()
ap.reset.reset_mock()

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
progress_bar_refresh_rate=0,
)

trainer.fit(model)
_assert_called(model, 'train')
_assert_called(model, 'val')

trainer.validate(model)
_assert_called(model, 'val')

trainer.test(model)
_assert_called(model, 'test')