-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic
Milestone
Description
🐛 Bug
Lightning takes care of calling set_epoch for custom DataLoader Samplers here. However, the DataLoader might use a custom BatchSampler instead of a Sampler. Lightning does not call set_epoch for custom BatchSamplers. Calling set_epoch is important for proper seeding in distributed environments.
Expected behavior
Lightning should call set_epoch for BatchSamplers to match its behavior for Samplers.
Can use the DataLoader's index_sampler property to retrieve the Sampler or BatchSampler that is actually being used by the DataLoader, or more simply, call set_epoch for both Sampler and BatchSampler.
Additional context
A workaround is to use a Callback such as
class SetBatchSamplerEpoch(Callback):
""" sets the epoch for batch sampler before dataloader iterator is initialized every training epoch """
def __init__(self):
super().__init__()
@staticmethod
def set_batch_sampler_epoch(dataloader, epoch):
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
# print("[RANK {}] Setting batch_sampler epoch to {}".format(os.getenv("LOCAL_RANK", '0'), epoch))
dataloader.batch_sampler.set_epoch(epoch)
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.train_dataloader is not None:
if isinstance(trainer.train_dataloader, CombinedLoader):
# is the train_dataloader always wrapped in a CombinedLoader at this point?
apply_to_collection(data=trainer.train_dataloader.loaders,
dtype=DataLoader,
function=self.set_batch_sampler_epoch,
epoch=trainer.current_epoch)
elif isinstance(trainer.train_dataloader, DataLoader):
# just in case the train_dataloader is not wrapped in a CombinedLoader
self.set_batch_sampler_epoch(trainer.train_dataloader, trainer.current_epoch)
else:
raise TypeError(f"Unexpected type of trainer.train_dataloader: {type(trainer.train_dataloader)}")
awaelchli
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topic