Skip to content

Commit f6efb71

Browse files
ananthsubSeanNarenrohitgr7
authored
Skip replacing dataloader sampler if it's already a distributed sampler (#4273)
* Update data_loading.py * Update data_loading.py * add test + update flag description * add to changelog * Update test_dataloaders.py * fix-pickle * Update test_dataloaders.py * Added missing reference calls * Update tests/trainer/test_dataloaders.py * Apply suggestions from code review * Update data_loading.py * Update test_dataloaders.py Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 91c64e9 commit f6efb71

File tree

4 files changed

+51
-13
lines changed

4 files changed

+51
-13
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
2121

2222
- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))
23+
- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))
2324

2425
### Deprecated
2526

@@ -122,7 +123,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
122123
### Fixed
123124

124125
- Fixed `current_epoch` property update to reflect true epoch number inside `LightningDataModule`, when `reload_dataloaders_every_epoch=True`. ([#3974](https://github.com/PyTorchLightning/pytorch-lightning/pull/3974))
125-
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
126+
- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053))
126127
- Fixed mismatch between docstring and code regarding when `on_load_checkpoint` hook is called ([#3996](https://github.com/PyTorchLightning/pytorch-lightning/pull/3996))
127128

128129

@@ -467,7 +468,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
467468
- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))
468469
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
469470
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
470-
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
471+
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
471472
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
472473
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))
473474

pytorch_lightning/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,8 @@ def world_size(self):
12781278
Enables auto adding of distributed sampler. By default it will add ``shuffle=True``
12791279
for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize
12801280
it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
1281+
If ``replace_sampler_ddp=True`` and a distributed sampler was already added,
1282+
Lightning will not replace the existing one.
12811283
12821284
.. testcode::
12831285

pytorch_lightning/trainer/data_loading.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
103103
f' (try {num_cpus} which is the number of cpus on this machine)'
104104
' in the `DataLoader` init to improve performance.')
105105

106-
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
106+
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
107107

108108
# don't do anything if it's not a dataloader
109109
is_dataloader = isinstance(dataloader, DataLoader)
@@ -112,8 +112,9 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
112112

113113
if not is_dataloader or is_iterable_ds:
114114
return dataloader
115-
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
116115

116+
is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu
117+
need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler)
117118
if self.replace_sampler_ddp and need_dist_sampler:
118119
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
119120
raise MisconfigurationException(
@@ -123,7 +124,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
123124
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
124125

125126
# replace with distributed sampler
126-
sampler = self._get_distributed_sampler(dataloader, train)
127+
sampler = self._get_distributed_sampler(dataloader, shuffle)
127128
dataloader = self.replace_sampler(dataloader, sampler)
128129

129130
return dataloader
@@ -136,10 +137,11 @@ def replace_sampler(self, dataloader, sampler):
136137
}
137138

138139
dl_args['sampler'] = sampler
140+
dl_args['shuffle'] = False
139141
dataloader = type(dataloader)(**dl_args)
140142
return dataloader
141143

142-
def _get_distributed_sampler(self, dataloader, train):
144+
def _get_distributed_sampler(self, dataloader, shuffle):
143145
if self.use_tpu:
144146
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
145147
elif self.use_horovod:
@@ -154,7 +156,7 @@ def _get_distributed_sampler(self, dataloader, train):
154156
assert self.distributed_backend is not None
155157
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
156158

157-
kwargs['shuffle'] = train and not self.overfit_batches
159+
kwargs['shuffle'] = shuffle and not self.overfit_batches
158160
sampler = DistributedSampler(dataloader.dataset, **kwargs)
159161
return sampler
160162

@@ -179,7 +181,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
179181
self.num_training_batches = 0
180182

181183
# automatically add samplers
182-
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
184+
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True)
183185

184186
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
185187
self._worker_check(self.train_dataloader, 'train dataloader')
@@ -267,7 +269,7 @@ def _reset_eval_dataloader(
267269
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
268270

269271
# add samplers
270-
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
272+
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]
271273

272274
loader_num_batches = []
273275

tests/trainer/test_dataloaders.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,17 +686,17 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
686686
class CustomDummyObj:
687687
sampler = None
688688

689-
result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
689+
result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True)
690690
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
691691

692692
dataset = list(range(1000))
693-
result = trainer.auto_add_sampler(CustomDataLoader(dataset), train=True)
693+
result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True)
694694
assert isinstance(result, torch.utils.data.DataLoader)
695695
assert isinstance(result, CustomDataLoader)
696696
assert hasattr(result, 'dummy_kwarg')
697697

698698
# Shuffled DataLoader should also work
699-
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
699+
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True)
700700
assert isinstance(result, torch.utils.data.DataLoader)
701701
assert isinstance(result, CustomDataLoader)
702702
assert hasattr(result, 'dummy_kwarg')
@@ -707,7 +707,7 @@ class CustomSampler(torch.utils.data.Sampler):
707707
# Should raise an error if existing sampler is being replaced
708708
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
709709
trainer.auto_add_sampler(
710-
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
710+
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True)
711711

712712

713713
class DistribSamplerCallback(Callback):
@@ -746,6 +746,39 @@ def test_dataloader_distributed_sampler(tmpdir):
746746
trainer.test(ckpt_path=None)
747747

748748

749+
class ModelWithDataLoaderDistributedSampler(EvalModelTemplate):
750+
751+
def train_dataloader(self):
752+
dataloader = super().train_dataloader()
753+
dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True)
754+
return DataLoader(
755+
dataloader.dataset,
756+
batch_size=self.batch_size,
757+
drop_last=False,
758+
sampler=dist_sampler,
759+
shuffle=False
760+
)
761+
762+
763+
@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
764+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
765+
def test_dataloader_distributed_sampler_already_attached(tmpdir):
766+
""" Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """
767+
768+
model = ModelWithDataLoaderDistributedSampler()
769+
trainer = Trainer(
770+
gpus=[0, 1],
771+
num_nodes=1,
772+
distributed_backend='ddp_spawn',
773+
default_root_dir=tmpdir,
774+
max_steps=100,
775+
callbacks=[DistribSamplerCallback()],
776+
replace_sampler_ddp=True,
777+
)
778+
result = trainer.fit(model)
779+
assert result == 1, "DDP Training failed"
780+
781+
749782
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
750783
def test_batch_size_smaller_than_num_gpus(tmpdir):
751784
# we need at least 3 gpus for this test

0 commit comments

Comments
 (0)