Skip to content

Distributed sampler puts all data on all machines in DDP #6423

@timothybrooks

Description

@timothybrooks

The lightning documentation says it automatically adds the appropriate (distributed) sampler for all training backends. However in DDP mode, the sampler puts the entire dataset on all machines, instead of partitioning the dataset into equally-sized disjoint subsets, which is expected and the behavior of the torch DistributedSampler. As a user, I expect the automatic sampler to default to the same behavior as DistributedSampler when in DDP mode. If there is some desire to use the entire dataset on all machines, I recommend adding a training flag for this, such as full_dataset_sampler=False.

For a dataset of size 16k, per-GPU batch size of 8, and 2 GPUs, the following puts all data on both machines, resulting in 2k training steps per epoch. This does not align with behavior of the torch DistributedSampler and is unexpected.

dataset = Dataset(...)
data_loader = DataLoader(dataset, ...)
model = MyModel(...)
trainer = Trainer(accelerator="ddp", ...)
trainer.fit(model, data_loader)

If we instead implement the train_dataloader() method in the lightning module and add the torch DistributedSampler manually, the behavior is correct, partitioning the dataset into a subset for each process, and resulting in 1k training steps per epoch. The code above where lightning adds a sampler should behave the same but does not.

class MyModel(LightningModule):
    ...
    def train_dataloader(self):
        dataset = Dataset(...)
        sampler = DistributedSampler(dataset, ...)
        data_loader = DataLoader(dataset, sampler=sampler, ...)
        return data_loader
    ...

model = MyModel(...)
trainer = Trainer(accelerator="ddp", ...)
trainer.fit(model)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions