From 34445a0cbe1f73ce7495920cad522a9bf71f98c9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 20 Oct 2020 22:50:39 -0700 Subject: [PATCH 01/12] Update data_loading.py --- pytorch_lightning/trainer/data_loading.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index bf85bb2c6122e..f70e9f8ec3799 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -115,7 +115,9 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu) if self.replace_sampler_ddp and need_dist_sampler: - if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): + if isinstance(dataloader.sampler, DistributedSampler): + return dataloader + elif not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( 'You seem to have configured a sampler in your DataLoader. This will be replaced ' ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' From 75af01d3253ac321eb706b478a5863ec115ea7eb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 21 Oct 2020 00:00:10 -0700 Subject: [PATCH 02/12] Update data_loading.py --- pytorch_lightning/trainer/data_loading.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f70e9f8ec3799..5a570332cc3dd 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -112,12 +112,11 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu) + is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: - if isinstance(dataloader.sampler, DistributedSampler): - return dataloader - elif not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): + if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( 'You seem to have configured a sampler in your DataLoader. This will be replaced ' ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' From 4f9d5ca1834cb0fe0ffab77e2ac7589c8b372292 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 21 Oct 2020 23:06:34 -0700 Subject: [PATCH 03/12] add test + update flag description --- pytorch_lightning/trainer/__init__.py | 2 ++ tests/trainer/test_dataloaders.py | 33 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 5deb460f544df..a79a41e3c3616 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -1278,6 +1278,8 @@ def world_size(self): Enables auto adding of distributed sampler. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. +If ``replace_sampler_ddp=True`` and a distributed sampler was already added, +Lightning will not replace the existing one. .. testcode:: diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index ddabe919c43c4..9fb7c394c10a3 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -746,6 +746,39 @@ def test_dataloader_distributed_sampler(tmpdir): trainer.test(ckpt_path=None) +@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') +def test_dataloader_distributed_sampler_already_attached(tmpdir): + """ Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """ + + class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): + + def train_dataloader(self): + dataloader = super().train_dataloader() + dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False) + dataloader = DataLoader( + dataset, + batch_size=self.batch_size, + drop_last=False, + sampler=dist_sampler, + shuffle=False + ) + return dataloader + + model = ModelWithDataLoaderDistributedSampler() + trainer = Trainer( + gpus=[0, 1], + num_nodes=1, + distributed_backend='ddp_spawn', + default_root_dir=tmpdir, + max_steps=100, + callbacks=[DistribSamplerCallback()] + ) + trainer.fit(model) + assert result == 1, "DDP Training failed" + + + @pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs') def test_batch_size_smaller_than_num_gpus(tmpdir): # we need at least 3 gpus for this test From 504d9e85352e14e0010212782e7ee3631d60f8b4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 21 Oct 2020 23:09:00 -0700 Subject: [PATCH 04/12] add to changelog --- CHANGELOG.md | 5 +++-- tests/trainer/test_dataloaders.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c790f73729e2..e03d94e77efe1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) - Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130)) +- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273)) ### Deprecated @@ -119,7 +120,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `current_epoch` property update to reflect true epoch number inside `LightningDataModule`, when `reload_dataloaders_every_epoch=True`. ([#3974](https://github.com/PyTorchLightning/pytorch-lightning/pull/3974)) -- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053)) +- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053)) - Fixed mismatch between docstring and code regarding when `on_load_checkpoint` hook is called ([#3996](https://github.com/PyTorchLightning/pytorch-lightning/pull/3996)) @@ -464,7 +465,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986)) - Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997)) - Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020)) -- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043)) +- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043)) - Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045)) - Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 9fb7c394c10a3..1f9c7df47a7f5 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -773,6 +773,7 @@ def train_dataloader(self): default_root_dir=tmpdir, max_steps=100, callbacks=[DistribSamplerCallback()] + replace_sampler_ddp=True, ) trainer.fit(model) assert result == 1, "DDP Training failed" From ade68df65a89cc4c7976872fb940f14cb09311e8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 21 Oct 2020 23:11:15 -0700 Subject: [PATCH 05/12] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1f9c7df47a7f5..488cf1947651d 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -756,14 +756,13 @@ class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): def train_dataloader(self): dataloader = super().train_dataloader() dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False) - dataloader = DataLoader( + return DataLoader( dataset, batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, shuffle=False ) - return dataloader model = ModelWithDataLoaderDistributedSampler() trainer = Trainer( From 85ad318af441a01ca9b5d16ea97fab2e29c0eeb0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 23 Oct 2020 00:25:48 -0700 Subject: [PATCH 06/12] fix-pickle --- tests/trainer/test_dataloaders.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 488cf1947651d..1c804760c6507 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -746,24 +746,25 @@ def test_dataloader_distributed_sampler(tmpdir): trainer.test(ckpt_path=None) +class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): + + def train_dataloader(self): + dataloader = super().train_dataloader() + dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False) + return DataLoader( + dataset, + batch_size=self.batch_size, + drop_last=False, + sampler=dist_sampler, + shuffle=False + ) + + @pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') def test_dataloader_distributed_sampler_already_attached(tmpdir): """ Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """ - class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): - - def train_dataloader(self): - dataloader = super().train_dataloader() - dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False) - return DataLoader( - dataset, - batch_size=self.batch_size, - drop_last=False, - sampler=dist_sampler, - shuffle=False - ) - model = ModelWithDataLoaderDistributedSampler() trainer = Trainer( gpus=[0, 1], @@ -778,7 +779,6 @@ def train_dataloader(self): assert result == 1, "DDP Training failed" - @pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs') def test_batch_size_smaller_than_num_gpus(tmpdir): # we need at least 3 gpus for this test From 2ff354efc611196ca0d7d2920f94f6c505dbe785 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 23 Oct 2020 00:31:00 -0700 Subject: [PATCH 07/12] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1c804760c6507..a6169459a96ae 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -772,7 +772,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir): distributed_backend='ddp_spawn', default_root_dir=tmpdir, max_steps=100, - callbacks=[DistribSamplerCallback()] + callbacks=[DistribSamplerCallback()], replace_sampler_ddp=True, ) trainer.fit(model) From 05b6b5ae46e72a08e1302778b4f9c6f7d629e8cc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 23 Oct 2020 12:03:16 +0100 Subject: [PATCH 08/12] Added missing reference calls --- tests/trainer/test_dataloaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index a6169459a96ae..dd7eb61ca7c82 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -752,7 +752,7 @@ def train_dataloader(self): dataloader = super().train_dataloader() dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False) return DataLoader( - dataset, + dataloader.dataset, batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, @@ -775,7 +775,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir): callbacks=[DistribSamplerCallback()], replace_sampler_ddp=True, ) - trainer.fit(model) + result = trainer.fit(model) assert result == 1, "DDP Training failed" From 910ca399badf5a633a93dfbfc32819daca59d9c7 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 23 Oct 2020 18:50:00 +0530 Subject: [PATCH 09/12] Update tests/trainer/test_dataloaders.py --- tests/trainer/test_dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index dd7eb61ca7c82..d31bf3560d558 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -756,7 +756,7 @@ def train_dataloader(self): batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, - shuffle=False + shuffle=True ) From 295303b5c3a8c7ce340002c2d56c788591ab2104 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 23 Oct 2020 19:17:24 +0530 Subject: [PATCH 10/12] Apply suggestions from code review --- tests/trainer/test_dataloaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d31bf3560d558..9368ca8e274fd 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -750,13 +750,13 @@ class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): def train_dataloader(self): dataloader = super().train_dataloader() - dist_sampler = DistributedSampler(dataloader.dataset, shuffle=False) + dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True) return DataLoader( dataloader.dataset, batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, - shuffle=True + shuffle=False ) From fd1359b7e18f72346df07ce43e7394a0251390b2 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 23 Oct 2020 20:18:07 +0530 Subject: [PATCH 11/12] Update data_loading.py --- pytorch_lightning/trainer/data_loading.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 5a570332cc3dd..2c8e8669b8560 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -103,7 +103,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: f' (try {num_cpus} which is the number of cpus on this machine)' ' in the `DataLoader` init to improve performance.') - def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: + def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) @@ -124,7 +124,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: ' `replace_sampler_ddp`=False if you want to use your custom sampler.') # replace with distributed sampler - sampler = self._get_distributed_sampler(dataloader, train) + sampler = self._get_distributed_sampler(dataloader, shuffle) dataloader = self.replace_sampler(dataloader, sampler) return dataloader @@ -137,10 +137,11 @@ def replace_sampler(self, dataloader, sampler): } dl_args['sampler'] = sampler + dl_args['shuffle'] = False dataloader = type(dataloader)(**dl_args) return dataloader - def _get_distributed_sampler(self, dataloader, train): + def _get_distributed_sampler(self, dataloader, shuffle): if self.use_tpu: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.use_horovod: @@ -155,7 +156,7 @@ def _get_distributed_sampler(self, dataloader, train): assert self.distributed_backend is not None kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) - kwargs['shuffle'] = train and not self.overfit_batches + kwargs['shuffle'] = shuffle and not self.overfit_batches sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler @@ -180,7 +181,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: self.num_training_batches = 0 # automatically add samplers - self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) + self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True) self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') @@ -268,7 +269,7 @@ def _reset_eval_dataloader( rank_zero_warn("One of given dataloaders is None and it will be skipped.") # add samplers - dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None] + dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None] loader_num_batches = [] From 9eaea2fbd121bae2770dc03ca0ae5124fa959b20 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 23 Oct 2020 20:23:45 +0530 Subject: [PATCH 12/12] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 9368ca8e274fd..03810309d9136 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -686,17 +686,17 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, class CustomDummyObj: sampler = None - result = trainer.auto_add_sampler(CustomDummyObj(), train=True) + result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True) assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" dataset = list(range(1000)) - result = trainer.auto_add_sampler(CustomDataLoader(dataset), train=True) + result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True) assert isinstance(result, torch.utils.data.DataLoader) assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') # Shuffled DataLoader should also work - result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True) + result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True) assert isinstance(result, torch.utils.data.DataLoader) assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') @@ -707,7 +707,7 @@ class CustomSampler(torch.utils.data.Sampler): # Should raise an error if existing sampler is being replaced with pytest.raises(MisconfigurationException, match='DistributedSampler'): trainer.auto_add_sampler( - CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True) + CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True) class DistribSamplerCallback(Callback):