diff --git a/CHANGELOG.md b/CHANGELOG.md index c05b5f2e28f34..4f31320844543 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -109,6 +109,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) +- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/)) + + ### Deprecated - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 59ec40c3df2e8..0480c8023c3f8 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -61,18 +61,36 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn" if is_dataloader and not on_windows: if dataloader.num_workers > 0 and using_spawn: - rank_zero_warn( - 'Dataloader(num_workers>0) and ddp_spawn do not mix well!' - ' Your performance might suffer dramatically.' - ' Please consider setting accelerator=ddp to use num_workers > 0' - ' (this is a bottleneck of Python .spawn() and PyTorch' - ) + # checks for the attr persistent_workers available in pytorch >= 1.7 + if hasattr(dataloader, "persistent_workers"): + if not dataloader.persistent_workers: + rank_zero_warn( + 'num_workers>0, persistent_workers=False, and accelerator=ddp_spawn' + ' may result in data loading bottlenecks.' + ' Consider setting persistent_workers=True' + ' (this is a limitation of Python .spawn() and PyTorch)' + ) + else: + rank_zero_warn( + 'num_workers>0 and accelerator=ddp_spawn do not mix well' + ' and may result in data loading bottlenecks.' + ' Consider setting accelerator=ddp to use num_workers>0' + ' (this is a limitation of Python .spawn() and PyTorch)' + ) elif dataloader.num_workers == 0 and using_spawn: - rank_zero_warn( - 'You are using `accelerator=ddp_spawn` with num_workers=0.' - ' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`' - ) + # checks for the attr persistent_workers available in pytorch >= 1.7 + if hasattr(dataloader, "persistent_workers"): + if not dataloader.persistent_workers: + rank_zero_warn( + 'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.' + ' Consider setting num_workers>0 and persistent_workers=True' + ) + else: + rank_zero_warn( + 'accelerator=ddp_spawn and num_workers=0 may result in data loading bottlenecks.' + ' Consider setting accelerator=ddp and set num_workers>0' + ) elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn: num_cpus = multiprocessing.cpu_count() diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index ec7f020faa4c3..382311c107958 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -102,3 +102,25 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, @pytest.mark.parametrize("mode", [1, 2]) def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode): check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode) + + +@pytest.mark.parametrize("num_workers", [0, 1]) +def test_dataloader_warnings(num_workers): + + class TestModel(BoringModel): + + def on_train_start(self, *_) -> None: + raise SystemExit() + + dl = DataLoader(RandomDataset(32, 64), num_workers=num_workers) + if hasattr(dl, "persistent_workers"): + if num_workers == 0: + warn_str = "Consider setting num_workers>0 and persistent_workers=True" + else: + warn_str = "Consider setting persistent_workers=True" + else: + warn_str = "Consider setting accelerator=ddp" + + trainer = Trainer(accelerator="ddp_spawn") + with pytest.warns(UserWarning, match=warn_str), pytest.raises(SystemExit): + trainer.fit(TestModel(), dl)