Skip to content

Validation skipped when using custom batch sampler #13007

@6gsn

Description

@6gsn

🐛 Bug

Hi, I'm trying to use the batch sampler which was implemented as below,

class BalancedBatchSampler(BatchSampler):
    def __init__(
        self, labels, n_classes_in_batch, n_samples_in_batch
    ):
        self.labels = labels
        self.label_set = list(set(self.labels))
        self.label_to_indices = {
            label: list(np.where(np.array(self.labels) == label)[0])
            for label in self.label_set
        }
        for l in self.label_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.label_set}

        self.n_dataset = len(self.labels)
        self.n_classes_in_batch = n_classes_in_batch
        self.n_samples_in_batch = n_samples_in_batch

        self.batch_size = self.n_samples_in_batch * self.n_classes_in_batch

    def __iter__(self):
        count = 0

        while count + self.batch_size < self.n_dataset:
            sampled_classes = np.random.choice(self.label_set, self.n_classes_in_batch, replace=False)
            indices = []

            for class_ in sampled_classes:
                pad_for_class = self.used_label_indices_count[class_]
                samples_for_class = self.label_to_indices[class_][pad_for_class: pad_for_class + self.n_samples_in_batch]
                indices += samples_for_class
                self.used_label_indices_count[class_] += self.n_samples_in_batch

                if self.used_label_indices_count[class_] + self.n_samples_in_batch > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0

            yield indices
            count += self.batch_size

    def __len__(self):
        # step size in 1 epoch
        return self.n_dataset // self.batch_size

The sampler works well in the training stage of trainer.fit(), however, the validation stage has been skipped.

To Reproduce

I confirmed that the behavior has not occurred when using torch.utils.data.sampler.BatchSampler as below,

# train_dataloader+standard sampler * val_dataloader+standard sampler = validation not skipped
def train_dataloader(self):
    train_sampler = BatchSampler(RandomSampler(self.dataset_train), ...)
    return DataLoader(
        self.dataset_train,
        batch_sampler=train_sampler,
        ...
    )

def val_dataloader(self):
    train_sampler = BatchSampler(RandomSampler(self.dataset_val), ...)
    return DataLoader(
        self.dataset_val,
        batch_sampler=val_sampler,
        ...
    )

# train_dataloader+standard sampler * val_dataloader+custom sampler = validation not skipped
def train_dataloader(self):
    train_sampler = BatchSampler(...)
    return DataLoader(...)

def val_dataloader(self):
    val_sampler = BalancedBatchSampler(...)
    return DataLoader(...)

# train_dataloader+custom sampler * val_dataloader+standard sampler = validation skipped
def train_dataloader(self):
    train_sampler = BalancedBatchSampler(...)
    return DataLoader(...)

def val_dataloader(self):
    val_sampler = BatchSampler(...)
    return DataLoader(...)

# train_dataloader+custom sampler * val_dataloader+custom sampler = validation skipped
def train_dataloader(self):
    train_sampler = BalancedBatchSampler(...)
    return DataLoader(...)

def val_dataloader(self):
    val_sampler = BalancedBatchSampler(...)
    return DataLoader(...)

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): 1.6.1
  • PyTorch Version (e.g., 1.10): 1.11
  • Python version (e.g., 3.9): 3.8.3
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: 10.2
  • How you installed PyTorch (conda, pip, source): conda

Additional context

Unfortunately, I could not find the reason why the validation stage has been skipped, and what I'm missing in the implementation.
Is there proper practice to use a custom batch sampler?

cc @justusschock @awaelchli @ninginthecloud @rohitgr7

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdata handlingGeneric data-related topic

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions