-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task
Description
🐛 Bug
Passing a metric (e.g. AveragePrecision) to self.log gives the following error:
RuntimeError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat. This usually means that this function requires a non-empty list of Tensors. Available functions are [CPU, CUDA, QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].
at this line:
preds = torch.cat(self.preds, dim=0)
Maybe relevant: if you roll back the commits, you get a different error before #6540
Please reproduce using the BoringModel
To Reproduce
Running this on master (note: must be both PL master and metrics master for bug to appear):
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import RunningStage
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchmetrics as metrics
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.l1 = torch.nn.Linear(28 * 28, 1)
for stage in [RunningStage.TRAINING, RunningStage.VALIDATING]:
self.add_module(f"acc_{stage}", metrics.Accuracy())
self.add_module(
f"ap_{stage}", metrics.AveragePrecision(num_classes=1, pos_label=1)
)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1))).squeeze()
def _step(self, stage, batch):
images, labels = batch
labels = (labels > 5).float() # Fake some targets
logits = self.forward(images)
loss = F.binary_cross_entropy_with_logits(logits, labels)
probs = torch.sigmoid(logits.detach())
self.log(f"Loss/{stage}", loss)
labels_int = labels.to(torch.long)
acc = self._modules[f"acc_{stage}"]
ap = self._modules[f"ap_{stage}"]
acc(probs, labels_int)
ap(probs, labels_int)
self.log(f"{stage}/accuracy", acc, prog_bar=True)
self.log(f"{stage}/ap", ap, prog_bar=True)
return loss
def training_step(self, batch, batch_idx: int, *args, **kwargs) -> torch.Tensor:
return self._step(RunningStage.TRAINING, batch)
def validation_step(self, batch, batch_idx: int, *args, **kwargs) -> torch.Tensor:
return self._step(RunningStage.VALIDATING, batch)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
print(pl.__version__)
print(metrics.__version__)
# Init our model
mnist_model = MNISTModel()
# Init DataLoader from MNIST Dataset
train_ds = MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=32)
val_ds = MNIST('./data', train=False, download=True, transform=transforms.ToTensor())
val_loader = DataLoader(val_ds, batch_size=32)
# Initialize a trainer
trainer = pl.Trainer(
max_epochs=2,
progress_bar_refresh_rate=20,
gpus=1
)
# Train the model ⚡
trainer.fit(mnist_model, train_loader, val_loader)cc @ananthsub
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task