Skip to content

Prediction with custom batch sampler #11807

@stecklin

Description

@stecklin

🐛 Bug

When using a custom batch sampler for prediction, the following code gets executed:

https://github.com/PyTorchLightning/pytorch-lightning/blob/ab1c2ff23fd27ab6e5647e390313b569e32b61c5/pytorch_lightning/trainer/data_loading.py#L221-L226

This assumes the custom batch sampler has the same interface as BatchSampler. However, data loaders can take any Sampler as batch_sampler argument according to the docs: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

To Reproduce

This is the custom sampler I was using:

class MySampler(Sampler):
    def __init__(self, dataset):
        """
        Args:
            dataset (Dataset): PyTorch dataset.
        """
        self.dataset = dataset
        self.group_ids_to_sample_indices = self.get_group_ids_to_sample_indices()

    def get_group_ids_to_sample_indices(self):
        """ Group samples by id

        Returns:
            dict: group_id -> indices of samples with this id
        """
        group_ids_to_sample_indices = defaultdict(list)
        for i, sample in enumerate(self.dataset.samples):
            group_ids_to_sample_indices[sample["group_id"]].append(i)
        return group_ids_to_sample_indices

    def __iter__(self):
        group_ids = list(self.group_ids_to_sample_indices.keys())
        # Yield a batch of variable size with all samples for one group
        for group_id in group_ids:
            yield self.group_ids_to_sample_indices[group_id]

    def __len__(self):
        return len(self.group_ids_to_sample_indices)

Expected behavior

Prediction runs without crashing. For training, validation and testing it already works.

Environment

  • PyTorch Lightning Version: 1.5.9
  • PyTorch Version: 1.10.1

Additional context

I did the following work around for now, but it feels very hacky:

class MySampler(Sampler):
    def __init__(self, sampler=None, batch_size=None, drop_last=False, dataset=None):
        if sampler:
            self.dataset = sampler.data_source
        else:
            self.dataset = dataset
        self.group_ids_to_sample_indices = self.get_group_ids_to_sample_indices()

cc @tchaton @rohitgr7 @akihironitta @justusschock @awaelchli @ninginthecloud

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdata handlingGeneric data-related topic

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions