From f553a82aa2134bd637090fa3e92bf438140abcd8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 9 Sep 2022 23:03:00 +0200 Subject: [PATCH 01/13] Attempt to query device count via NVML --- src/lightning_lite/utilities/device_parser.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index 78bf8a9a8c93f..a1c9117831bdc 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,9 +1,11 @@ import multiprocessing import os +from functools import lru_cache from typing import Any, List, MutableSequence, Optional, Tuple, Union - +from ctypes import CDLL, c_int import torch + # TODO(lite): Fix the imports # from lightning_lite.plugins.environments import TorchElasticEnvironment # from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled @@ -286,16 +288,30 @@ def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0] +@lru_cache def num_cuda_devices() -> int: - """Returns the number of GPUs available. + """Returns the number of available CUDA devices. Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): + + # Code adapted from Nikita Shulga, @malfet: + # https://github.com/pytorch/pytorch/issues/83973#issuecomment-1238633755 + try: + nvml_h = CDLL("libnvidia-ml.so.1") + except OSError: + # Can't load NVML shared libraries, fall back to torch.cuda return torch.cuda.device_count() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.device_count) + + rc = nvml_h.nvmlInit() + assert rc == 0 + dev_arr = (c_int * 1)(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr) + assert rc == 0 + del nvml_h + + return dev_arr[0] def is_cuda_available() -> bool: @@ -304,10 +320,7 @@ def is_cuda_available() -> bool: Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): - return torch.cuda.is_available() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.is_available) + return num_cuda_devices() > 0 # TODO(lite): move this back to launchers/multiprocessing.py once launchers have moved From ba860756f6fd2e03106cacb193a46e59e93aaccc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 9 Sep 2022 23:05:32 +0200 Subject: [PATCH 02/13] remove fork disablers --- src/lightning_lite/utilities/device_parser.py | 11 ++--------- .../strategies/launchers/multiprocessing.py | 9 --------- .../trainer/connectors/accelerator_connector.py | 5 ----- .../strategies/launchers/test_multiprocessing.py | 10 ---------- .../trainer/connectors/test_accelerator_connector.py | 9 --------- 5 files changed, 2 insertions(+), 42 deletions(-) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index a1c9117831bdc..4943eed5595e8 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,14 +1,13 @@ import multiprocessing import os +from ctypes import c_int, CDLL from functools import lru_cache from typing import Any, List, MutableSequence, Optional, Tuple, Union -from ctypes import CDLL, c_int -import torch +import torch # TODO(lite): Fix the imports # from lightning_lite.plugins.environments import TorchElasticEnvironment -# from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.types import _DEVICE @@ -321,9 +320,3 @@ def is_cuda_available() -> bool: if the platform allows it. """ return num_cuda_devices() > 0 - - -# TODO(lite): move this back to launchers/multiprocessing.py once launchers have moved -def _is_forking_disabled() -> bool: - """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" - return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 31508067abf36..4a008a5351df8 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -67,10 +67,6 @@ def __init__(self, strategy: Strategy, start_method: Literal["spawn", "fork", "f f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - if start_method in ("fork", "forkserver") and _is_forking_disabled(): - raise ValueError( - "Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method." - ) @property def is_interactive_compatible(self) -> bool: @@ -286,8 +282,3 @@ def restore(self) -> None: torch.use_deterministic_algorithms(self.use_deterministic_algorithms) torch.backends.cudnn.benchmark = self.cudnn_benchmark _set_rng_states(self.rng_states) - - -def _is_forking_disabled() -> bool: - """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" - return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index d50fc86140da2..9f76b84c0a3ec 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -74,7 +74,6 @@ TPUSpawnStrategy, ) from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES -from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( @@ -638,10 +637,6 @@ def _check_strategy_and_fallback(self) -> None: f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this" f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead." ) - if strategy_flag in _DDP_FORK_ALIASES and _is_forking_disabled(): - raise ValueError( - "Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different strategy." - ) if strategy_flag: self._strategy_flag = strategy_flag diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 4859ddba60574..ad3e891ad607f 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from unittest import mock from unittest.mock import ANY, Mock @@ -19,7 +18,6 @@ import torch from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher -from tests_pytorch.helpers.runif import RunIf @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) @@ -28,14 +26,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_): _MultiProcessingLauncher(strategy=Mock(), start_method="fork") -@RunIf(skip_windows=True) -@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_multiprocessing_launcher_disabled_forking(start_method): - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) - - @pytest.mark.parametrize("start_method", ["spawn", "fork"]) @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp") def test_multiprocessing_launcher_start_method(mp_mock, start_method): diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 6625f191c3190..1640b60644342 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -812,12 +812,3 @@ def test_accelerator_specific_checkpoint_io(*_): def test_ddp_fork_on_unsupported_platform(_, strategy): with pytest.raises(ValueError, match="process forking is not supported on this platform"): Trainer(strategy=strategy) - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy): - """Test there is an error when forking is disabled via the environment variable and the user requests fork.""" - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - Trainer(devices=2, strategy=strategy) From 64d03724e9122522183efb15a6ff8298289a30ae Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 10 Sep 2022 00:36:30 +0200 Subject: [PATCH 03/13] remove unused imports --- src/lightning_lite/utilities/device_parser.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index 4943eed5595e8..1f1217e176ca3 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,5 +1,3 @@ -import multiprocessing -import os from ctypes import c_int, CDLL from functools import lru_cache from typing import Any, List, MutableSequence, Optional, Tuple, Union From e3cf7082d2868026eb222a62a59caf520e263688 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 14 Sep 2022 19:33:42 +0200 Subject: [PATCH 04/13] add upstream implementation --- src/lightning_lite/utilities/device_parser.py | 82 +++++++++++++++---- 1 file changed, 65 insertions(+), 17 deletions(-) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index 1f1217e176ca3..ebfdfa8c23d4b 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,6 +1,8 @@ +import os +import warnings from ctypes import c_int, CDLL from functools import lru_cache -from typing import Any, List, MutableSequence, Optional, Tuple, Union +from typing import Any, List, MutableSequence, Optional, Tuple, Union, Set import torch @@ -8,6 +10,7 @@ # from lightning_lite.plugins.environments import TorchElasticEnvironment from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.types import _DEVICE +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13 def determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: @@ -285,30 +288,20 @@ def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0] -@lru_cache +@lru_cache(1) def num_cuda_devices() -> int: """Returns the number of available CUDA devices. Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """ - - # Code adapted from Nikita Shulga, @malfet: - # https://github.com/pytorch/pytorch/issues/83973#issuecomment-1238633755 - try: - nvml_h = CDLL("libnvidia-ml.so.1") - except OSError: - # Can't load NVML shared libraries, fall back to torch.cuda + if _TORCH_GREATER_EQUAL_1_13: return torch.cuda.device_count() - rc = nvml_h.nvmlInit() - assert rc == 0 - dev_arr = (c_int * 1)(-1) - rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr) - assert rc == 0 - del nvml_h - - return dev_arr[0] + # Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 + # TODO: Remove once minimum supported PyTorch version is 1.13 + nvml_count = _device_count_nvml() + return torch.cuda.device_count() if nvml_count < 0 else nvml_count def is_cuda_available() -> bool: @@ -318,3 +311,58 @@ def is_cuda_available() -> bool: if the platform allows it. """ return num_cuda_devices() > 0 + + +def _parse_visible_devices() -> Set[int]: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879""" + var = os.getenv("CUDA_VISIBLE_DEVICES") + if var is None: + return set(x for x in range(64)) + + def _strtoul(s: str) -> int: + """ Return -1 or integer sequence string starts with """ + if len(s) == 0: + return -1 + for idx, c in enumerate(s): + if not c.isdigit(): + break + if idx + 1 == len(s): + idx += 1 + return int(s[:idx]) if idx > 0 else -1 + + # CUDA_VISIBLE_DEVICES uses something like strtoul + # which makes `1gpu2,2ampere` is equivalent to `1,2` + rc: Set[int] = set() + for elem in var.split(","): + rc.add(_strtoul(elem.strip())) + return rc + + +def _raw_device_count_nvml() -> int: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879""" + from ctypes import CDLL, c_int + nvml_h = CDLL("libnvidia-ml.so.1") + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML") + return -1 + dev_arr = (c_int * 1)(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr) + if rc != 0: + warnings.warn("Can't get nvml device count") + return -1 + del nvml_h + return dev_arr[0] + + +def _device_count_nvml() -> int: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879""" + try: + raw_cnt = _raw_device_count_nvml() + if raw_cnt <= 0: + return raw_cnt + return len(set(range(raw_cnt)).intersection(_parse_visible_devices())) + except OSError: + return -1 + except AttributeError: + return -1 From 6396bcf85ca4968c5628f8efe00283ec5f2e9e93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Sep 2022 18:36:40 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/utilities/device_parser.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index c1056b257b1a5..4c5bc2b1858d5 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,14 +1,14 @@ import os import warnings from functools import lru_cache -from typing import Any, List, MutableSequence, Optional, Tuple, Union, Set +from typing import Any, List, MutableSequence, Optional, Set, Tuple, Union import torch from lightning_lite.plugins.environments.torchelastic_environment import TorchElasticEnvironment from lightning_lite.utilities.exceptions import MisconfigurationException -from lightning_lite.utilities.types import _DEVICE from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13 +from lightning_lite.utilities.types import _DEVICE def determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: @@ -312,13 +312,13 @@ def is_cuda_available() -> bool: def _parse_visible_devices() -> Set[int]: - """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879""" + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" var = os.getenv("CUDA_VISIBLE_DEVICES") if var is None: - return set(x for x in range(64)) + return {x for x in range(64)} def _strtoul(s: str) -> int: - """ Return -1 or integer sequence string starts with """ + """Return -1 or integer sequence string starts with.""" if len(s) == 0: return -1 for idx, c in enumerate(s): @@ -337,8 +337,9 @@ def _strtoul(s: str) -> int: def _raw_device_count_nvml() -> int: - """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879""" - from ctypes import CDLL, c_int + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" + from ctypes import c_int, CDLL + nvml_h = CDLL("libnvidia-ml.so.1") rc = nvml_h.nvmlInit() if rc != 0: @@ -354,7 +355,7 @@ def _raw_device_count_nvml() -> int: def _device_count_nvml() -> int: - """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879""" + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" try: raw_cnt = _raw_device_count_nvml() if raw_cnt <= 0: From bcf45e8b5200eebb3dea54e0c233a94eb8737355 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 18 Sep 2022 23:01:54 +0200 Subject: [PATCH 06/13] update test --- .../utilities/test_device_parser.py | 14 +++------ .../utilities/test_device_parser.py | 31 ------------------- 2 files changed, 5 insertions(+), 40 deletions(-) delete mode 100644 tests/tests_pytorch/utilities/test_device_parser.py diff --git a/tests/tests_lite/utilities/test_device_parser.py b/tests/tests_lite/utilities/test_device_parser.py index 09e35fb61d51c..892f3f3e23d60 100644 --- a/tests/tests_lite/utilities/test_device_parser.py +++ b/tests/tests_lite/utilities/test_device_parser.py @@ -14,7 +14,6 @@ from unittest import mock import pytest -import torch from lightning_lite.utilities import device_parser from lightning_lite.utilities.exceptions import MisconfigurationException @@ -102,16 +101,13 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun device_parser.parse_gpu_ids(devices, include_cuda=True) -@pytest.mark.skipif( - "fork" in torch.multiprocessing.get_all_start_methods(), reason="Requires platform without forking support" -) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser._device_count_nvml", return_value=-1) +@mock.patch("torch.cuda.device_count", return_value=100) def test_num_cuda_devices_without_forking(*_): - """This merely tests that on platforms without fork support our helper functions fall back to the default - implementation for determining cuda availability.""" + """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for determining + CUDA availability.""" assert device_parser.is_cuda_available() - assert device_parser.num_cuda_devices() == 2 + assert device_parser.num_cuda_devices() == 100 @pytest.mark.parametrize("devices", ([3], -1)) diff --git a/tests/tests_pytorch/utilities/test_device_parser.py b/tests/tests_pytorch/utilities/test_device_parser.py deleted file mode 100644 index a4a84892a6e8d..0000000000000 --- a/tests/tests_pytorch/utilities/test_device_parser.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest import mock - -import pytest -import torch - -from lightning_lite.utilities import device_parser - - -@pytest.mark.skipif( - "fork" in torch.multiprocessing.get_all_start_methods(), reason="Requires platform without forking support" -) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=2) -def test_num_cuda_devices_without_forking(*_): - """This merely tests that on platforms without fork support our helper functions fall back to the default - implementation for determining cuda availability.""" - assert device_parser.is_cuda_available() - assert device_parser.num_cuda_devices() == 2 From 9a882be06d738eb3451a891ad6480f83586a8ff8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Sep 2022 21:03:41 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/utilities/test_device_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_lite/utilities/test_device_parser.py b/tests/tests_lite/utilities/test_device_parser.py index 892f3f3e23d60..af4cf7b9776ca 100644 --- a/tests/tests_lite/utilities/test_device_parser.py +++ b/tests/tests_lite/utilities/test_device_parser.py @@ -104,8 +104,8 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun @mock.patch("lightning_lite.utilities.device_parser._device_count_nvml", return_value=-1) @mock.patch("torch.cuda.device_count", return_value=100) def test_num_cuda_devices_without_forking(*_): - """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for determining - CUDA availability.""" + """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for + determining CUDA availability.""" assert device_parser.is_cuda_available() assert device_parser.num_cuda_devices() == 100 From 01076bad934ff9a628231aae3f5abb095c8a01a8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Sep 2022 18:20:20 +0200 Subject: [PATCH 08/13] wip --- src/lightning_lite/accelerators/cuda.py | 81 +++++++++++++++--- src/lightning_lite/utilities/device_parser.py | 82 ------------------- 2 files changed, 70 insertions(+), 93 deletions(-) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index f2b412a9713a5..02c1a027c5632 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing -from typing import Dict, List, Optional, Union +import os +import warnings +from functools import lru_cache +from typing import Dict, List, Optional, Union, Set import torch from lightning_lite.accelerators.accelerator import Accelerator -from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13 class CUDAAccelerator(Accelerator): @@ -75,16 +77,20 @@ def _get_all_available_cuda_gpus() -> List[int]: return list(range(num_cuda_devices())) +@lru_cache(1) def num_cuda_devices() -> int: - """Returns the number of GPUs available. + """Returns the number of available CUDA devices. Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): + if _TORCH_GREATER_EQUAL_1_13: return torch.cuda.device_count() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.device_count) + + # Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 + # TODO: Remove once minimum supported PyTorch version is 1.13 + nvml_count = _device_count_nvml() + return torch.cuda.device_count() if nvml_count < 0 else nvml_count def is_cuda_available() -> bool: @@ -93,7 +99,60 @@ def is_cuda_available() -> bool: Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): - return torch.cuda.is_available() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.is_available) + return num_cuda_devices() > 0 + + +def _parse_visible_devices() -> Set[int]: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" + var = os.getenv("CUDA_VISIBLE_DEVICES") + if var is None: + return {x for x in range(64)} + + def _strtoul(s: str) -> int: + """Return -1 or integer sequence string starts with.""" + if len(s) == 0: + return -1 + for idx, c in enumerate(s): + if not c.isdigit(): + break + if idx + 1 == len(s): + idx += 1 + return int(s[:idx]) if idx > 0 else -1 + + # CUDA_VISIBLE_DEVICES uses something like strtoul + # which makes `1gpu2,2ampere` is equivalent to `1,2` + rc: Set[int] = set() + for elem in var.split(","): + rc.add(_strtoul(elem.strip())) + return rc + + +def _raw_device_count_nvml() -> int: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" + from ctypes import c_int, CDLL + + nvml_h = CDLL("libnvidia-ml.so.1") + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML") + return -1 + dev_arr = (c_int * 1)(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr) + if rc != 0: + warnings.warn("Can't get nvml device count") + return -1 + del nvml_h + return dev_arr[0] + + +def _device_count_nvml() -> int: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" + try: + raw_cnt = _raw_device_count_nvml() + if raw_cnt <= 0: + return raw_cnt + return len(set(range(raw_cnt)).intersection(_parse_visible_devices())) + except OSError: + return -1 + except AttributeError: + return -1 diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index b40508934f0df..fd7a6bf0edc43 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -4,7 +4,6 @@ from lightning_lite.accelerators.mps import _get_all_available_mps_gpus from lightning_lite.plugins.environments.torchelastic_environment import TorchElasticEnvironment from lightning_lite.utilities.exceptions import MisconfigurationException -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13 from lightning_lite.utilities.types import _DEVICE @@ -210,84 +209,3 @@ def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: if tpu_cores in ("1", "8"): return int(tpu_cores) return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0] - - -@lru_cache(1) -def num_cuda_devices() -> int: - """Returns the number of available CUDA devices. - - Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, - if the platform allows it. - """ - if _TORCH_GREATER_EQUAL_1_13: - return torch.cuda.device_count() - - # Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 - # TODO: Remove once minimum supported PyTorch version is 1.13 - nvml_count = _device_count_nvml() - return torch.cuda.device_count() if nvml_count < 0 else nvml_count - - -def is_cuda_available() -> bool: - """Returns a bool indicating if CUDA is currently available. - - Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support, - if the platform allows it. - """ - return num_cuda_devices() > 0 - - -def _parse_visible_devices() -> Set[int]: - """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" - var = os.getenv("CUDA_VISIBLE_DEVICES") - if var is None: - return {x for x in range(64)} - - def _strtoul(s: str) -> int: - """Return -1 or integer sequence string starts with.""" - if len(s) == 0: - return -1 - for idx, c in enumerate(s): - if not c.isdigit(): - break - if idx + 1 == len(s): - idx += 1 - return int(s[:idx]) if idx > 0 else -1 - - # CUDA_VISIBLE_DEVICES uses something like strtoul - # which makes `1gpu2,2ampere` is equivalent to `1,2` - rc: Set[int] = set() - for elem in var.split(","): - rc.add(_strtoul(elem.strip())) - return rc - - -def _raw_device_count_nvml() -> int: - """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" - from ctypes import c_int, CDLL - - nvml_h = CDLL("libnvidia-ml.so.1") - rc = nvml_h.nvmlInit() - if rc != 0: - warnings.warn("Can't initialize NVML") - return -1 - dev_arr = (c_int * 1)(-1) - rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr) - if rc != 0: - warnings.warn("Can't get nvml device count") - return -1 - del nvml_h - return dev_arr[0] - - -def _device_count_nvml() -> int: - """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" - try: - raw_cnt = _raw_device_count_nvml() - if raw_cnt <= 0: - return raw_cnt - return len(set(range(raw_cnt)).intersection(_parse_visible_devices())) - except OSError: - return -1 - except AttributeError: - return -1 From 7c92ae7c821b5f67830a79ab0928d98215c563e9 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Sep 2022 18:26:45 +0200 Subject: [PATCH 09/13] resolve merge conflicts --- src/lightning_lite/accelerators/cuda.py | 2 +- tests/tests_lite/accelerators/test_cpu.py | 9 ++++++++- tests/tests_lite/accelerators/test_cuda.py | 11 ++++++++++- .../tests_lite/utilities/test_device_parser.py | 18 ------------------ 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 02c1a027c5632..9179a0015548c 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -14,7 +14,7 @@ import os import warnings from functools import lru_cache -from typing import Dict, List, Optional, Union, Set +from typing import Dict, List, Optional, Set, Union import torch diff --git a/tests/tests_lite/accelerators/test_cpu.py b/tests/tests_lite/accelerators/test_cpu.py index f6ccba3560b96..012efe9f9aa26 100644 --- a/tests/tests_lite/accelerators/test_cpu.py +++ b/tests/tests_lite/accelerators/test_cpu.py @@ -15,7 +15,7 @@ import pytest import torch -from lightning_lite.accelerators.cpu import CPUAccelerator +from lightning_lite.accelerators.cpu import CPUAccelerator, parse_cpu_cores def test_auto_device_count(): @@ -41,3 +41,10 @@ def test_init_device_with_wrong_device_type(): ) def test_get_parallel_devices(devices, expected): assert CPUAccelerator.get_parallel_devices(devices) == expected + + +@pytest.mark.parametrize("devices", ([3], -1)) +def test_invalid_devices_with_cpu_accelerator(devices): + """Test invalid device flag raises MisconfigurationException.""" + with pytest.raises(TypeError, match="should be an int > 0"): + parse_cpu_cores(devices) diff --git a/tests/tests_lite/accelerators/test_cuda.py b/tests/tests_lite/accelerators/test_cuda.py index 1c2c7a8ac33d8..f3d25f9bfb750 100644 --- a/tests/tests_lite/accelerators/test_cuda.py +++ b/tests/tests_lite/accelerators/test_cuda.py @@ -17,7 +17,7 @@ import torch from tests_lite.helpers.runif import RunIf -from lightning_lite.accelerators.cuda import CUDAAccelerator +from lightning_lite.accelerators.cuda import CUDAAccelerator, is_cuda_available, num_cuda_devices @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2) @@ -51,3 +51,12 @@ def test_get_parallel_devices(devices, expected): def test_set_cuda_device(set_device_mock): CUDAAccelerator().setup_device(torch.device("cuda", 1)) set_device_mock.assert_called_once_with(torch.device("cuda", 1)) + + +@mock.patch("lightning_lite.accelerators.cuda._device_count_nvml", return_value=-1) +@mock.patch("torch.cuda.device_count", return_value=100) +def test_num_cuda_devices_without_forking(*_): + """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for + determining CUDA availability.""" + assert is_cuda_available() + assert num_cuda_devices() == 100 diff --git a/tests/tests_lite/utilities/test_device_parser.py b/tests/tests_lite/utilities/test_device_parser.py index 7d73cc414de3d..3c78e30f75d5d 100644 --- a/tests/tests_lite/utilities/test_device_parser.py +++ b/tests/tests_lite/utilities/test_device_parser.py @@ -15,8 +15,6 @@ import pytest -from lightning_lite.accelerators.cpu import parse_cpu_cores -from lightning_lite.accelerators.cuda import is_cuda_available, num_cuda_devices from lightning_lite.utilities import device_parser from lightning_lite.utilities.exceptions import MisconfigurationException @@ -86,19 +84,3 @@ def test_parse_gpu_fail_on_non_existent_id_2(_): def test_parse_gpu_returns_none_when_no_devices_are_available(_, devices): with pytest.raises(MisconfigurationException): device_parser.parse_gpu_ids(devices, include_cuda=True) - - -@mock.patch("lightning_lite.utilities.device_parser._device_count_nvml", return_value=-1) -@mock.patch("torch.cuda.device_count", return_value=100) -def test_num_cuda_devices_without_forking(*_): - """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for - determining CUDA availability.""" - assert device_parser.is_cuda_available() - assert device_parser.num_cuda_devices() == 100 - - -@pytest.mark.parametrize("devices", ([3], -1)) -def test_invalid_devices_with_cpu_accelerator(devices): - """Test invalid device flag raises MisconfigurationException.""" - with pytest.raises(TypeError, match="should be an int > 0"): - parse_cpu_cores(devices) From c17a2862494bd7cc116b9bda771a6c43bb3caba9 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Sep 2022 18:28:15 +0200 Subject: [PATCH 10/13] merge conflicts --- src/lightning_lite/utilities/device_parser.py | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index fd7a6bf0edc43..9c0feec8f7275 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, List, MutableSequence, Optional, Tuple, Union from lightning_lite.accelerators.cuda import _get_all_available_cuda_gpus @@ -187,25 +200,3 @@ def _check_data_type(device_ids: Any) -> None: raise MisconfigurationException(f"{msg} a sequence of {type(id_).__name__}.") elif type(device_ids) not in (int, str): raise MisconfigurationException(f"{msg} {type(device_ids).__name__}.") - - -def _tpu_cores_valid(tpu_cores: Any) -> bool: - # allow 1 or 8 cores - if tpu_cores in (1, 8, None): - return True - - # allow picking 1 of 8 indexes - if isinstance(tpu_cores, (list, tuple, set)): - has_1_tpu_idx = len(tpu_cores) == 1 - is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8 - - is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx - return is_valid_tpu_core_choice - - return False - - -def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: - if tpu_cores in ("1", "8"): - return int(tpu_cores) - return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0] From 9ec7ac64e299ffdd4b48da46fbc86bc7cd99e093 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Sep 2022 18:29:25 +0200 Subject: [PATCH 11/13] rename --- tests/tests_lite/accelerators/test_cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/accelerators/test_cuda.py b/tests/tests_lite/accelerators/test_cuda.py index f3d25f9bfb750..94e51d77358e6 100644 --- a/tests/tests_lite/accelerators/test_cuda.py +++ b/tests/tests_lite/accelerators/test_cuda.py @@ -55,7 +55,7 @@ def test_set_cuda_device(set_device_mock): @mock.patch("lightning_lite.accelerators.cuda._device_count_nvml", return_value=-1) @mock.patch("torch.cuda.device_count", return_value=100) -def test_num_cuda_devices_without_forking(*_): +def test_num_cuda_devices_without_nvml(*_): """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for determining CUDA availability.""" assert is_cuda_available() From 9e7ebc8cfea011b7ebf994bdb216fb1026ee044c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Sep 2022 18:32:44 +0200 Subject: [PATCH 12/13] changelog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index fe7cdeceff1eb..badd62fcd0539 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300)) +- Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the `PL_DISABLE_FORK` environment variable introduced in v1.7.4 ([#14631](https://github.com/Lightning-AI/lightning/issues/14631)) + + ### Deprecated - Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000)) From 71ae9ab74fc5a017cdbfdf0ed68e974198594760 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 22 Sep 2022 11:23:16 +0200 Subject: [PATCH 13/13] remove it also from lite --- .../strategies/launchers/multiprocessing.py | 9 --------- .../strategies/launchers/test_multiprocessing.py | 9 --------- tests/tests_lite/test_connector.py | 9 --------- 3 files changed, 27 deletions(-) diff --git a/src/lightning_lite/strategies/launchers/multiprocessing.py b/src/lightning_lite/strategies/launchers/multiprocessing.py index d416efee56185..20cf765f76187 100644 --- a/src/lightning_lite/strategies/launchers/multiprocessing.py +++ b/src/lightning_lite/strategies/launchers/multiprocessing.py @@ -63,10 +63,6 @@ def __init__( f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - if start_method in ("fork", "forkserver") and _is_forking_disabled(): - raise ValueError( - "Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method." - ) @property def is_interactive_compatible(self) -> bool: @@ -170,8 +166,3 @@ def restore(self) -> None: torch.use_deterministic_algorithms(self.use_deterministic_algorithms) torch.backends.cudnn.benchmark = self.cudnn_benchmark _set_rng_states(self.rng_states) - - -def _is_forking_disabled() -> bool: - """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" - return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/tests/tests_lite/strategies/launchers/test_multiprocessing.py b/tests/tests_lite/strategies/launchers/test_multiprocessing.py index 70b45763fe2df..fef19f06715ab 100644 --- a/tests/tests_lite/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_lite/strategies/launchers/test_multiprocessing.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from unittest import mock from unittest.mock import ANY, Mock @@ -35,14 +34,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_): _MultiProcessingLauncher(strategy=Mock(), start_method="fork") -@RunIf(skip_windows=True) -@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_multiprocessing_launcher_disabled_forking(start_method): - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) - - @pytest.mark.parametrize("start_method", ["spawn", "fork"]) @mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp") def test_multiprocessing_launcher_start_method(mp_mock, start_method): diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index 09b0f359ec0d8..258934f1ca847 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -692,12 +692,3 @@ def test_gpu_accelerator_no_gpu_backend_found_error(*_): def test_ddp_fork_on_unsupported_platform(_, strategy): with pytest.raises(ValueError, match="process forking is not supported on this platform"): _Connector(strategy=strategy) - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy): - """Test there is an error when forking is disabled via the environment variable and the user requests fork.""" - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - _Connector(devices=2, strategy=strategy)