Skip to content

Commit 09a42e5

Browse files
committed
add context manager that safely patches :func:torch.cuda.is_available with its NVML-based version if possible
1 parent 5e7f225 commit 09a42e5

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

src/lightning_lite/accelerators/cuda.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414
import os
1515
import warnings
16+
from contextlib import contextmanager
1617
from functools import lru_cache
17-
from typing import Dict, List, Optional, Set, Union
18+
from typing import Dict, Generator, List, Optional, Set, Union
1819

1920
import torch
2021

@@ -77,6 +78,25 @@ def _get_all_available_cuda_gpus() -> List[int]:
7778
return list(range(num_cuda_devices()))
7879

7980

81+
@contextmanager
82+
def _patch_cuda_is_available() -> Generator:
83+
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
84+
possible."""
85+
orig_check = None
86+
new_check = torch.cuda.device_count if _TORCH_GREATER_EQUAL_1_13 else _device_count_nvml
87+
88+
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0:
89+
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
90+
# otherwise, patching is_available could lead to attribute errors or infinite recursion
91+
orig_check = torch.cuda.is_available
92+
torch.cuda.is_available = new_check # type: ignore[assignment]
93+
try:
94+
yield
95+
finally:
96+
if orig_check:
97+
torch.cuda.is_available = orig_check
98+
99+
80100
@lru_cache(1)
81101
def num_cuda_devices() -> int:
82102
"""Returns the number of available CUDA devices.

src/lightning_lite/plugins/precision/native_amp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.optim import LBFGS
2121
from typing_extensions import Literal
2222

23-
from lightning_lite.accelerators.cuda import is_cuda_available
23+
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
2424
from lightning_lite.plugins.precision.precision import Precision
2525
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
2626
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
@@ -48,9 +48,9 @@ def __init__(
4848
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
4949
raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.")
5050
if scaler is None and precision == 16:
51-
# if possible, we defer CUDA initialization to support strategies that will attempt forks
52-
torch.cuda.is_available = is_cuda_available
53-
scaler = torch.cuda.amp.GradScaler()
51+
with _patch_cuda_is_available():
52+
# if possible, we defer CUDA initialization to support strategies that will attempt forks
53+
scaler = torch.cuda.amp.GradScaler()
5454
if scaler is not None and precision == "bf16":
5555
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
5656
self.precision = precision

src/pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.optim import LBFGS
2020

2121
import pytorch_lightning as pl
22-
from lightning_lite.accelerators.cuda import is_cuda_available
22+
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
2323
from lightning_lite.utilities.types import Steppable
2424
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2525
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
@@ -51,9 +51,9 @@ def __init__(
5151
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
5252
)
5353
if scaler is None and precision == 16:
54-
# if possible, we defer CUDA initialization to support strategies that will attempt forks
55-
torch.cuda.is_available = is_cuda_available
56-
scaler = torch.cuda.amp.GradScaler()
54+
with _patch_cuda_is_available():
55+
# if possible, we defer CUDA initialization to support strategies that will attempt forks
56+
scaler = torch.cuda.amp.GradScaler()
5757
if scaler is not None and precision == "bf16":
5858
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
5959
self.precision = precision

0 commit comments

Comments
 (0)