Skip to content

Commit e2dcbfe

Browse files
committed
tests
1 parent 72b36e1 commit e2dcbfe

File tree

3 files changed

+11
-16
lines changed

3 files changed

+11
-16
lines changed

tests/accelerators/test_accelerator_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,10 @@ class Accel(Accelerator):
345345
def auto_device_count() -> int:
346346
return 1
347347

348+
@staticmethod
349+
def is_available() -> bool:
350+
return True
351+
348352
class Prec(PrecisionPlugin):
349353
pass
350354

tests/accelerators/test_ddp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,14 @@ def test_torch_distributed_backend_env_variables(tmpdir):
9191

9292

9393
@RunIf(skip_windows=True)
94-
@mock.patch("torch.cuda.device_count", return_value=1)
95-
@mock.patch("torch.cuda.is_available", return_value=True)
9694
@mock.patch("torch.cuda.set_device")
95+
@mock.patch("torch.cuda.is_available", return_value=True)
96+
@mock.patch("torch.cuda.device_count", return_value=1)
97+
@mock.patch("pytorch_lightning.accelerators.gpu.GPUAccelerator.is_available", return_value=True)
9798
@mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "gloo"}, clear=True)
98-
def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir):
99+
def test_ddp_torch_dist_is_available_in_setup(
100+
mock_set_device, mock_cuda_available, mock_device_count, mock_gpu_is_available, tmpdir
101+
):
99102
"""Test to ensure torch distributed is available within the setup hook using ddp."""
100103

101104
class TestModel(BoringModel):

tests/accelerators/test_gpu.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,5 @@ def test_set_cuda_device(set_device_mock, tmpdir):
6363

6464

6565
@RunIf(min_gpus=1)
66-
def test_real_gpu_availability():
66+
def test_gpu_availability():
6767
assert GPUAccelerator.is_available()
68-
69-
70-
@mock.patch("torch.cuda.is_available", return_value=True)
71-
@mock.patch("torch.cuda.device_count", return_value=2)
72-
def test_gpu_available(*_):
73-
assert GPUAccelerator.is_available()
74-
75-
76-
@mock.patch("torch.cuda.is_available", return_value=False)
77-
@mock.patch("torch.cuda.device_count", return_value=0)
78-
def test_gpu_not_available(*_):
79-
assert not GPUAccelerator.is_available()

0 commit comments

Comments
 (0)