Skip to content

Commit e35192d

Browse files
s-rogcarmocca
andauthored
Update DataLoader.persistent_workers warnings in ddp_spawn (#6762)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 5e4dfd7 commit e35192d

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
116116
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
117117

118118

119+
- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))
120+
121+
119122
### Deprecated
120123

121124
- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))

pytorch_lightning/trainer/data_loading.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,36 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
6161
using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn"
6262
if is_dataloader and not on_windows:
6363
if dataloader.num_workers > 0 and using_spawn:
64-
rank_zero_warn(
65-
'Dataloader(num_workers>0) and ddp_spawn do not mix well!'
66-
' Your performance might suffer dramatically.'
67-
' Please consider setting accelerator=ddp to use num_workers > 0'
68-
' (this is a bottleneck of Python .spawn() and PyTorch'
69-
)
64+
# checks for the attr persistent_workers available in pytorch >= 1.7
65+
if hasattr(dataloader, "persistent_workers"):
66+
if not dataloader.persistent_workers:
67+
rank_zero_warn(
68+
'num_workers>0, persistent_workers=False, and accelerator=ddp_spawn'
69+
' may result in data loading bottlenecks.'
70+
' Consider setting persistent_workers=True'
71+
' (this is a limitation of Python .spawn() and PyTorch)'
72+
)
73+
else:
74+
rank_zero_warn(
75+
'num_workers>0 and accelerator=ddp_spawn do not mix well'
76+
' and may result in data loading bottlenecks.'
77+
' Consider setting accelerator=ddp to use num_workers>0'
78+
' (this is a limitation of Python .spawn() and PyTorch)'
79+
)
7080

7181
elif dataloader.num_workers == 0 and using_spawn:
72-
rank_zero_warn(
73-
'You are using `accelerator=ddp_spawn` with num_workers=0.'
74-
' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`'
75-
)
82+
# checks for the attr persistent_workers available in pytorch >= 1.7
83+
if hasattr(dataloader, "persistent_workers"):
84+
if not dataloader.persistent_workers:
85+
rank_zero_warn(
86+
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
87+
' Consider setting num_workers>0 and persistent_workers=True'
88+
)
89+
else:
90+
rank_zero_warn(
91+
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
92+
' Consider setting accelerator=ddp and set num_workers>0'
93+
)
7694

7795
elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
7896
num_cpus = multiprocessing.cpu_count()

tests/trainer/test_data_loading.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,25 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator,
102102
@pytest.mark.parametrize("mode", [1, 2])
103103
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
104104
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)
105+
106+
107+
@pytest.mark.parametrize("num_workers", [0, 1])
108+
def test_dataloader_warnings(num_workers):
109+
110+
class TestModel(BoringModel):
111+
112+
def on_train_start(self, *_) -> None:
113+
raise SystemExit()
114+
115+
dl = DataLoader(RandomDataset(32, 64), num_workers=num_workers)
116+
if hasattr(dl, "persistent_workers"):
117+
if num_workers == 0:
118+
warn_str = "Consider setting num_workers>0 and persistent_workers=True"
119+
else:
120+
warn_str = "Consider setting persistent_workers=True"
121+
else:
122+
warn_str = "Consider setting accelerator=ddp"
123+
124+
trainer = Trainer(accelerator="ddp_spawn")
125+
with pytest.warns(UserWarning, match=warn_str), pytest.raises(SystemExit):
126+
trainer.fit(TestModel(), dl)

0 commit comments

Comments
 (0)