Skip to content

Commit 1b107c5

Browse files
authored
Add Accelerator.is_available() interface requirement (#11797)
1 parent c618e59 commit 1b107c5

File tree

16 files changed

+73
-9
lines changed

16 files changed

+73
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
108108
- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637))
109109

110110

111+
- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
112+
113+
111114
### Changed
112115

113116
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,8 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
5959
@abstractmethod
6060
def auto_device_count() -> int:
6161
"""Get the device count when set to auto."""
62+
63+
@staticmethod
64+
@abstractmethod
65+
def is_available() -> bool:
66+
"""Detect if the hardware is available."""

pytorch_lightning/accelerators/cpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def setup_environment(self, root_device: torch.device) -> None:
3131
MisconfigurationException:
3232
If the selected device is not CPU.
3333
"""
34+
super().setup_environment(root_device)
3435
if root_device.type != "cpu":
3536
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")
3637

@@ -42,3 +43,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
4243
def auto_device_count() -> int:
4344
"""Get the devices when set to auto."""
4445
return 1
46+
47+
@staticmethod
48+
def is_available() -> bool:
49+
"""CPU is always available for execution."""
50+
return True

pytorch_lightning/accelerators/gpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def setup_environment(self, root_device: torch.device) -> None:
3939
MisconfigurationException:
4040
If the selected device is not GPU.
4141
"""
42+
super().setup_environment(root_device)
4243
if root_device.type != "cuda":
4344
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
4445
torch.cuda.set_device(root_device)
@@ -79,6 +80,10 @@ def auto_device_count() -> int:
7980
"""Get the devices when set to auto."""
8081
return torch.cuda.device_count()
8182

83+
@staticmethod
84+
def is_available() -> bool:
85+
return torch.cuda.device_count() > 0
86+
8287

8388
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:
8489
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

pytorch_lightning/accelerators/ipu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19+
from pytorch_lightning.utilities import _IPU_AVAILABLE
1920

2021

2122
class IPUAccelerator(Accelerator):
@@ -31,3 +32,7 @@ def auto_device_count() -> int:
3132
# TODO (@kaushikb11): 4 is the minimal unit they are shipped in.
3233
# Update this when api is exposed by the Graphcore team.
3334
return 4
35+
36+
@staticmethod
37+
def is_available() -> bool:
38+
return _IPU_AVAILABLE

pytorch_lightning/accelerators/tpu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19-
from pytorch_lightning.utilities import _XLA_AVAILABLE
19+
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE
2020

2121
if _XLA_AVAILABLE:
2222
import torch_xla.core.xla_model as xm
@@ -47,3 +47,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
4747
def auto_device_count() -> int:
4848
"""Get the devices when set to auto."""
4949
return 8
50+
51+
@staticmethod
52+
def is_available() -> bool:
53+
return _TPU_AVAILABLE

tests/accelerators/test_accelerator_connector.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
9898
@mock.patch("torch.cuda.set_device")
9999
@mock.patch("torch.cuda.device_count", return_value=2)
100100
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
101+
@mock.patch("torch.cuda.is_available", return_value=True)
101102
def test_accelerator_choice_ddp_slurm(*_):
102103
with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"):
103104
trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=2)
@@ -123,6 +124,7 @@ def test_accelerator_choice_ddp_slurm(*_):
123124
@mock.patch("torch.cuda.set_device")
124125
@mock.patch("torch.cuda.device_count", return_value=2)
125126
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
127+
@mock.patch("torch.cuda.is_available", return_value=True)
126128
def test_accelerator_choice_ddp2_slurm(*_):
127129
with pytest.deprecated_call(match=r"accelerator='ddp2'\)` has been deprecated in v1.5"):
128130
trainer = Trainer(fast_dev_run=True, accelerator="ddp2", gpus=2)
@@ -148,6 +150,7 @@ def test_accelerator_choice_ddp2_slurm(*_):
148150
@mock.patch("torch.cuda.set_device")
149151
@mock.patch("torch.cuda.device_count", return_value=1)
150152
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
153+
@mock.patch("torch.cuda.is_available", return_value=True)
151154
def test_accelerator_choice_ddp_te(*_):
152155
with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"):
153156
trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=2)
@@ -172,6 +175,7 @@ def test_accelerator_choice_ddp_te(*_):
172175
@mock.patch("torch.cuda.set_device")
173176
@mock.patch("torch.cuda.device_count", return_value=1)
174177
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
178+
@mock.patch("torch.cuda.is_available", return_value=True)
175179
def test_accelerator_choice_ddp2_te(*_):
176180
with pytest.deprecated_call(match=r"accelerator='ddp2'\)` has been deprecated in v1.5"):
177181
trainer = Trainer(fast_dev_run=True, accelerator="ddp2", gpus=2)
@@ -210,6 +214,7 @@ def test_accelerator_choice_ddp_cpu_te(*_):
210214
@mock.patch("torch.cuda.set_device")
211215
@mock.patch("torch.cuda.device_count", return_value=1)
212216
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
217+
@mock.patch("torch.cuda.is_available", return_value=True)
213218
def test_accelerator_choice_ddp_kubeflow(*_):
214219
with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"):
215220
trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=1)
@@ -340,6 +345,10 @@ class Accel(Accelerator):
340345
def auto_device_count() -> int:
341346
return 1
342347

