Skip to content

Commit 57f5268

Browse files
awaelchlicarmocca
andauthored
Improve the suggested num_workers warning (#18591)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent d13ad1e commit 57f5268

File tree

10 files changed

+153
-20
lines changed

10 files changed

+153
-20
lines changed

docs/source-fabric/api/utilities.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ lightning.fabric.utilities
99
.. autofunction:: lightning.fabric.utilities.seed.seed_everything
1010

1111
.. autofunction:: lightning.fabric.utilities.seed.pl_worker_init_function
12+
13+
.. autofunction:: lightning.fabric.utilities.data.suggested_max_num_workers

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
127127
- Enabled the default process group configuration for FSDP's hybrid sharding ([#18583](https://github.com/Lightning-AI/lightning/pull/18583))
128128

129129

130+
- Added `lightning.fabric.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591))
131+
132+
130133
### Changed
131134

132135
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))

src/lightning/fabric/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""General utilities."""
1515

1616
from lightning.fabric.utilities.apply_func import move_data_to_device # noqa: F401
17+
from lightning.fabric.utilities.data import suggested_max_num_workers # noqa: F401
1718
from lightning.fabric.utilities.enums import LightningEnum # noqa: F401
1819
from lightning.fabric.utilities.rank_zero import ( # noqa: F401
1920
rank_zero_deprecation,

src/lightning/fabric/utilities/data.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,25 @@ def _set_sampler_epoch(dataloader: object, epoch: int) -> None:
433433
set_epoch = getattr(obj, "set_epoch", None)
434434
if callable(set_epoch):
435435
set_epoch(epoch)
436+
437+
438+
def suggested_max_num_workers(local_world_size: int) -> int:
439+
"""Suggests an upper bound of ``num_workers`` to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on
440+
the number of CPU cores available on the system and the number of distributed processes in the current machine.
441+
442+
Args:
443+
local_world_size: The number of distributed processes running on the current machine. Set this to the number
444+
of devices configured in Fabric/Trainer.
445+
"""
446+
if local_world_size < 1:
447+
raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.")
448+
cpu_count = _num_cpus_available()
449+
return max(1, cpu_count // local_world_size)
450+
451+
452+
def _num_cpus_available() -> int:
453+
if hasattr(os, "sched_getaffinity"):
454+
return len(os.sched_getaffinity(0))
455+
456+
cpu_count = os.cpu_count()
457+
return 1 if cpu_count is None else cpu_count

src/lightning/pytorch/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
131131
- Enabled the default process group configuration for FSDP's hybrid sharding ([#18583](https://github.com/Lightning-AI/lightning/pull/18583))
132132

133133

134+
135+
- Added `lightning.pytorch.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591))
136+
137+
138+
- Improved the `num_workers` warning to give a more accurate upper limit on the `num_workers` suggestion ([#18591](https://github.com/Lightning-AI/lightning/pull/18591))
139+
140+
134141
### Changed
135142

136143
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import multiprocessing
1514
import os
1615
from dataclasses import dataclass, field
1716
from typing import Any, Iterable, Optional, Tuple, Union
@@ -25,6 +24,7 @@
2524
_replace_dunder_methods,
2625
_set_sampler_epoch,
2726
has_iterable_dataset,
27+
suggested_max_num_workers,
2828
)
2929
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
3030
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
@@ -420,11 +420,11 @@ def _check_dataloader_iterable(
420420
)
421421

422422

423-
def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None:
423+
def _worker_check(trainer: "pl.Trainer", using_spawn: bool, dataloader: object, name: str) -> None:
424424
if not isinstance(dataloader, DataLoader):
425425
return
426426

427-
num_cpus = multiprocessing.cpu_count()
427+
upper_bound = suggested_max_num_workers(trainer.num_devices)
428428

429429
# ddp_spawn + num_workers > 0 don't mix! tell the user
430430
if dataloader.num_workers > 0 and using_spawn:
@@ -442,14 +442,11 @@ def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None:
442442
"strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks."
443443
" Consider setting num_workers>0 and persistent_workers=True"
444444
)
445-
446-
elif dataloader.num_workers <= 2 < num_cpus and not using_spawn:
445+
elif dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound:
447446
# if changed, update the `filterwarnings` snippet in 'speed.html#num-workers'
448447
rank_zero_warn(
449-
f"The dataloader, {name}, does not have many workers which may be a bottleneck."
450-
" Consider increasing the value of the `num_workers` argument`"
451-
f" (try {num_cpus} which is the number of cpus on this machine)"
452-
" in the `DataLoader` init to improve performance.",
448+
f"The '{name}' does not have many workers which may be a bottleneck. Consider increasing the value of the"
449+
f" `num_workers` argument` to `num_workers={upper_bound}` in the `DataLoader` to improve performance.",
453450
category=PossibleUserWarning,
454451
)
455452

