From 3c936a118c4913dd5cf6543ddd5f22bb3af7c751 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 10 Feb 2021 23:20:54 +0000 Subject: [PATCH] Use property in connector for sampler --- pytorch_lightning/accelerators/accelerator_connector.py | 4 ++++ pytorch_lightning/trainer/data_loading.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index af215f6accf27..8c941878fb348 100755 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -256,6 +256,10 @@ def use_ddp2(self): def use_horovod(self): return self._distrib_type == DistributedType.HOROVOD + @property + def is_distributed(self): + return self.use_ddp or self.use_ddp2 or self.use_horovod or self.on_tpu + @property def num_gpus(self) -> int: gpus = self.parallel_device_ids diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 48684595847ef..352d2e1ce0429 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -97,9 +97,9 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu - - need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) + need_dist_sampler = self.accelerator_connector.is_distributed and not isinstance( + dataloader.sampler, DistributedSampler + ) if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException(