-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topicgood first issueGood for newcomersGood for newcomers
Milestone
Description
🐛 Bug
When using a NumPy array as sampler for a PyTorch data loader the check
if (
dataloader is not None
and getattr(dataloader, "sampler", None)
and callable(getattr(dataloader.sampler, "set_epoch", None))
):in "on_advance_start", raises the following exception:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
import numpy as np
class RandomDataset(Dataset):
def __init__(self, size, num_samples):
self.len = num_samples
self.data = torch.randn(num_samples, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2, sampler=np.array([1, 2, 3, 4]))
test_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
run()Expected behavior
The error is not raised.
Environment
- CUDA:
- GPU:
- Tesla T4
- available: True
- version: 11.3
- GPU:
- Packages:
- numpy: 1.21.6
- pyTorch_debug: False
- pyTorch_version: 1.11.0+cu113
- pytorch-lightning: 1.6.4
- tqdm: 4.64.0
- System:
- OS: Linux
- architecture:
- 64bit
- processor: x86_64
- python: 3.7.13
- version: Proposal for help #1 SMP Sun Apr 24 10:03:06 PDT 2022
Additional context
An easy solution is to change the code that generates the error to
if (
dataloader is not None
and getattr(dataloader, "sampler", None) is not None
and callable(getattr(dataloader.sampler, "set_epoch", None))
):if the only thing to check is that the sampler exists and is different from None.
cc @Borda @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topicgood first issueGood for newcomersGood for newcomers