Skip to content

Commit 274c9aa

Browse files
committed
is_available
1 parent 19bc8fa commit 274c9aa

File tree

11 files changed

+56
-15
lines changed

11 files changed

+56
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9898
- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637))
9999

100100

101-
- Added checks to `GPUAccelerator` to assert CUDA availability at initialization ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
101+
- Added `Accelerator.is_available` to assert device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
102102

103103

104104
### Changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ def setup_environment(self, root_device: torch.device) -> None:
3535
3636
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
3737
environment before setup is complete.
38+
39+
Raises:
40+
RuntimeError:
41+
If corresponding hardware is not found.
3842
"""
43+
if not self.is_available():
44+
raise RuntimeError(f"{self.__class__.__qualname__} is not configured to run on this hardware.")
3945

4046
def setup(self, trainer: "pl.Trainer") -> None:
4147
"""Setup plugins for the trainer fit and creates optimizers.
@@ -59,3 +65,8 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
5965
@abstractmethod
6066
def auto_device_count() -> int:
6167
"""Get the device count when set to auto."""
68+
69+
@staticmethod
70+
@abstractmethod
71+
def is_available() -> bool:
72+
"""Detect if the hardware is available."""

pytorch_lightning/accelerators/cpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def setup_environment(self, root_device: torch.device) -> None:
3131
MisconfigurationException:
3232
If the selected device is not CPU.
3333
"""
34+
super().setup_environment(root_device)
3435
if root_device.type != "cpu":
3536
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")
3637

@@ -42,3 +43,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
4243
def auto_device_count() -> int:
4344
"""Get the devices when set to auto."""
4445
return 1
46+
47+
@staticmethod
48+
def is_available() -> bool:
49+
"""CPU is always available for execution."""
50+
return True

pytorch_lightning/accelerators/gpu.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,13 @@
3333
class GPUAccelerator(Accelerator):
3434
"""Accelerator for GPU devices."""
3535

36-
def __init__(self) -> None:
37-
"""
38-
Raises:
39-
MisconfigurationException:
40-
If torch.cuda isn't available.
41-
If no CUDA devices are found.
42-
"""
43-
if not torch.cuda.is_available():
44-
raise MisconfigurationException("GPU Accelerator used, but CUDA isn't available.")
45-
if torch.cuda.device_count() == 0:
46-
raise MisconfigurationException("GPU Accelerator used, but found no CUDA devices available.")
47-
4836
def setup_environment(self, root_device: torch.device) -> None:
4937
"""
5038
Raises:
5139
MisconfigurationException:
5240
If the selected device is not GPU.
5341
"""
42+
super().setup_environment(root_device)
5443
if root_device.type != "cuda":
5544
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
5645
torch.cuda.set_device(root_device)
@@ -91,6 +80,10 @@ def auto_device_count() -> int:
9180
"""Get the devices when set to auto."""
9281
return torch.cuda.device_count()
9382

83+
@staticmethod
84+
def is_available() -> bool:
85+
return torch.cuda.is_available() and torch.cuda.device_count > 0
86+
9487

9588
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:
9689
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

pytorch_lightning/accelerators/ipu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19+
from pytorch_lightning.utilities import _IPU_AVAILABLE
1920

2021

2122
class IPUAccelerator(Accelerator):
@@ -31,3 +32,7 @@ def auto_device_count() -> int:
3132
# TODO (@kaushikb11): 4 is the minimal unit they are shipped in.
3233
# Update this when api is exposed by the Graphcore team.
3334
return 4
35+
36+
@staticmethod
37+
def is_available() -> bool:
38+
return _IPU_AVAILABLE

pytorch_lightning/accelerators/tpu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19-
from pytorch_lightning.utilities import _XLA_AVAILABLE
19+
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE
2020

2121
if _XLA_AVAILABLE:
2222
import torch_xla.core.xla_model as xm
@@ -47,3 +47,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
4747
def auto_device_count() -> int:
4848
"""Get the devices when set to auto."""
4949
return 8
50+
51+
@staticmethod
52+
def is_available() -> bool:
53+
return _TPU_AVAILABLE

tests/accelerators/test_cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def test_restore_checkpoint_after_pre_setup_default():
2222
assert not plugin.restore_checkpoint_after_setup
2323

2424

25+
def test_availability():
26+
assert CPUAccelerator.is_available
27+
28+
2529
@pytest.mark.parametrize("restore_after_pre_setup", [True, False])
2630
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
2731
"""Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre-

tests/accelerators/test_dp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def _assert_extra_outputs(self, outputs):
160160
@mock.patch("torch.cuda.device_count", return_value=2)
161161
def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, mock_is_available, mock_device_count):
162162
"""Test that an exception is raised when overriding batch_transfer_hooks in DP model."""
163-
# monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
164163

165164
class CustomModel(BoringModel):
166165
def transfer_batch_to_device(self, batch, device, dataloader_idx):

tests/accelerators/test_gpu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,20 @@ def test_set_cuda_device(set_device_mock, tmpdir):
6060
)
6161
trainer.fit(model)
6262
set_device_mock.assert_called_once()
63+
64+
65+
@RunIf(min_gpus=1)
66+
def test_gpu_availability():
67+
assert GPUAccelerator.is_available()
68+
69+
70+
@mock.patch("torch.cuda.is_available", return_value=True)
71+
@mock.patch("torch.cuda.device_count", return_value=2)
72+
def test_mocked_gpu_available(*_):
73+
assert GPUAccelerator.is_available()
74+
75+
76+
@mock.patch("torch.cuda.is_available", return_value=False)
77+
@mock.patch("torch.cuda.device_count", return_value=0)
78+
def test_mocked_gpu_availability(*_):
79+
assert not GPUAccelerator.is_available()

tests/accelerators/test_ipu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_fail_if_no_ipus(tmpdir):
106106

107107
@RunIf(ipu=True)
108108
def test_accelerator_selected(tmpdir):
109+
assert IPUAccelerator.is_available()
109110
trainer = Trainer(default_root_dir=tmpdir, ipus=1)
110111
assert isinstance(trainer.accelerator, IPUAccelerator)
111112
trainer = Trainer(default_root_dir=tmpdir, ipus=1, accelerator="ipu")

0 commit comments

Comments
 (0)