Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))


### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
Expand Down
38 changes: 28 additions & 10 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,36 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn"
if is_dataloader and not on_windows:
if dataloader.num_workers > 0 and using_spawn:
rank_zero_warn(
'Dataloader(num_workers>0) and ddp_spawn do not mix well!'
' Your performance might suffer dramatically.'
' Please consider setting accelerator=ddp to use num_workers > 0'
' (this is a bottleneck of Python .spawn() and PyTorch'
)
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
'num_workers>0, persistent_workers=False, and accelerator=ddp_spawn'
' may result in data loading bottlenecks.'
' Consider setting persistent_workers=True'
' (this is a limitation of Python .spawn() and PyTorch)'
)
else:
rank_zero_warn(
'num_workers>0 and accelerator=ddp_spawn do not mix well'
' and may result in data loading bottlenecks.'
' Consider setting accelerator=ddp to use num_workers>0'
' (this is a limitation of Python .spawn() and PyTorch)'
)

elif dataloader.num_workers == 0 and using_spawn:
rank_zero_warn(
'You are using `accelerator=ddp_spawn` with num_workers=0.'
' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`'
)
# checks for the attr persistent_workers available in pytorch >= 1.7
if hasattr(dataloader, "persistent_workers"):
if not dataloader.persistent_workers:
rank_zero_warn(
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
' Consider setting num_workers>0 and persistent_workers=True'
)
else:
rank_zero_warn(
'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.'
' Consider setting accelerator=ddp and set num_workers>0'
)

elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
num_cpus = multiprocessing.cpu_count()
Expand Down
22 changes: 22 additions & 0 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,25 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator,
@pytest.mark.parametrize("mode", [1, 2])
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)


@pytest.mark.parametrize("num_workers", [0, 1])
def test_dataloader_warnings(num_workers):

class TestModel(BoringModel):

def on_train_start(self, *_) -> None:
raise SystemExit()

dl = DataLoader(RandomDataset(32, 64), num_workers=num_workers)
if hasattr(dl, "persistent_workers"):
if num_workers == 0:
warn_str = "Consider setting num_workers>0 and persistent_workers=True"
else:
warn_str = "Consider setting persistent_workers=True"
else:
warn_str = "Consider setting accelerator=ddp"

trainer = Trainer(accelerator="ddp_spawn")
with pytest.warns(UserWarning, match=warn_str), pytest.raises(SystemExit):
trainer.fit(TestModel(), dl)