Skip to content

Commit 222653d

Browse files
authored
Use property in connector for sampler (#5913)
1 parent 135c236 commit 222653d

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ def use_ddp2(self):
256256
def use_horovod(self):
257257
return self._distrib_type == DistributedType.HOROVOD
258258

259+
@property
260+
def is_distributed(self):
261+
return self.use_ddp or self.use_ddp2 or self.use_horovod or self.on_tpu
262+
259263
@property
260264
def num_gpus(self) -> int:
261265
gpus = self.parallel_device_ids

pytorch_lightning/trainer/data_loading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
9797
if not is_dataloader or is_iterable_ds:
9898
return dataloader
9999

100-
is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu
101-
102-
need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler)
100+
need_dist_sampler = self.accelerator_connector.is_distributed and not isinstance(
101+
dataloader.sampler, DistributedSampler
102+
)
103103
if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler:
104104
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
105105
raise MisconfigurationException(

0 commit comments

Comments
 (0)