-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
PyTorch Lightning should reset metrics on every epoch.
Motivation
In the current docs, there's a big warning:
From v1.2 onward compute() will no longer automatically call reset(), and it is up to the user to reset metrics between epochs, except in the case where the metric is directly passed to LightningModule
sself.log.
Obviously, this is a bit of a pain to have to do manually, especially because there doesn't seem to be any hook for epoch_start that a LightningModule can define here.
Pitch
PyTorch Lightning should automatically reset all metrics whenever it starts a new epoch. Implementationally, I think this is as simple as:
def reset(c):
if isinstance(c, pl.metrics.Metric):
c.reset()
return c
module.apply(reset)This would be limited to metrics that are correctly identified as children, but since that's already required to get the metric on the right device, I think that's fine:
when properly defined inside a LightningModule , Lightning will automatically move the metrics to the same device as the data. Being properly defined means that the metric is correctly identified as a child module of the model (check .children() attribute of the model). Therefore, metrics cannot be placed in native python list and dict, as they will not be correctly identified as child modules. Instead of list use ModuleList and instead of dict use ModuleDict.
Alternatives
We could implement some kind of epoch_start hook for LightningModule where the user could nicely implement this. Otherwise, I cannot figure out where to reset each epoch -- I have tried doing it at the end of the epoch (via validation_epoch-end), but that seems to give incorrect results on the first epoch:
This script prints out 400 for the first epoch rather than 100
import pytorch_lightning as pl
import torch
from torch import nn
class Counter(pl.metrics.Metric):
def __init__(self):
super().__init__()
self.add_state("test", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, x: int):
self.test += x
def compute(self):
return self.test
class Module(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 1)
self.metric = Counter()
self.metric.reset()
def configure_optimizers(self):
return torch.optim.Adam(self.linear.parameters())
def training_step(self, batch, batch_idx):
self.metric.update(torch.tensor(float(len(batch))))
return self.linear(batch).sum()
def validation_step(self, batch, batch_idx):
return batch_idx
def validation_epoch_end(self, outputs):
print("VALIDATING", self.metric.compute())
self.metric.reset()
if __name__ == "__main__":
m = Module()
datasets = [torch.rand([5]) for __ in range(100)]
train_loader = torch.utils.data.DataLoader(datasets, batch_size=8)
val_loader = torch.utils.data.DataLoader(datasets, batch_size=1)
trainer = pl.Trainer(
num_sanity_val_steps=0,
max_epochs=5,
accelerator="ddp_cpu",
num_processes=4,
)
trainer.fit(m, train_loader, val_loader)prints out:
warnings.warn(*args, **kwargs)
Epoch 0: 14%|████████████████▉ | 4/29 [00:00<00:00, 64.38it/s, loss=-3.59, v_num=39VALIDATING tensor(400.) | 0/25 [00:00<?, ?it/s]
VALIDATING tensor(400.)
VALIDATING tensor(400.)
VALIDATING tensor(400.)
Epoch 1: 14%|████████████████▊ | 4/29 [00:00<00:00, 242.12it/s, loss=-3.79, v_num=39VALIDATING tensor(100.) | 0/25 [00:00<?, ?it/s]
VALIDATING tensor(100.)
VALIDATING tensor(100.)
VALIDATING tensor(100.)
Even if this works, I still think it's subpar, b/c this seems like something that could be handled automatically fairly easily.