Skip to content
Merged
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))

- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))
- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))

### Deprecated

Expand Down Expand Up @@ -122,7 +123,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `current_epoch` property update to reflect true epoch number inside `LightningDataModule`, when `reload_dataloaders_every_epoch=True`. ([#3974](https://github.com/PyTorchLightning/pytorch-lightning/pull/3974))
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
- Fixed mismatch between docstring and code regarding when `on_load_checkpoint` hook is called ([#3996](https://github.com/PyTorchLightning/pytorch-lightning/pull/3996))


Expand Down Expand Up @@ -467,7 +468,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,8 @@ def world_size(self):
Enables auto adding of distributed sampler. By default it will add ``shuffle=True``
for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize
it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
If ``replace_sampler_ddp=True`` and a distributed sampler was already added,
Lightning will not replace the existing one.

.. testcode::

Expand Down
16 changes: 9 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
f' (try {num_cpus} which is the number of cpus on this machine)'
' in the `DataLoader` init to improve performance.')

def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
Expand All @@ -112,8 +112,9 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)

is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu
need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler)
if self.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
Expand All @@ -123,7 +124,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
' `replace_sampler_ddp`=False if you want to use your custom sampler.')

# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, train)
sampler = self._get_distributed_sampler(dataloader, shuffle)
dataloader = self.replace_sampler(dataloader, sampler)

return dataloader
Expand All @@ -136,10 +137,11 @@ def replace_sampler(self, dataloader, sampler):
}

dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dataloader = type(dataloader)(**dl_args)
return dataloader

def _get_distributed_sampler(self, dataloader, train):
def _get_distributed_sampler(self, dataloader, shuffle):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
Expand All @@ -154,7 +156,7 @@ def _get_distributed_sampler(self, dataloader, train):
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)

kwargs['shuffle'] = train and not self.overfit_batches
kwargs['shuffle'] = shuffle and not self.overfit_batches
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

Expand All @@ -179,7 +181,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
Expand Down Expand Up @@ -267,7 +269,7 @@ def _reset_eval_dataloader(
rank_zero_warn("One of given dataloaders is None and it will be skipped.")

# add samplers
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]

loader_num_batches = []

Expand Down
41 changes: 37 additions & 4 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,17 +686,17 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
class CustomDummyObj:
sampler = None

result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True)
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"

dataset = list(range(1000))
result = trainer.auto_add_sampler(CustomDataLoader(dataset), train=True)
result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True)
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

# Shuffled DataLoader should also work
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True)
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')
Expand All @@ -707,7 +707,7 @@ class CustomSampler(torch.utils.data.Sampler):
# Should raise an error if existing sampler is being replaced
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
trainer.auto_add_sampler(
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True)


class DistribSamplerCallback(Callback):
Expand Down Expand Up @@ -746,6 +746,39 @@ def test_dataloader_distributed_sampler(tmpdir):
trainer.test(ckpt_path=None)


class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):

def train_dataloader(self):
dataloader = super().train_dataloader()
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True)
return DataLoader(
dataloader.dataset,
batch_size=self.batch_size,
drop_last=False,
sampler=dist_sampler,
shuffle=False
)


@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_distributed_sampler_already_attached(tmpdir):
""" Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """

model = ModelWithDataLoaderDistributedSampler()
trainer = Trainer(
gpus=[0, 1],
num_nodes=1,
distributed_backend='ddp_spawn',
default_root_dir=tmpdir,
max_steps=100,
callbacks=[DistribSamplerCallback()],
replace_sampler_ddp=True,
)
result = trainer.fit(model)
assert result == 1, "DDP Training failed"


@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus(tmpdir):
# we need at least 3 gpus for this test
Expand Down