@@ -507,9 +504,10 @@ def _process_dataloader(
507504

508505
# check the workers
509506
_worker_check(
510-
dataloader,
511-
isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn",
512-
f"{stage.dataloader_prefix}_dataloader",
507+
trainer=trainer,
508+
using_spawn=isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn",
509+
dataloader=dataloader,
510+
name=f"{stage.dataloader_prefix}_dataloader",
513511
)
514512

515513
# add worker_init_fn for correct seeding in worker processes

src/lightning/pytorch/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from lightning.fabric.utilities import LightningEnum # noqa: F401
1919
from lightning.fabric.utilities import move_data_to_device # noqa: F401
20+
from lightning.fabric.utilities import suggested_max_num_workers # noqa: F401
2021
from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401
2122
from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401
2223
from lightning.pytorch.utilities.grads import grad_norm # noqa: F401

tests/tests_fabric/utilities/test_data.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import contextlib
2+
import os
23
import random
4+
from unittest import mock
35
from unittest.mock import Mock
46

57
import numpy as np
68
import pytest
79
import torch
10+
from lightning_utilities.test.warning import no_warning_call
811
from torch import Tensor
912
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
1013

14+
import lightning.fabric
1115
from lightning.fabric.utilities.data import (
1216
_get_dataloader_init_args_and_kwargs,
1317
_replace_dunder_methods,
@@ -17,6 +21,7 @@
1721
_WrapAttrTag,
1822
has_iterable_dataset,
1923
has_len,
24+
suggested_max_num_workers,
2025
)
2126
from lightning.fabric.utilities.exceptions import MisconfigurationException
2227
from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset
@@ -575,3 +580,63 @@ def test_set_sampler_epoch():
575580
_set_sampler_epoch(dataloader, 55)
576581
dataloader.sampler.set_epoch.assert_called_once_with(55)
577582
dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)
583+
584+
585+
@pytest.mark.parametrize(
586+
("cpu_count", "local_world_size", "expected"),
587+
[
588+
(0, 1, 1),
589+
(1, 1, 1),
590+
(2, 1, 2),
591+
(1, 2, 1),
592+
(1, 2, 1),
593+
(2, 2, 1),
594+
(3, 2, 1),
595+
(4, 2, 2),
596+
(4, 3, 1),
597+
(4, 1, 4),
598+
],
599+
)
600+
@pytest.mark.parametrize(
601+
"affinity",
602+
[
603+
False,
604+
pytest.param(
605+
True,
606+
marks=pytest.mark.skipif(
607+
not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores"
608+
),
609+
),
610+
],
611+
)
612+
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
613+
def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch):
614+
if affinity:
615+
monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count)))
616+
else:
617+
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
618+
cpu_count_mock.return_value = cpu_count
619+
620+
assert suggested_max_num_workers(local_world_size) == expected
621+
622+
623+
@pytest.mark.parametrize("invalid", [-1, 0])
624+
def test_suggested_max_num_workers_input_validation(invalid):
625+
with pytest.raises(ValueError, match="should be >= 1"):
626+
suggested_max_num_workers(invalid)
627+
628+
629+
@pytest.mark.parametrize("cpu_count", [1, 2, 3])
630+
@pytest.mark.parametrize("local_world_size", [1, 2, 3])
631+
def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch):
632+
"""Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers."""
633+
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
634+
monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False)
635+
monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count)
636+
monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count)
637+
638+
# The dataloader runs a check in `DataLoader.check_worker_number_rationality`
639+
with pytest.warns(UserWarning, match="This DataLoader will create"):
640+
DataLoader(range(2), num_workers=(cpu_count + 1))
641+
with no_warning_call():
642+
DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size))

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from io import StringIO
1616
from re import escape
1717
from typing import Sized
18+
from unittest import mock
1819
from unittest.mock import Mock
1920

2021
import pytest
2122
from lightning_utilities.test.warning import no_warning_call
2223
from torch import Tensor
2324
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler
2425

26+
import lightning.fabric
2527
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
2628
from lightning.fabric.utilities.warnings import PossibleUserWarning
2729
from lightning.pytorch import Trainer
@@ -30,6 +32,7 @@
3032
_check_dataloader_iterable,
3133
_DataHookSelector,
3234
_DataLoaderSource,
35+
_worker_check,
3336
warning_cache,
3437
)
3538
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
@@ -146,6 +149,40 @@ def test_dataloader_warnings(tmpdir, num_workers):
146149
trainer.fit(TestSpawnBoringModel(num_workers))
147150

148151

152+
@pytest.mark.parametrize(
153+
("num_devices", "num_workers", "cpu_count", "expected_warning"),
154+
[
155+
(1, 0, 1, False),
156+
(8, 0, 1, False),
157+
(8, 0, None, False),
158+
(1, 1, None, False),
159+
(1, 2, 2, False),
160+
(1, 1, 8, True),
161+
(1, 2, 8, True),
162+
(1, 3, 8, False),
163+
(4, 1, 8, True),
164+
(4, 2, 8, False),
165+
(8, 2, 8, False),
166+
],
167+
)
168+
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
169+
def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch):
170+
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
171+
trainer = Mock(spec=Trainer)
172+
dataloader = Mock(spec=DataLoader)
173+
trainer.num_devices = num_devices
174+
dataloader.num_workers = num_workers
175+
cpu_count_mock.return_value = cpu_count
176+
177+
if expected_warning:
178+
ctx = pytest.warns(UserWarning, match="Consider increasing the value of the `num_workers` argument`")
179+
else:
180+
ctx = no_warning_call(UserWarning)
181+
182+
with ctx:
183+
_worker_check(trainer, using_spawn=False, dataloader=dataloader, name="train_dataloader")
184+
185+
149186
def test_update_dataloader_raises():
150187
with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"):
151188
_update_dataloader(object(), object(), mode="fit")

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def test_warning_on_zero_len_dataloader():
532532
@RunIf(skip_windows=True)
533533
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
534534
@pytest.mark.parametrize("stage", ["train", "test", "val"])
535-
@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
535+
@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4)
536536
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
537537
"""Test that error is raised if dataloader with only a few workers is used."""
538538
model = BoringModel()
@@ -545,10 +545,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
545545

546546
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
547547

548-
with pytest.warns(
549-
UserWarning,
550-
match=f"The dataloader, {stage}_dataloader, does not have many workers",
551-
):
548+
with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"):
552549
if stage == "test":
553550
if ckpt_path in ("specific", "best"):
554551
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
@@ -561,9 +558,9 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
561558
@RunIf(skip_windows=True)
562559
@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
563560
@pytest.mark.parametrize("stage", ["train", "test", "val"])
564-
@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
561+
@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4)
565562
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
566-
"""Test that error is raised if dataloader with only a few workers is used."""
563+
"""Test that a warning is emitted if the dataloader only has a few workers."""
567564

568565
class CustomModel(MultiEvalDataLoaderModel):
569566
def training_step(self, batch, batch_idx):
@@ -584,7 +581,7 @@ def training_step(self, batch, batch_idx):
584581

585582
with pytest.warns(
586583
UserWarning,
587-
match=f"The dataloader, {stage}_dataloader, does not have many workers",
584+
match=f"The '{stage}_dataloader' does not have many workers",
588585
):
589586
if stage == "test":
590587
if ckpt_path in ("specific", "best"):

0 commit comments

Comments
 (0)