Skip to content

Metrics passed to self.log giving RuntimeError #7052

@ethanwharris

Description

@ethanwharris

🐛 Bug

Passing a metric (e.g. AveragePrecision) to self.log gives the following error:

RuntimeError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat.  This usually means that this function requires a non-empty list of Tensors.  Available functions are [CPU, CUDA, QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

at this line:

preds = torch.cat(self.preds, dim=0)

Maybe relevant: if you roll back the commits, you get a different error before #6540

Please reproduce using the BoringModel

To Reproduce

Running this on master (note: must be both PL master and metrics master for bug to appear):

import pytorch_lightning as pl
from pytorch_lightning.trainer.states import RunningStage
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchmetrics as metrics


class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 1)

        for stage in [RunningStage.TRAINING, RunningStage.VALIDATING]:
            self.add_module(f"acc_{stage}", metrics.Accuracy())
            self.add_module(
                f"ap_{stage}", metrics.AveragePrecision(num_classes=1, pos_label=1)
            )

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1))).squeeze()

    def _step(self, stage, batch):
        images, labels = batch
        labels = (labels > 5).float()  # Fake some targets
        logits = self.forward(images)
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        probs = torch.sigmoid(logits.detach())
        self.log(f"Loss/{stage}", loss)

        labels_int = labels.to(torch.long)
        acc = self._modules[f"acc_{stage}"]
        ap = self._modules[f"ap_{stage}"]
        acc(probs, labels_int)
        ap(probs, labels_int)
        self.log(f"{stage}/accuracy", acc, prog_bar=True)
        self.log(f"{stage}/ap", ap, prog_bar=True)

        return loss

    def training_step(self, batch, batch_idx: int, *args, **kwargs) -> torch.Tensor:
        return self._step(RunningStage.TRAINING, batch)

    def validation_step(self, batch, batch_idx: int, *args, **kwargs) -> torch.Tensor:
        return self._step(RunningStage.VALIDATING, batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


print(pl.__version__)
print(metrics.__version__)

# Init our model
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=32)

val_ds = MNIST('./data', train=False, download=True, transform=transforms.ToTensor())
val_loader = DataLoader(val_ds, batch_size=32)

# Initialize a trainer
trainer = pl.Trainer(
    max_epochs=2,
    progress_bar_refresh_rate=20,
    gpus=1
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader, val_loader)

cc @ananthsub

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions