|
33 | 33 | class GPUAccelerator(Accelerator): |
34 | 34 | """Accelerator for GPU devices.""" |
35 | 35 |
|
36 | | - def __init__(self) -> None: |
37 | | - """ |
38 | | - Raises: |
39 | | - MisconfigurationException: |
40 | | - If torch.cuda isn't available. |
41 | | - If no CUDA devices are found. |
42 | | - """ |
43 | | - if not torch.cuda.is_available(): |
44 | | - raise MisconfigurationException("GPU Accelerator used, but CUDA isn't available.") |
45 | | - if torch.cuda.device_count() == 0: |
46 | | - raise MisconfigurationException("GPU Accelerator used, but found no CUDA devices available.") |
47 | | - |
48 | 36 | def setup_environment(self, root_device: torch.device) -> None: |
49 | 37 | """ |
50 | 38 | Raises: |
51 | 39 | MisconfigurationException: |
52 | 40 | If the selected device is not GPU. |
53 | 41 | """ |
| 42 | + super().setup_environment(root_device) |
54 | 43 | if root_device.type != "cuda": |
55 | 44 | raise MisconfigurationException(f"Device should be GPU, got {root_device} instead") |
56 | 45 | torch.cuda.set_device(root_device) |
@@ -91,6 +80,10 @@ def auto_device_count() -> int: |
91 | 80 | """Get the devices when set to auto.""" |
92 | 81 | return torch.cuda.device_count() |
93 | 82 |
|
| 83 | + @staticmethod |
| 84 | + def is_available() -> bool: |
| 85 | + return torch.cuda.is_available() and torch.cuda.device_count > 0 |
| 86 | + |
94 | 87 |
|
95 | 88 | def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: |
96 | 89 | """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. |
|
0 commit comments