348+
@staticmethod
349+
def is_available() -> bool:
350+
return True
351+
343352
class Prec(PrecisionPlugin):
344353
pass
345354

@@ -735,8 +744,11 @@ def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy):
735744
@mock.patch("torch.cuda.set_device")
736745
@mock.patch("torch.cuda.device_count", return_value=2)
737746
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
747+
@mock.patch("torch.cuda.is_available", return_value=True)
738748
@pytest.mark.parametrize("strategy", ["ddp2", DDP2Strategy()])
739-
def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock, strategy):
749+
def test_strategy_choice_ddp2_slurm(
750+
set_device_mock, device_count_mock, setup_distributed_mock, is_available_mock, strategy
751+
):
740752
trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2)
741753
assert trainer._accelerator_connector._is_slurm_managing_tasks()
742754
assert isinstance(trainer.accelerator, GPUAccelerator)
@@ -760,6 +772,7 @@ def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_di
760772
@mock.patch("torch.cuda.set_device")
761773
@mock.patch("torch.cuda.device_count", return_value=2)
762774
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
775+
@mock.patch("torch.cuda.is_available", return_value=True)
763776
def test_strategy_choice_ddp_te(*_):
764777
trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=2)
765778
assert isinstance(trainer.accelerator, GPUAccelerator)
@@ -783,6 +796,7 @@ def test_strategy_choice_ddp_te(*_):
783796
@mock.patch("torch.cuda.set_device")
784797
@mock.patch("torch.cuda.device_count", return_value=2)
785798
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
799+
@mock.patch("torch.cuda.is_available", return_value=True)
786800
def test_strategy_choice_ddp2_te(*_):
787801
trainer = Trainer(fast_dev_run=True, strategy="ddp2", gpus=2)
788802
assert isinstance(trainer.accelerator, GPUAccelerator)
@@ -820,6 +834,7 @@ def test_strategy_choice_ddp_cpu_te(*_):
820834
@mock.patch("torch.cuda.set_device")
821835
@mock.patch("torch.cuda.device_count", return_value=1)
822836
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
837+
@mock.patch("torch.cuda.is_available", return_value=True)
823838
def test_strategy_choice_ddp_kubeflow(*_):
824839
trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=1)
825840
assert isinstance(trainer.accelerator, GPUAccelerator)

tests/accelerators/test_cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def test_restore_checkpoint_after_pre_setup_default():
2222
assert not plugin.restore_checkpoint_after_setup
2323

2424

25+
def test_availability():
26+
assert CPUAccelerator.is_available()
27+
28+
2529
@pytest.mark.parametrize("restore_after_pre_setup", [True, False])
2630
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
2731
"""Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre-

tests/accelerators/test_ddp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir, as_module):
7979

8080
@RunIf(skip_windows=True)
8181
@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine")
82+
@mock.patch("torch.cuda.is_available", return_value=True)
8283
def test_torch_distributed_backend_env_variables(tmpdir):
8384
"""This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError."""
8485
_environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"}
@@ -90,11 +91,14 @@ def test_torch_distributed_backend_env_variables(tmpdir):
9091

9192

9293
@RunIf(skip_windows=True)
93-
@mock.patch("torch.cuda.device_count", return_value=1)
94-
@mock.patch("torch.cuda.is_available", return_value=True)
9594
@mock.patch("torch.cuda.set_device")
95+
@mock.patch("torch.cuda.is_available", return_value=True)
96+
@mock.patch("torch.cuda.device_count", return_value=1)
97+
@mock.patch("pytorch_lightning.accelerators.gpu.GPUAccelerator.is_available", return_value=True)
9698
@mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "gloo"}, clear=True)
97-
def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir):
99+
def test_ddp_torch_dist_is_available_in_setup(
100+
mock_gpu_is_available, mock_device_count, mock_cuda_available, mock_set_device, tmpdir
101+
):
98102
"""Test to ensure torch distributed is available within the setup hook using ddp."""
99103

100104
class TestModel(BoringModel):

tests/accelerators/test_dp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest import mock
15+
1416
import pytest
1517
import torch
1618
import torch.nn.functional as F
@@ -154,9 +156,10 @@ def _assert_extra_outputs(self, outputs):
154156
assert out.dtype is torch.float
155157

156158

157-
def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch):
159+
@mock.patch("torch.cuda.device_count", return_value=2)
160+
@mock.patch("torch.cuda.is_available", return_value=True)
161+
def test_dp_raise_exception_with_batch_transfer_hooks(mock_is_available, mock_device_count, tmpdir):
158162
"""Test that an exception is raised when overriding batch_transfer_hooks in DP model."""
159-
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
160163

161164
class CustomModel(BoringModel):
162165
def transfer_batch_to_device(self, batch, device, dataloader_idx):

0 commit comments

Comments
 (0)