Skip to content

Error in "on_advance_start" when data-loader's sampler is a NumPy array  #13320

@LucaButera

Description

@LucaButera

🐛 Bug

When using a NumPy array as sampler for a PyTorch data loader the check

if (
    dataloader is not None
    and getattr(dataloader, "sampler", None)
    and callable(getattr(dataloader.sampler, "set_epoch", None))
    ):

in "on_advance_start", raises the following exception:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

To Reproduce

import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
import numpy as np

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2, sampler=np.array([1, 2, 3, 4]))
    test_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

run()

Expected behavior

The error is not raised.

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 11.3
  • Packages:
    • numpy: 1.21.6
    • pyTorch_debug: False
    • pyTorch_version: 1.11.0+cu113
    • pytorch-lightning: 1.6.4
    • tqdm: 4.64.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.13
    • version: Proposal for help #1 SMP Sun Apr 24 10:03:06 PDT 2022

Additional context

An easy solution is to change the code that generates the error to

if (
    dataloader is not None
    and getattr(dataloader, "sampler", None) is not None
    and callable(getattr(dataloader.sampler, "set_epoch", None))
    ):

if the only thing to check is that the sampler exists and is different from None.

cc @Borda @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdata handlingGeneric data-related topicgood first issueGood for newcomers

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions