Skip to content

set_epoch not called for BatchSampler #13316

@samgelman

Description

@samgelman

🐛 Bug

Lightning takes care of calling set_epoch for custom DataLoader Samplers here. However, the DataLoader might use a custom BatchSampler instead of a Sampler. Lightning does not call set_epoch for custom BatchSamplers. Calling set_epoch is important for proper seeding in distributed environments.

Expected behavior

Lightning should call set_epoch for BatchSamplers to match its behavior for Samplers.
Can use the DataLoader's index_sampler property to retrieve the Sampler or BatchSampler that is actually being used by the DataLoader, or more simply, call set_epoch for both Sampler and BatchSampler.

Additional context

A workaround is to use a Callback such as

class SetBatchSamplerEpoch(Callback):
    """ sets the epoch for batch sampler before dataloader iterator is initialized every training epoch """
    def __init__(self):
        super().__init__()

    @staticmethod
    def set_batch_sampler_epoch(dataloader, epoch):
        if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
            # print("[RANK {}] Setting batch_sampler epoch to {}".format(os.getenv("LOCAL_RANK", '0'), epoch))
            dataloader.batch_sampler.set_epoch(epoch)

    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if trainer.train_dataloader is not None:
            if isinstance(trainer.train_dataloader, CombinedLoader):
                # is the train_dataloader always wrapped in a CombinedLoader at this point?
                apply_to_collection(data=trainer.train_dataloader.loaders,
                                    dtype=DataLoader,
                                    function=self.set_batch_sampler_epoch,
                                    epoch=trainer.current_epoch)

            elif isinstance(trainer.train_dataloader, DataLoader):
                # just in case the train_dataloader is not wrapped in a CombinedLoader
                self.set_batch_sampler_epoch(trainer.train_dataloader, trainer.current_epoch)

            else:
                raise TypeError(f"Unexpected type of trainer.train_dataloader: {type(trainer.train_dataloader)}")

cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdata handlingGeneric data-related topic

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions