diff --git a/docs/source-fabric/api/utilities.rst b/docs/source-fabric/api/utilities.rst index b4bd1f564131c..bf23827b6dfe8 100644 --- a/docs/source-fabric/api/utilities.rst +++ b/docs/source-fabric/api/utilities.rst @@ -9,3 +9,5 @@ lightning.fabric.utilities .. autofunction:: lightning.fabric.utilities.seed.seed_everything .. autofunction:: lightning.fabric.utilities.seed.pl_worker_init_function + +.. autofunction:: lightning.fabric.utilities.data.suggested_max_num_workers diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 91bd87a3dbc8f..0a44a71ba244b 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -127,6 +127,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled the default process group configuration for FSDP's hybrid sharding ([#18583](https://github.com/Lightning-AI/lightning/pull/18583)) +- Added `lightning.fabric.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) diff --git a/src/lightning/fabric/utilities/__init__.py b/src/lightning/fabric/utilities/__init__.py index 53bd2a4526612..c706d64463189 100644 --- a/src/lightning/fabric/utilities/__init__.py +++ b/src/lightning/fabric/utilities/__init__.py @@ -14,6 +14,7 @@ """General utilities.""" from lightning.fabric.utilities.apply_func import move_data_to_device # noqa: F401 +from lightning.fabric.utilities.data import suggested_max_num_workers # noqa: F401 from lightning.fabric.utilities.enums import LightningEnum # noqa: F401 from lightning.fabric.utilities.rank_zero import ( # noqa: F401 rank_zero_deprecation, diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 7b3af926cfda0..60e666cc4e981 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -433,3 +433,25 @@ def _set_sampler_epoch(dataloader: object, epoch: int) -> None: set_epoch = getattr(obj, "set_epoch", None) if callable(set_epoch): set_epoch(epoch) + + +def suggested_max_num_workers(local_world_size: int) -> int: + """Suggests an upper bound of ``num_workers`` to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on + the number of CPU cores available on the system and the number of distributed processes in the current machine. + + Args: + local_world_size: The number of distributed processes running on the current machine. Set this to the number + of devices configured in Fabric/Trainer. + """ + if local_world_size < 1: + raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") + cpu_count = _num_cpus_available() + return max(1, cpu_count // local_world_size) + + +def _num_cpus_available() -> int: + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + + cpu_count = os.cpu_count() + return 1 if cpu_count is None else cpu_count diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index d42775efd7c51..74d2c352b146f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -131,6 +131,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled the default process group configuration for FSDP's hybrid sharding ([#18583](https://github.com/Lightning-AI/lightning/pull/18583)) + +- Added `lightning.pytorch.utilities.suggested_max_num_workers` to assist with setting a good value in distributed settings ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) + + +- Improved the `num_workers` warning to give a more accurate upper limit on the `num_workers` suggestion ([#18591](https://github.com/Lightning-AI/lightning/pull/18591)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index ede53b999d290..5832559ffd619 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing import os from dataclasses import dataclass, field from typing import Any, Iterable, Optional, Tuple, Union @@ -25,6 +24,7 @@ _replace_dunder_methods, _set_sampler_epoch, has_iterable_dataset, + suggested_max_num_workers, ) from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper @@ -420,11 +420,11 @@ def _check_dataloader_iterable( ) -def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None: +def _worker_check(trainer: "pl.Trainer", using_spawn: bool, dataloader: object, name: str) -> None: if not isinstance(dataloader, DataLoader): return - num_cpus = multiprocessing.cpu_count() + upper_bound = suggested_max_num_workers(trainer.num_devices) # ddp_spawn + num_workers > 0 don't mix! tell the user if dataloader.num_workers > 0 and using_spawn: @@ -442,14 +442,11 @@ def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None: "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." " Consider setting num_workers>0 and persistent_workers=True" ) - - elif dataloader.num_workers <= 2 < num_cpus and not using_spawn: + elif dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound: # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( - f"The dataloader, {name}, does not have many workers which may be a bottleneck." - " Consider increasing the value of the `num_workers` argument`" - f" (try {num_cpus} which is the number of cpus on this machine)" - " in the `DataLoader` init to improve performance.", + f"The '{name}' does not have many workers which may be a bottleneck. Consider increasing the value of the" + f" `num_workers` argument` to `num_workers={upper_bound}` in the `DataLoader` to improve performance.", category=PossibleUserWarning, ) @@ -507,9 +504,10 @@ def _process_dataloader( # check the workers _worker_check( - dataloader, - isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn", - f"{stage.dataloader_prefix}_dataloader", + trainer=trainer, + using_spawn=isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn", + dataloader=dataloader, + name=f"{stage.dataloader_prefix}_dataloader", ) # add worker_init_fn for correct seeding in worker processes diff --git a/src/lightning/pytorch/utilities/__init__.py b/src/lightning/pytorch/utilities/__init__.py index e2e0c0a8d941e..699120da7ea36 100644 --- a/src/lightning/pytorch/utilities/__init__.py +++ b/src/lightning/pytorch/utilities/__init__.py @@ -17,6 +17,7 @@ from lightning.fabric.utilities import LightningEnum # noqa: F401 from lightning.fabric.utilities import move_data_to_device # noqa: F401 +from lightning.fabric.utilities import suggested_max_num_workers # noqa: F401 from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401 from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401 from lightning.pytorch.utilities.grads import grad_norm # noqa: F401 diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 072d7a545677a..5cfcb8e747a85 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -1,13 +1,17 @@ import contextlib +import os import random +from unittest import mock from unittest.mock import Mock import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, RandomSampler +import lightning.fabric from lightning.fabric.utilities.data import ( _get_dataloader_init_args_and_kwargs, _replace_dunder_methods, @@ -17,6 +21,7 @@ _WrapAttrTag, has_iterable_dataset, has_len, + suggested_max_num_workers, ) from lightning.fabric.utilities.exceptions import MisconfigurationException from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset @@ -575,3 +580,63 @@ def test_set_sampler_epoch(): _set_sampler_epoch(dataloader, 55) dataloader.sampler.set_epoch.assert_called_once_with(55) dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55) + + +@pytest.mark.parametrize( + ("cpu_count", "local_world_size", "expected"), + [ + (0, 1, 1), + (1, 1, 1), + (2, 1, 2), + (1, 2, 1), + (1, 2, 1), + (2, 2, 1), + (3, 2, 1), + (4, 2, 2), + (4, 3, 1), + (4, 1, 4), + ], +) +@pytest.mark.parametrize( + "affinity", + [ + False, + pytest.param( + True, + marks=pytest.mark.skipif( + not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores" + ), + ), + ], +) +@mock.patch("lightning.fabric.utilities.data.os.cpu_count") +def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): + if affinity: + monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count))) + else: + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) + cpu_count_mock.return_value = cpu_count + + assert suggested_max_num_workers(local_world_size) == expected + + +@pytest.mark.parametrize("invalid", [-1, 0]) +def test_suggested_max_num_workers_input_validation(invalid): + with pytest.raises(ValueError, match="should be >= 1"): + suggested_max_num_workers(invalid) + + +@pytest.mark.parametrize("cpu_count", [1, 2, 3]) +@pytest.mark.parametrize("local_world_size", [1, 2, 3]) +def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch): + """Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers.""" + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) + monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False) + monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count) + monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count) + + # The dataloader runs a check in `DataLoader.check_worker_number_rationality` + with pytest.warns(UserWarning, match="This DataLoader will create"): + DataLoader(range(2), num_workers=(cpu_count + 1)) + with no_warning_call(): + DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size)) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 7c9dc9126dc0e..d9aaff068764a 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -15,6 +15,7 @@ from io import StringIO from re import escape from typing import Sized +from unittest import mock from unittest.mock import Mock import pytest @@ -22,6 +23,7 @@ from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler +import lightning.fabric from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -30,6 +32,7 @@ _check_dataloader_iterable, _DataHookSelector, _DataLoaderSource, + _worker_check, warning_cache, ) from lightning.pytorch.trainer.states import RunningStage, TrainerFn @@ -146,6 +149,40 @@ def test_dataloader_warnings(tmpdir, num_workers): trainer.fit(TestSpawnBoringModel(num_workers)) +@pytest.mark.parametrize( + ("num_devices", "num_workers", "cpu_count", "expected_warning"), + [ + (1, 0, 1, False), + (8, 0, 1, False), + (8, 0, None, False), + (1, 1, None, False), + (1, 2, 2, False), + (1, 1, 8, True), + (1, 2, 8, True), + (1, 3, 8, False), + (4, 1, 8, True), + (4, 2, 8, False), + (8, 2, 8, False), + ], +) +@mock.patch("lightning.fabric.utilities.data.os.cpu_count") +def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) + trainer = Mock(spec=Trainer) + dataloader = Mock(spec=DataLoader) + trainer.num_devices = num_devices + dataloader.num_workers = num_workers + cpu_count_mock.return_value = cpu_count + + if expected_warning: + ctx = pytest.warns(UserWarning, match="Consider increasing the value of the `num_workers` argument`") + else: + ctx = no_warning_call(UserWarning) + + with ctx: + _worker_check(trainer, using_spawn=False, dataloader=dataloader, name="train_dataloader") + + def test_update_dataloader_raises(): with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"): _update_dataloader(object(), object(), mode="fit") diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index a0cd3fdf74cdb..19515cd5399e5 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -532,7 +532,7 @@ def test_warning_on_zero_len_dataloader(): @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("stage", ["train", "test", "val"]) -@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4) +@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4) def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """Test that error is raised if dataloader with only a few workers is used.""" model = BoringModel() @@ -545,10 +545,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) - with pytest.warns( - UserWarning, - match=f"The dataloader, {stage}_dataloader, does not have many workers", - ): + with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): if stage == "test": if ckpt_path in ("specific", "best"): trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) @@ -561,9 +558,9 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("stage", ["train", "test", "val"]) -@patch("lightning.pytorch.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4) +@patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4) def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): - """Test that error is raised if dataloader with only a few workers is used.""" + """Test that a warning is emitted if the dataloader only has a few workers.""" class CustomModel(MultiEvalDataLoaderModel): def training_step(self, batch, batch_idx): @@ -584,7 +581,7 @@ def training_step(self, batch, batch_idx): with pytest.warns( UserWarning, - match=f"The dataloader, {stage}_dataloader, does not have many workers", + match=f"The '{stage}_dataloader' does not have many workers", ): if stage == "test": if ckpt_path in ("specific", "best"):