-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic
Milestone
Description
🐛 Bug
When using a custom batch sampler for prediction, the following code gets executed:
This assumes the custom batch sampler has the same interface as BatchSampler. However, data loaders can take any Sampler as batch_sampler argument according to the docs: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
To Reproduce
This is the custom sampler I was using:
class MySampler(Sampler):
def __init__(self, dataset):
"""
Args:
dataset (Dataset): PyTorch dataset.
"""
self.dataset = dataset
self.group_ids_to_sample_indices = self.get_group_ids_to_sample_indices()
def get_group_ids_to_sample_indices(self):
""" Group samples by id
Returns:
dict: group_id -> indices of samples with this id
"""
group_ids_to_sample_indices = defaultdict(list)
for i, sample in enumerate(self.dataset.samples):
group_ids_to_sample_indices[sample["group_id"]].append(i)
return group_ids_to_sample_indices
def __iter__(self):
group_ids = list(self.group_ids_to_sample_indices.keys())
# Yield a batch of variable size with all samples for one group
for group_id in group_ids:
yield self.group_ids_to_sample_indices[group_id]
def __len__(self):
return len(self.group_ids_to_sample_indices)Expected behavior
Prediction runs without crashing. For training, validation and testing it already works.
Environment
- PyTorch Lightning Version: 1.5.9
- PyTorch Version: 1.10.1
Additional context
I did the following work around for now, but it feels very hacky:
class MySampler(Sampler):
def __init__(self, sampler=None, batch_size=None, drop_last=False, dataset=None):
if sampler:
self.dataset = sampler.data_source
else:
self.dataset = dataset
self.group_ids_to_sample_indices = self.get_group_ids_to_sample_indices()cc @tchaton @rohitgr7 @akihironitta @justusschock @awaelchli @ninginthecloud
TheShadow29 and shenoynikhildenizetkar, SagiPolaczek and shenoynikhil
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic