Skip to content

Automatically reset metrics on every epoch #6201

@alanhdu

Description

@alanhdu

🚀 Feature

PyTorch Lightning should reset metrics on every epoch.

Motivation

In the current docs, there's a big warning:

From v1.2 onward compute() will no longer automatically call reset(), and it is up to the user to reset metrics between epochs, except in the case where the metric is directly passed to LightningModules self.log.

Obviously, this is a bit of a pain to have to do manually, especially because there doesn't seem to be any hook for epoch_start that a LightningModule can define here.

Pitch

PyTorch Lightning should automatically reset all metrics whenever it starts a new epoch. Implementationally, I think this is as simple as:

def reset(c):
    if isinstance(c, pl.metrics.Metric):
        c.reset()
    return c
module.apply(reset)

This would be limited to metrics that are correctly identified as children, but since that's already required to get the metric on the right device, I think that's fine:

when properly defined inside a LightningModule , Lightning will automatically move the metrics to the same device as the data. Being properly defined means that the metric is correctly identified as a child module of the model (check .children() attribute of the model). Therefore, metrics cannot be placed in native python list and dict, as they will not be correctly identified as child modules. Instead of list use ModuleList and instead of dict use ModuleDict.

Alternatives

We could implement some kind of epoch_start hook for LightningModule where the user could nicely implement this. Otherwise, I cannot figure out where to reset each epoch -- I have tried doing it at the end of the epoch (via validation_epoch-end), but that seems to give incorrect results on the first epoch:

This script prints out 400 for the first epoch rather than 100
import pytorch_lightning as pl
import torch
from torch import nn


class Counter(pl.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.add_state("test", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, x: int):
        self.test += x

    def compute(self):
        return self.test


class Module(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 1)
        self.metric = Counter()
        self.metric.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.linear.parameters())

    def training_step(self, batch, batch_idx):
        self.metric.update(torch.tensor(float(len(batch))))
        return self.linear(batch).sum()

    def validation_step(self, batch, batch_idx):
        return batch_idx

    def validation_epoch_end(self, outputs):
        print("VALIDATING", self.metric.compute())
        self.metric.reset()


if __name__ == "__main__":
    m = Module()

    datasets = [torch.rand([5]) for __ in range(100)]
    train_loader = torch.utils.data.DataLoader(datasets, batch_size=8)
    val_loader = torch.utils.data.DataLoader(datasets, batch_size=1)

    trainer = pl.Trainer(
        num_sanity_val_steps=0,
        max_epochs=5,
        accelerator="ddp_cpu",
        num_processes=4,
    )
    trainer.fit(m, train_loader, val_loader)

prints out:

  warnings.warn(*args, **kwargs)
Epoch 0:  14%|████████████████▉                                                                                                          | 4/29 [00:00<00:00, 64.38it/s, loss=-3.59, v_num=39VALIDATING tensor(400.)                                                                                                                                                 | 0/25 [00:00<?, ?it/s]
VALIDATING tensor(400.)
VALIDATING tensor(400.)
VALIDATING tensor(400.)
Epoch 1:  14%|████████████████▊                                                                                                         | 4/29 [00:00<00:00, 242.12it/s, loss=-3.79, v_num=39VALIDATING tensor(100.)                                                                                                                                                 | 0/25 [00:00<?, ?it/s]
VALIDATING tensor(100.)
VALIDATING tensor(100.)
VALIDATING tensor(100.)

Even if this works, I still think it's subpar, b/c this seems like something that could be handled automatically fairly easily.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked onpriority: 0High priority task

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions