-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic
Milestone
Description
🐛 Bug
Hi, I'm trying to use the batch sampler which was implemented as below,
class BalancedBatchSampler(BatchSampler):
def __init__(
self, labels, n_classes_in_batch, n_samples_in_batch
):
self.labels = labels
self.label_set = list(set(self.labels))
self.label_to_indices = {
label: list(np.where(np.array(self.labels) == label)[0])
for label in self.label_set
}
for l in self.label_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.label_set}
self.n_dataset = len(self.labels)
self.n_classes_in_batch = n_classes_in_batch
self.n_samples_in_batch = n_samples_in_batch
self.batch_size = self.n_samples_in_batch * self.n_classes_in_batch
def __iter__(self):
count = 0
while count + self.batch_size < self.n_dataset:
sampled_classes = np.random.choice(self.label_set, self.n_classes_in_batch, replace=False)
indices = []
for class_ in sampled_classes:
pad_for_class = self.used_label_indices_count[class_]
samples_for_class = self.label_to_indices[class_][pad_for_class: pad_for_class + self.n_samples_in_batch]
indices += samples_for_class
self.used_label_indices_count[class_] += self.n_samples_in_batch
if self.used_label_indices_count[class_] + self.n_samples_in_batch > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
count += self.batch_size
def __len__(self):
# step size in 1 epoch
return self.n_dataset // self.batch_sizeThe sampler works well in the training stage of trainer.fit(), however, the validation stage has been skipped.
To Reproduce
I confirmed that the behavior has not occurred when using torch.utils.data.sampler.BatchSampler as below,
# train_dataloader+standard sampler * val_dataloader+standard sampler = validation not skipped
def train_dataloader(self):
train_sampler = BatchSampler(RandomSampler(self.dataset_train), ...)
return DataLoader(
self.dataset_train,
batch_sampler=train_sampler,
...
)
def val_dataloader(self):
train_sampler = BatchSampler(RandomSampler(self.dataset_val), ...)
return DataLoader(
self.dataset_val,
batch_sampler=val_sampler,
...
)
# train_dataloader+standard sampler * val_dataloader+custom sampler = validation not skipped
def train_dataloader(self):
train_sampler = BatchSampler(...)
return DataLoader(...)
def val_dataloader(self):
val_sampler = BalancedBatchSampler(...)
return DataLoader(...)
# train_dataloader+custom sampler * val_dataloader+standard sampler = validation skipped
def train_dataloader(self):
train_sampler = BalancedBatchSampler(...)
return DataLoader(...)
def val_dataloader(self):
val_sampler = BatchSampler(...)
return DataLoader(...)
# train_dataloader+custom sampler * val_dataloader+custom sampler = validation skipped
def train_dataloader(self):
train_sampler = BalancedBatchSampler(...)
return DataLoader(...)
def val_dataloader(self):
val_sampler = BalancedBatchSampler(...)
return DataLoader(...)Environment
- PyTorch Lightning Version (e.g., 1.5.0): 1.6.1
- PyTorch Version (e.g., 1.10): 1.11
- Python version (e.g., 3.9): 3.8.3
- OS (e.g., Linux): Linux
- CUDA/cuDNN version: 10.2
- How you installed PyTorch (
conda,pip, source): conda
Additional context
Unfortunately, I could not find the reason why the validation stage has been skipped, and what I'm missing in the implementation.
Is there proper practice to use a custom batch sampler?
finiteautomata
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic