diff --git a/CHANGELOG.md b/CHANGELOG.md index 65bfeee76bc15..d98dbc4cde2db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637)) +- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797)) + + ### Changed - Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 724b5b6f244c1..06f82fb8d4b96 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -59,3 +59,8 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @abstractmethod def auto_device_count() -> int: """Get the device count when set to auto.""" + + @staticmethod + @abstractmethod + def is_available() -> bool: + """Detect if the hardware is available.""" diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 75c55fdf5f047..2fbe3bf18b079 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -31,6 +31,7 @@ def setup_environment(self, root_device: torch.device) -> None: MisconfigurationException: If the selected device is not CPU. """ + super().setup_environment(root_device) if root_device.type != "cpu": raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.") @@ -42,3 +43,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: def auto_device_count() -> int: """Get the devices when set to auto.""" return 1 + + @staticmethod + def is_available() -> bool: + """CPU is always available for execution.""" + return True diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 3ccf2e4a7f919..aa8b0d56dbf63 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -39,6 +39,7 @@ def setup_environment(self, root_device: torch.device) -> None: MisconfigurationException: If the selected device is not GPU. """ + super().setup_environment(root_device) if root_device.type != "cuda": raise MisconfigurationException(f"Device should be GPU, got {root_device} instead") torch.cuda.set_device(root_device) @@ -79,6 +80,10 @@ def auto_device_count() -> int: """Get the devices when set to auto.""" return torch.cuda.device_count() + @staticmethod + def is_available() -> bool: + return torch.cuda.device_count() > 0 + def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 155dce5275a9b..6928546cf8c50 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import _IPU_AVAILABLE class IPUAccelerator(Accelerator): @@ -31,3 +32,7 @@ def auto_device_count() -> int: # TODO (@kaushikb11): 4 is the minimal unit they are shipped in. # Update this when api is exposed by the Graphcore team. return 4 + + @staticmethod + def is_available() -> bool: + return _IPU_AVAILABLE diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 34c37dcd95e7f..f1f598c3f1b3c 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm @@ -47,3 +47,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: def auto_device_count() -> int: """Get the devices when set to auto.""" return 8 + + @staticmethod + def is_available() -> bool: + return _TPU_AVAILABLE diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 3e2ec15216841..8ceb2de96c59c 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -98,6 +98,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_accelerator_choice_ddp_slurm(*_): with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=2) @@ -123,6 +124,7 @@ def test_accelerator_choice_ddp_slurm(*_): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_accelerator_choice_ddp2_slurm(*_): with pytest.deprecated_call(match=r"accelerator='ddp2'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp2", gpus=2) @@ -148,6 +150,7 @@ def test_accelerator_choice_ddp2_slurm(*_): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=1) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_accelerator_choice_ddp_te(*_): with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=2) @@ -172,6 +175,7 @@ def test_accelerator_choice_ddp_te(*_): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=1) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_accelerator_choice_ddp2_te(*_): with pytest.deprecated_call(match=r"accelerator='ddp2'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp2", gpus=2) @@ -210,6 +214,7 @@ def test_accelerator_choice_ddp_cpu_te(*_): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=1) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_accelerator_choice_ddp_kubeflow(*_): with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=1) @@ -340,6 +345,10 @@ class Accel(Accelerator): def auto_device_count() -> int: return 1 + @staticmethod + def is_available() -> bool: + return True + class Prec(PrecisionPlugin): pass @@ -735,8 +744,11 @@ def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) @pytest.mark.parametrize("strategy", ["ddp2", DDP2Strategy()]) -def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock, strategy): +def test_strategy_choice_ddp2_slurm( + set_device_mock, device_count_mock, setup_distributed_mock, is_available_mock, strategy +): trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2) assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) @@ -760,6 +772,7 @@ def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_di @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=2) assert isinstance(trainer.accelerator, GPUAccelerator) @@ -783,6 +796,7 @@ def test_strategy_choice_ddp_te(*_): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp2_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp2", gpus=2) assert isinstance(trainer.accelerator, GPUAccelerator) @@ -820,6 +834,7 @@ def test_strategy_choice_ddp_cpu_te(*_): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.device_count", return_value=1) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=1) assert isinstance(trainer.accelerator, GPUAccelerator) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 28011aa497eaa..bb3ebfe487fd9 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -22,6 +22,10 @@ def test_restore_checkpoint_after_pre_setup_default(): assert not plugin.restore_checkpoint_after_setup +def test_availability(): + assert CPUAccelerator.is_available() + + @pytest.mark.parametrize("restore_after_pre_setup", [True, False]) def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup): """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index c356ecf935ae1..342b001abf92c 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -79,6 +79,7 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir, as_module): @RunIf(skip_windows=True) @pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") +@mock.patch("torch.cuda.is_available", return_value=True) def test_torch_distributed_backend_env_variables(tmpdir): """This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError.""" _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} @@ -90,11 +91,14 @@ def test_torch_distributed_backend_env_variables(tmpdir): @RunIf(skip_windows=True) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) @mock.patch("torch.cuda.set_device") +@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("torch.cuda.device_count", return_value=1) +@mock.patch("pytorch_lightning.accelerators.gpu.GPUAccelerator.is_available", return_value=True) @mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "gloo"}, clear=True) -def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir): +def test_ddp_torch_dist_is_available_in_setup( + mock_gpu_is_available, mock_device_count, mock_cuda_available, mock_set_device, tmpdir +): """Test to ensure torch distributed is available within the setup hook using ddp.""" class TestModel(BoringModel): diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index f24876197a5f4..9173db2644d77 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest import torch import torch.nn.functional as F @@ -154,9 +156,10 @@ def _assert_extra_outputs(self, outputs): assert out.dtype is torch.float -def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch): +@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("torch.cuda.is_available", return_value=True) +def test_dp_raise_exception_with_batch_transfer_hooks(mock_is_available, mock_device_count, tmpdir): """Test that an exception is raised when overriding batch_transfer_hooks in DP model.""" - monkeypatch.setattr("torch.cuda.device_count", lambda: 2) class CustomModel(BoringModel): def transfer_batch_to_device(self, batch, device, dataloader_idx): diff --git a/tests/accelerators/test_gpu.py b/tests/accelerators/test_gpu.py index 110ba1be9a82c..dc3d76cd866e7 100644 --- a/tests/accelerators/test_gpu.py +++ b/tests/accelerators/test_gpu.py @@ -60,3 +60,8 @@ def test_set_cuda_device(set_device_mock, tmpdir): ) trainer.fit(model) set_device_mock.assert_called_once() + + +@RunIf(min_gpus=1) +def test_gpu_availability(): + assert GPUAccelerator.is_available() diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 3a250de38a7a8..861b149733c0c 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -106,6 +106,7 @@ def test_fail_if_no_ipus(tmpdir): @RunIf(ipu=True) def test_accelerator_selected(tmpdir): + assert IPUAccelerator.is_available() trainer = Trainer(default_root_dir=tmpdir, ipus=1) assert isinstance(trainer.accelerator, IPUAccelerator) trainer = Trainer(default_root_dir=tmpdir, ipus=1, accelerator="ipu") diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index a4eb26a4bc505..608d98304c757 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -83,6 +83,7 @@ def test_if_test_works_after_train(tmpdir): @RunIf(tpu=True) def test_accelerator_tpu(): + assert TPUAccelerator.is_available() trainer = Trainer(accelerator="tpu", tpu_cores=8) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index a3d9977b31c80..c494c0c1c18e6 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -235,8 +235,9 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun }, ) @mock.patch("torch.cuda.device_count", return_value=1) +@mock.patch("torch.cuda.is_available", return_value=True) @pytest.mark.parametrize("gpus", [[0, 1, 2], 2, "0"]) -def test_torchelastic_gpu_parsing(mocked_device_count, gpus): +def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus): """Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device That we omit sanitizing the gpus as only one of the GPUs is visible.""" trainer = Trainer(gpus=gpus) diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 24c04de6604ef..f3a5504f398ed 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -45,6 +45,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin): "SLURM_LOCALID": "0", }, ) +@mock.patch("torch.cuda.is_available", return_value=True) @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize("strategy,gpus", [("ddp", 2), ("ddp2", 2), ("ddp_spawn", 2)]) @pytest.mark.parametrize( @@ -56,7 +57,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin): pytest.param("apex", True, MyApexPlugin, marks=RunIf(amp_apex=True)), ], ) -def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, plugin_cls): +def test_amp_apex_ddp(mocked_device_count, mocked_is_available, strategy, gpus, amp, custom_plugin, plugin_cls): plugin = None if custom_plugin: plugin = plugin_cls(16, "cpu") if amp == "native" else plugin_cls() diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 3809b8b3e2eb6..9fc994e2e4338 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -184,6 +184,7 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + monkeypatch.setattr("torch.cuda.is_available", lambda: True) cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("sys.argv", ["any.py"] + cli_args): parser = LightningArgumentParser(add_help=False, parse_as_dict=False)