diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index f2b412a9713a5..9179a0015548c 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing -from typing import Dict, List, Optional, Union +import os +import warnings +from functools import lru_cache +from typing import Dict, List, Optional, Set, Union import torch from lightning_lite.accelerators.accelerator import Accelerator -from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13 class CUDAAccelerator(Accelerator): @@ -75,16 +77,20 @@ def _get_all_available_cuda_gpus() -> List[int]: return list(range(num_cuda_devices())) +@lru_cache(1) def num_cuda_devices() -> int: - """Returns the number of GPUs available. + """Returns the number of available CUDA devices. Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): + if _TORCH_GREATER_EQUAL_1_13: return torch.cuda.device_count() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.device_count) + + # Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 + # TODO: Remove once minimum supported PyTorch version is 1.13 + nvml_count = _device_count_nvml() + return torch.cuda.device_count() if nvml_count < 0 else nvml_count def is_cuda_available() -> bool: @@ -93,7 +99,60 @@ def is_cuda_available() -> bool: Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): - return torch.cuda.is_available() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.is_available) + return num_cuda_devices() > 0 + + +def _parse_visible_devices() -> Set[int]: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" + var = os.getenv("CUDA_VISIBLE_DEVICES") + if var is None: + return {x for x in range(64)} + + def _strtoul(s: str) -> int: + """Return -1 or integer sequence string starts with.""" + if len(s) == 0: + return -1 + for idx, c in enumerate(s): + if not c.isdigit(): + break + if idx + 1 == len(s): + idx += 1 + return int(s[:idx]) if idx > 0 else -1 + + # CUDA_VISIBLE_DEVICES uses something like strtoul + # which makes `1gpu2,2ampere` is equivalent to `1,2` + rc: Set[int] = set() + for elem in var.split(","): + rc.add(_strtoul(elem.strip())) + return rc + + +def _raw_device_count_nvml() -> int: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" + from ctypes import c_int, CDLL + + nvml_h = CDLL("libnvidia-ml.so.1") + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML") + return -1 + dev_arr = (c_int * 1)(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr) + if rc != 0: + warnings.warn("Can't get nvml device count") + return -1 + del nvml_h + return dev_arr[0] + + +def _device_count_nvml() -> int: + """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" + try: + raw_cnt = _raw_device_count_nvml() + if raw_cnt <= 0: + return raw_cnt + return len(set(range(raw_cnt)).intersection(_parse_visible_devices())) + except OSError: + return -1 + except AttributeError: + return -1 diff --git a/src/lightning_lite/strategies/launchers/multiprocessing.py b/src/lightning_lite/strategies/launchers/multiprocessing.py index d416efee56185..20cf765f76187 100644 --- a/src/lightning_lite/strategies/launchers/multiprocessing.py +++ b/src/lightning_lite/strategies/launchers/multiprocessing.py @@ -63,10 +63,6 @@ def __init__( f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - if start_method in ("fork", "forkserver") and _is_forking_disabled(): - raise ValueError( - "Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method." - ) @property def is_interactive_compatible(self) -> bool: @@ -170,8 +166,3 @@ def restore(self) -> None: torch.use_deterministic_algorithms(self.use_deterministic_algorithms) torch.backends.cudnn.benchmark = self.cudnn_benchmark _set_rng_states(self.rng_states) - - -def _is_forking_disabled() -> bool: - """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" - return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index 8a04e9b625e25..9c0feec8f7275 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, List, MutableSequence, Optional, Tuple, Union from lightning_lite.accelerators.cuda import _get_all_available_cuda_gpus diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 98a5566a2a484..9528ca7d8b492 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -81,8 +81,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300)) -- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292)) +- Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the `PL_DISABLE_FORK` environment variable introduced in v1.7.4 ([#14631](https://github.com/Lightning-AI/lightning/issues/14631)) + +- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292)) ### Deprecated diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 1eb036cfee81d..dc5916e3e22ee 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -68,10 +68,6 @@ def __init__( f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - if start_method in ("fork", "forkserver") and _is_forking_disabled(): - raise ValueError( - "Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method." - ) @property def is_interactive_compatible(self) -> bool: @@ -287,8 +283,3 @@ def restore(self) -> None: torch.use_deterministic_algorithms(self.use_deterministic_algorithms) torch.backends.cudnn.benchmark = self.cudnn_benchmark _set_rng_states(self.rng_states) - - -def _is_forking_disabled() -> bool: - """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" - return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index 4bef4f876796b..c74d1144f3c9a 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -75,7 +75,6 @@ TPUSpawnStrategy, ) from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES -from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( @@ -632,10 +631,6 @@ def _check_strategy_and_fallback(self) -> None: f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this" f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead." ) - if strategy_flag in _DDP_FORK_ALIASES and _is_forking_disabled(): - raise ValueError( - "Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different strategy." - ) if strategy_flag: self._strategy_flag = strategy_flag diff --git a/tests/tests_lite/accelerators/test_cpu.py b/tests/tests_lite/accelerators/test_cpu.py index f6ccba3560b96..012efe9f9aa26 100644 --- a/tests/tests_lite/accelerators/test_cpu.py +++ b/tests/tests_lite/accelerators/test_cpu.py @@ -15,7 +15,7 @@ import pytest import torch -from lightning_lite.accelerators.cpu import CPUAccelerator +from lightning_lite.accelerators.cpu import CPUAccelerator, parse_cpu_cores def test_auto_device_count(): @@ -41,3 +41,10 @@ def test_init_device_with_wrong_device_type(): ) def test_get_parallel_devices(devices, expected): assert CPUAccelerator.get_parallel_devices(devices) == expected + + +@pytest.mark.parametrize("devices", ([3], -1)) +def test_invalid_devices_with_cpu_accelerator(devices): + """Test invalid device flag raises MisconfigurationException.""" + with pytest.raises(TypeError, match="should be an int > 0"): + parse_cpu_cores(devices) diff --git a/tests/tests_lite/accelerators/test_cuda.py b/tests/tests_lite/accelerators/test_cuda.py index 1c2c7a8ac33d8..94e51d77358e6 100644 --- a/tests/tests_lite/accelerators/test_cuda.py +++ b/tests/tests_lite/accelerators/test_cuda.py @@ -17,7 +17,7 @@ import torch from tests_lite.helpers.runif import RunIf -from lightning_lite.accelerators.cuda import CUDAAccelerator +from lightning_lite.accelerators.cuda import CUDAAccelerator, is_cuda_available, num_cuda_devices @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2) @@ -51,3 +51,12 @@ def test_get_parallel_devices(devices, expected): def test_set_cuda_device(set_device_mock): CUDAAccelerator().setup_device(torch.device("cuda", 1)) set_device_mock.assert_called_once_with(torch.device("cuda", 1)) + + +@mock.patch("lightning_lite.accelerators.cuda._device_count_nvml", return_value=-1) +@mock.patch("torch.cuda.device_count", return_value=100) +def test_num_cuda_devices_without_nvml(*_): + """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for + determining CUDA availability.""" + assert is_cuda_available() + assert num_cuda_devices() == 100 diff --git a/tests/tests_lite/strategies/launchers/test_multiprocessing.py b/tests/tests_lite/strategies/launchers/test_multiprocessing.py index 70b45763fe2df..fef19f06715ab 100644 --- a/tests/tests_lite/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_lite/strategies/launchers/test_multiprocessing.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from unittest import mock from unittest.mock import ANY, Mock @@ -35,14 +34,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_): _MultiProcessingLauncher(strategy=Mock(), start_method="fork") -@RunIf(skip_windows=True) -@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_multiprocessing_launcher_disabled_forking(start_method): - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) - - @pytest.mark.parametrize("start_method", ["spawn", "fork"]) @mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp") def test_multiprocessing_launcher_start_method(mp_mock, start_method): diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index 09b0f359ec0d8..258934f1ca847 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -692,12 +692,3 @@ def test_gpu_accelerator_no_gpu_backend_found_error(*_): def test_ddp_fork_on_unsupported_platform(_, strategy): with pytest.raises(ValueError, match="process forking is not supported on this platform"): _Connector(strategy=strategy) - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy): - """Test there is an error when forking is disabled via the environment variable and the user requests fork.""" - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - _Connector(devices=2, strategy=strategy) diff --git a/tests/tests_lite/utilities/test_device_parser.py b/tests/tests_lite/utilities/test_device_parser.py index 9f1d9d9a782f7..3c78e30f75d5d 100644 --- a/tests/tests_lite/utilities/test_device_parser.py +++ b/tests/tests_lite/utilities/test_device_parser.py @@ -14,10 +14,7 @@ from unittest import mock import pytest -import torch -from lightning_lite.accelerators.cpu import parse_cpu_cores -from lightning_lite.accelerators.cuda import is_cuda_available, num_cuda_devices from lightning_lite.utilities import device_parser from lightning_lite.utilities.exceptions import MisconfigurationException @@ -87,22 +84,3 @@ def test_parse_gpu_fail_on_non_existent_id_2(_): def test_parse_gpu_returns_none_when_no_devices_are_available(_, devices): with pytest.raises(MisconfigurationException): device_parser.parse_gpu_ids(devices, include_cuda=True) - - -@pytest.mark.skipif( - "fork" in torch.multiprocessing.get_all_start_methods(), reason="Requires platform without forking support" -) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=2) -def test_num_cuda_devices_without_forking(*_): - """This merely tests that on platforms without fork support our helper functions fall back to the default - implementation for determining cuda availability.""" - assert is_cuda_available() - assert num_cuda_devices() == 2 - - -@pytest.mark.parametrize("devices", ([3], -1)) -def test_invalid_devices_with_cpu_accelerator(devices): - """Test invalid device flag raises MisconfigurationException.""" - with pytest.raises(TypeError, match="should be an int > 0"): - parse_cpu_cores(devices) diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 4859ddba60574..ad3e891ad607f 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from unittest import mock from unittest.mock import ANY, Mock @@ -19,7 +18,6 @@ import torch from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher -from tests_pytorch.helpers.runif import RunIf @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) @@ -28,14 +26,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_): _MultiProcessingLauncher(strategy=Mock(), start_method="fork") -@RunIf(skip_windows=True) -@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_multiprocessing_launcher_disabled_forking(start_method): - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) - - @pytest.mark.parametrize("start_method", ["spawn", "fork"]) @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp") def test_multiprocessing_launcher_start_method(mp_mock, start_method): diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index b2a784525520e..0496ed8e2b465 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -794,12 +794,3 @@ def test_accelerator_specific_checkpoint_io(*_): def test_ddp_fork_on_unsupported_platform(_, strategy): with pytest.raises(ValueError, match="process forking is not supported on this platform"): Trainer(strategy=strategy) - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES) -@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) -def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy): - """Test there is an error when forking is disabled via the environment variable and the user requests fork.""" - with pytest.raises(ValueError, match="Forking is disabled in this environment"): - Trainer(devices=2, strategy=strategy)