diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index ba95e74428a15..96a061a941b57 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -12,7 +12,7 @@ accelerators Accelerator CPUAccelerator - GPUAccelerator + CUDAAccelerator HPUAccelerator IPUAccelerator TPUAccelerator diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index ee4cd6c5c0005..b7c4c21018dcc 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -249,8 +249,8 @@ Example:: .. code-block:: python - # This is part of the built-in `GPUAccelerator` - class GPUAccelerator(Accelerator): + # This is part of the built-in `CUDAAccelerator` + class CUDAAccelerator(Accelerator): """Accelerator for GPU devices.""" @staticmethod @@ -603,8 +603,8 @@ based on the accelerator type (``"cpu", "gpu", "tpu", "ipu", "auto"``). .. code-block:: python - # This is part of the built-in `GPUAccelerator` - class GPUAccelerator(Accelerator): + # This is part of the built-in `CUDAAccelerator` + class CUDAAccelerator(Accelerator): """Accelerator for GPU devices.""" @staticmethod diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index fdfe9660b90aa..5fedd441fdd2c 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -125,7 +125,7 @@ Accelerator API Accelerator CPUAccelerator - GPUAccelerator + CUDAAccelerator HPUAccelerator IPUAccelerator MPSAccelerator diff --git a/src/pytorch_lightning/accelerators/__init__.py b/src/pytorch_lightning/accelerators/__init__.py index e7d757cd73149..1bba4a42879bc 100644 --- a/src/pytorch_lightning/accelerators/__init__.py +++ b/src/pytorch_lightning/accelerators/__init__.py @@ -12,6 +12,7 @@ # limitations under the License. from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401 from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401 from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 diff --git a/src/pytorch_lightning/accelerators/cuda.py b/src/pytorch_lightning/accelerators/cuda.py new file mode 100644 index 0000000000000..89d1a5b284b0c --- /dev/null +++ b/src/pytorch_lightning/accelerators/cuda.py @@ -0,0 +1,167 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import shutil +import subprocess +from typing import Any, Dict, List, Optional, Union + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import device_parser +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _DEVICE + +_log = logging.getLogger(__name__) + + +class CUDAAccelerator(Accelerator): + """Accelerator for NVIDIA CUDA devices.""" + + def setup_environment(self, root_device: torch.device) -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not GPU. + """ + super().setup_environment(root_device) + if root_device.type != "cuda": + raise MisconfigurationException(f"Device should be GPU, got {root_device} instead") + torch.cuda.set_device(root_device) + + def setup(self, trainer: "pl.Trainer") -> None: + # TODO refactor input from trainer to local_rank @four4fish + self.set_nvidia_flags(trainer.local_rank) + # clear cache before training + torch.cuda.empty_cache() + + @staticmethod + def set_nvidia_flags(local_rank: int) -> None: + # set the correct cuda visible devices (using pci order) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count())) + devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) + _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + """Gets stats for the given GPU device. + + Args: + device: GPU device for which to get stats + + Returns: + A dictionary mapping the metrics to their values. + + Raises: + FileNotFoundError: + If nvidia-smi installation not found + """ + return torch.cuda.memory_stats(device) + + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + """Accelerator device parsing logic.""" + return device_parser.parse_gpu_ids(devices, include_cuda=True) + + @staticmethod + def get_parallel_devices(devices: List[int]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + return [torch.device("cuda", i) for i in devices] + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return torch.cuda.device_count() + + @staticmethod + def is_available() -> bool: + return torch.cuda.device_count() > 0 + + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "cuda", + cls, + description=f"{cls.__class__.__name__}", + ) + # temporarily enable "gpu" to point to the CUDA Accelerator + accelerator_registry.register( + "gpu", + cls, + description=f"{cls.__class__.__name__}", + ) + + def teardown(self) -> None: + # clean up memory + torch.cuda.empty_cache() + + +def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover + """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. + + Args: + device: GPU device for which to get stats + + Returns: + A dictionary mapping the metrics to their values. + + Raises: + FileNotFoundError: + If nvidia-smi installation not found + """ + nvidia_smi_path = shutil.which("nvidia-smi") + if nvidia_smi_path is None: + raise FileNotFoundError("nvidia-smi: command not found") + + gpu_stat_metrics = [ + ("utilization.gpu", "%"), + ("memory.used", "MB"), + ("memory.free", "MB"), + ("utilization.memory", "%"), + ("fan.speed", "%"), + ("temperature.gpu", "°C"), + ("temperature.memory", "°C"), + ] + gpu_stat_keys = [k for k, _ in gpu_stat_metrics] + gpu_query = ",".join(gpu_stat_keys) + + index = torch._utils._get_device_index(device) + gpu_id = _get_gpu_id(index) + result = subprocess.run( + [nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"], + encoding="utf-8", + capture_output=True, + check=True, + ) + + def _to_float(x: str) -> float: + try: + return float(x) + except ValueError: + return 0.0 + + s = result.stdout.strip() + stats = [_to_float(x) for x in s.split(", ")] + gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)} + return gpu_stats + + +def _get_gpu_id(device_id: int) -> str: + """Get the unmasked real GPU IDs.""" + # All devices if `CUDA_VISIBLE_DEVICES` unset + default = ",".join(str(i) for i in range(torch.cuda.device_count())) + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") + return cuda_visible_devices[device_id].strip() diff --git a/src/pytorch_lightning/accelerators/gpu.py b/src/pytorch_lightning/accelerators/gpu.py index 898ce09b91431..a7d054b946393 100644 --- a/src/pytorch_lightning/accelerators/gpu.py +++ b/src/pytorch_lightning/accelerators/gpu.py @@ -11,151 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import os -import shutil -import subprocess -from typing import Any, Dict, List, Optional, Union +from pytorch_lightning.accelerators.cuda import CUDAAccelerator +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation -import torch -import pytorch_lightning as pl -from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import device_parser -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _DEVICE +class GPUAccelerator(CUDAAccelerator): + """Accelerator for NVIDIA GPU devices. -_log = logging.getLogger(__name__) + .. deprecated:: 1.9 - -class GPUAccelerator(Accelerator): - """Accelerator for GPU devices.""" - - def setup_environment(self, root_device: torch.device) -> None: - """ - Raises: - MisconfigurationException: - If the selected device is not GPU. - """ - super().setup_environment(root_device) - if root_device.type != "cuda": - raise MisconfigurationException(f"Device should be GPU, got {root_device} instead") - torch.cuda.set_device(root_device) - - def setup(self, trainer: "pl.Trainer") -> None: - # TODO refactor input from trainer to local_rank @four4fish - self.set_nvidia_flags(trainer.local_rank) - # clear cache before training - torch.cuda.empty_cache() - - @staticmethod - def set_nvidia_flags(local_rank: int) -> None: - # set the correct cuda visible devices (using pci order) - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count())) - devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) - _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") - - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: - """Gets stats for the given GPU device. - - Args: - device: GPU device for which to get stats - - Returns: - A dictionary mapping the metrics to their values. - - Raises: - FileNotFoundError: - If nvidia-smi installation not found - """ - return torch.cuda.memory_stats(device) - - @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: - """Accelerator device parsing logic.""" - return device_parser.parse_gpu_ids(devices, include_cuda=True) - - @staticmethod - def get_parallel_devices(devices: List[int]) -> List[torch.device]: - """Gets parallel devices for the Accelerator.""" - return [torch.device("cuda", i) for i in devices] - - @staticmethod - def auto_device_count() -> int: - """Get the devices when set to auto.""" - return torch.cuda.device_count() - - @staticmethod - def is_available() -> bool: - return torch.cuda.device_count() > 0 - - @classmethod - def register_accelerators(cls, accelerator_registry: Dict) -> None: - accelerator_registry.register( - "gpu", - cls, - description=f"{cls.__class__.__name__}", - ) - - def teardown(self) -> None: - # clean up memory - torch.cuda.empty_cache() - - -def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover - """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. - - Args: - device: GPU device for which to get stats - - Returns: - A dictionary mapping the metrics to their values. - - Raises: - FileNotFoundError: - If nvidia-smi installation not found + Please use the ``CUDAAccelerator`` instead. """ - nvidia_smi_path = shutil.which("nvidia-smi") - if nvidia_smi_path is None: - raise FileNotFoundError("nvidia-smi: command not found") - gpu_stat_metrics = [ - ("utilization.gpu", "%"), - ("memory.used", "MB"), - ("memory.free", "MB"), - ("utilization.memory", "%"), - ("fan.speed", "%"), - ("temperature.gpu", "°C"), - ("temperature.memory", "°C"), - ] - gpu_stat_keys = [k for k, _ in gpu_stat_metrics] - gpu_query = ",".join(gpu_stat_keys) - - index = torch._utils._get_device_index(device) - gpu_id = _get_gpu_id(index) - result = subprocess.run( - [nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"], - encoding="utf-8", - capture_output=True, - check=True, - ) - - def _to_float(x: str) -> float: - try: - return float(x) - except ValueError: - return 0.0 - - s = result.stdout.strip() - stats = [_to_float(x) for x in s.split(", ")] - gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)} - return gpu_stats - - -def _get_gpu_id(device_id: int) -> str: - """Get the unmasked real GPU IDs.""" - # All devices if `CUDA_VISIBLE_DEVICES` unset - default = ",".join(str(i) for i in range(torch.cuda.device_count())) - cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") - return cuda_visible_devices[device_id].strip() + def __init__(self) -> None: + rank_zero_deprecation( + "The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9." + " Please use the `CUDAAccelerator` instead!" + ) + super().__init__() diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index baf65d566d2dc..3a9c9ec0ac391 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -23,7 +23,7 @@ from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl -from pytorch_lightning.accelerators import GPUAccelerator +from pytorch_lightning.accelerators import CUDAAccelerator from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop @@ -411,7 +411,7 @@ def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher ) return DataLoaderIterDataFetcher elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": - if not isinstance(trainer.accelerator, GPUAccelerator): + if not isinstance(trainer.accelerator, CUDAAccelerator): raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher return DataFetcher diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index ab63b0e6df3be..8b54579a6bbfb 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -17,7 +17,7 @@ from typing import Optional, Type import pytorch_lightning as pl -from pytorch_lightning.accelerators import GPUAccelerator +from pytorch_lightning.accelerators import CUDAAccelerator from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE @@ -340,7 +340,7 @@ def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: ) return DataLoaderIterDataFetcher elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": - if not isinstance(trainer.accelerator, GPUAccelerator): + if not isinstance(trainer.accelerator, CUDAAccelerator): raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher return DataFetcher diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 9b4d3513c1aab..ede42754aafc9 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -27,7 +27,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.cuda import CUDAAccelerator from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -452,7 +452,7 @@ def init_deepspeed(self): if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.") - if not isinstance(self.accelerator, GPUAccelerator): + if not isinstance(self.accelerator, CUDAAccelerator): raise MisconfigurationException( f"DeepSpeed strategy is only supported on GPU but `{self.accelerator.__class__.__name__}` is used." ) diff --git a/src/pytorch_lightning/strategies/hivemind.py b/src/pytorch_lightning/strategies/hivemind.py index b274856bb6113..b258fe7f738ad 100644 --- a/src/pytorch_lightning/strategies/hivemind.py +++ b/src/pytorch_lightning/strategies/hivemind.py @@ -172,9 +172,9 @@ def num_peers(self) -> int: @property def root_device(self) -> torch.device: from pytorch_lightning.accelerators.cpu import CPUAccelerator - from pytorch_lightning.accelerators.gpu import GPUAccelerator + from pytorch_lightning.accelerators.cuda import CUDAAccelerator - if isinstance(self.accelerator, GPUAccelerator): + if isinstance(self.accelerator, CUDAAccelerator): return torch.device(f"cuda:{torch.cuda.current_device()}") elif isinstance(self.accelerator, CPUAccelerator): return torch.device("cpu") diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index 2e112c754cbc5..ece0e5d27bdce 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -22,7 +22,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator -from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.cuda import CUDAAccelerator from pytorch_lightning.accelerators.hpu import HPUAccelerator from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.accelerators.mps import MPSAccelerator @@ -370,12 +370,12 @@ def _check_config_and_set_final_flags( ) self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": - if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"): + if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) - self._accelerator_flag = "gpu" + self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices amp_type = amp_type if isinstance(amp_type, str) else None @@ -475,7 +475,7 @@ def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag( if tpu_cores: self._accelerator_flag = "tpu" if gpus: - self._accelerator_flag = "gpu" + self._accelerator_flag = "cuda" if num_processes: self._accelerator_flag = "cpu" @@ -497,7 +497,7 @@ def _choose_accelerator(self) -> str: if MPSAccelerator.is_available(): return "mps" if torch.cuda.is_available() and torch.cuda.device_count() > 0: - return "gpu" + return "cuda" return "cpu" def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -534,7 +534,7 @@ def _set_devices_flag_if_auto_passed(self) -> None: self._devices_flag = self.accelerator.auto_device_count() def _set_devices_flag_if_auto_select_gpus_passed(self) -> None: - if self._auto_select_gpus and isinstance(self._gpus, int) and isinstance(self.accelerator, GPUAccelerator): + if self._auto_select_gpus and isinstance(self._gpus, int) and isinstance(self.accelerator, CUDAAccelerator): self._devices_flag = pick_multiple_gpus(self._gpus) log.info(f"Auto select gpus: {self._devices_flag}") @@ -579,8 +579,8 @@ def _choose_strategy(self) -> Union[Strategy, str]: return DDPStrategy.strategy_name if len(self._parallel_devices) <= 1: # TODO: Change this once gpu accelerator was renamed to cuda accelerator - if isinstance(self._accelerator_flag, (GPUAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("gpu", "mps") + if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") ): device = device_parser.determine_root_gpu_device(self._parallel_devices) else: @@ -609,7 +609,7 @@ def _check_strategy_and_fallback(self) -> None: if ( strategy_flag in DDPFullyShardedNativeStrategy.get_registered_strategies() or isinstance(self._strategy_flag, DDPFullyShardedNativeStrategy) - ) and self._accelerator_flag != "gpu": + ) and self._accelerator_flag not in ("cuda", "gpu"): raise MisconfigurationException( f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, " "but GPU accelerator is not used." @@ -632,7 +632,7 @@ def _handle_horovod(self) -> None: ) hvd.init() - if isinstance(self.accelerator, GPUAccelerator): + if isinstance(self.accelerator, CUDAAccelerator): # Horovod assigns one local GPU per process self._parallel_devices = [torch.device(f"cuda:{i}") for i in range(hvd.local_size())] else: diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 25357578ea24e..882326f870de6 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,7 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators import ( Accelerator, - GPUAccelerator, + CUDAAccelerator, HPUAccelerator, IPUAccelerator, MPSAccelerator, @@ -1735,7 +1735,7 @@ def __setup_profiler(self) -> None: def _log_device_info(self) -> None: - if GPUAccelerator.is_available(): + if CUDAAccelerator.is_available(): gpu_available = True gpu_type = " (cuda)" elif MPSAccelerator.is_available(): @@ -1745,7 +1745,7 @@ def _log_device_info(self) -> None: gpu_available = False gpu_type = "" - gpu_used = isinstance(self.accelerator, (GPUAccelerator, MPSAccelerator)) + gpu_used = isinstance(self.accelerator, (CUDAAccelerator, MPSAccelerator)) rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0 @@ -1758,10 +1758,10 @@ def _log_device_info(self) -> None: rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs") # TODO: Integrate MPS Accelerator here, once gpu maps to both - if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator): + if torch.cuda.is_available() and not isinstance(self.accelerator, CUDAAccelerator): rank_zero_warn( "GPU available but not used. Set `accelerator` and `devices` using" - f" `Trainer(accelerator='gpu', devices={GPUAccelerator.auto_device_count()})`.", + f" `Trainer(accelerator='gpu', devices={CUDAAccelerator.auto_device_count()})`.", category=PossibleUserWarning, ) @@ -2069,7 +2069,7 @@ def root_gpu(self) -> Optional[int]: "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. " "Please use `Trainer.strategy.root_device.index` instead." ) - return self.strategy.root_device.index if isinstance(self.accelerator, GPUAccelerator) else None + return self.strategy.root_device.index if isinstance(self.accelerator, CUDAAccelerator) else None @property def tpu_cores(self) -> int: @@ -2093,7 +2093,7 @@ def num_gpus(self) -> int: "`Trainer.num_gpus` was deprecated in v1.6 and will be removed in v1.8." " Please use `Trainer.num_devices` instead." ) - return self.num_devices if isinstance(self.accelerator, GPUAccelerator) else 0 + return self.num_devices if isinstance(self.accelerator, CUDAAccelerator) else 0 @property def devices(self) -> int: @@ -2109,7 +2109,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]: "`Trainer.data_parallel_device_ids` was deprecated in v1.6 and will be removed in v1.8." " Please use `Trainer.device_ids` instead." ) - return self.device_ids if isinstance(self.accelerator, GPUAccelerator) else None + return self.device_ids if isinstance(self.accelerator, CUDAAccelerator) else None @property def lightning_module(self) -> "pl.LightningModule": diff --git a/src/pytorch_lightning/utilities/memory.py b/src/pytorch_lightning/utilities/memory.py index 286a571001b0f..573dd6ed0f129 100644 --- a/src/pytorch_lightning/utilities/memory.py +++ b/src/pytorch_lightning/utilities/memory.py @@ -101,7 +101,7 @@ def get_gpu_memory_map() -> Dict[str, float]: r""" .. deprecated:: v1.5 This function was deprecated in v1.5 in favor of - `pytorch_lightning.accelerators.gpu._get_nvidia_gpu_stats` and will be removed in v1.7. + `pytorch_lightning.accelerators.cuda._get_nvidia_gpu_stats` and will be removed in v1.7. Get the current gpu usage. diff --git a/tests/tests_pytorch/accelerators/test_accelerator_connector.py b/tests/tests_pytorch/accelerators/test_accelerator_connector.py index 100a4cc1d1c7a..33911bffb0eb7 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_connector.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_connector.py @@ -24,7 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator -from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.cuda import CUDAAccelerator from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.plugins import DoublePrecisionPlugin, LayerSync, NativeSyncBatchNorm, PrecisionPlugin from pytorch_lightning.plugins.environments import ( @@ -259,14 +259,14 @@ def test_accelerator_cpu(_): with pytest.raises( MisconfigurationException, - match="GPUAccelerator can not run on your system since the accelerator is not available.", + match="CUDAAccelerator can not run on your system since the accelerator is not available.", ): with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"): Trainer(gpus=1) with pytest.raises( MisconfigurationException, - match="GPUAccelerator can not run on your system since the accelerator is not available.", + match="CUDAAccelerator can not run on your system since the accelerator is not available.", ): Trainer(accelerator="gpu") @@ -287,13 +287,13 @@ def test_accelererator_invalid_type_devices(mock_is_available, mock_device_count @RunIf(min_cuda_gpus=1) def test_accelerator_gpu(): trainer = Trainer(accelerator="gpu", devices=1) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) trainer = Trainer(accelerator="gpu") - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) trainer = Trainer(accelerator="auto", devices=1) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) @pytest.mark.parametrize(["devices", "strategy_class"], [(1, SingleDeviceStrategy), (5, DDPSpawnStrategy)]) @@ -312,13 +312,13 @@ def test_accelerator_gpu_with_devices(devices, strategy_class): trainer = Trainer(accelerator="gpu", devices=devices) assert trainer.num_devices == len(devices) if isinstance(devices, list) else devices assert isinstance(trainer.strategy, strategy_class) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) @RunIf(min_cuda_gpus=1) def test_accelerator_auto_with_devices_gpu(): trainer = Trainer(accelerator="auto", devices=1) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert trainer.num_devices == 1 @@ -392,7 +392,7 @@ def test_device_type_when_strategy_instance_gpu_passed(strategy_class): trainer = Trainer(strategy=strategy_class(), accelerator="gpu", devices=2) assert isinstance(trainer.strategy, strategy_class) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) @pytest.mark.parametrize("precision", [1, 12, "invalid"]) @@ -419,7 +419,7 @@ def test_strategy_choice_ddp_spawn_cpu(): @mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=1) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert isinstance(trainer.strategy, DDPStrategy) assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @@ -429,7 +429,7 @@ def test_strategy_choice_ddp(*_): @mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="gpu", devices=1) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert isinstance(trainer.strategy, DDPSpawnStrategy) assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @@ -451,7 +451,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock): def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): trainer = Trainer(fast_dev_run=True, strategy=strategy, accelerator="gpu", devices=2) assert trainer._accelerator_connector._is_slurm_managing_tasks() - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert isinstance(trainer.strategy, DDPStrategy) assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) assert trainer.strategy.cluster_environment.local_rank() == 1 @@ -477,7 +477,7 @@ def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): @mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=2) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert isinstance(trainer.strategy, DDPStrategy) assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment) assert trainer.strategy.cluster_environment.local_rank() == 1 @@ -524,7 +524,7 @@ def test_strategy_choice_ddp_cpu_te(*_): @mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=1) - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert isinstance(trainer.strategy, DDPStrategy) assert isinstance(trainer.strategy.cluster_environment, KubeflowEnvironment) assert trainer.strategy.cluster_environment.local_rank() == 0 @@ -649,7 +649,7 @@ def test_devices_auto_choice_cpu( def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock): trainer = Trainer(accelerator="auto", devices="auto") - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert trainer.num_devices == 2 diff --git a/tests/tests_pytorch/accelerators/test_accelerator_registry.py b/tests/tests_pytorch/accelerators/test_accelerator_registry.py index 11c806a21c740..791d4c33dbbe8 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_registry.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_registry.py @@ -63,4 +63,4 @@ def is_available(): def test_available_accelerators_in_registry(): - assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "hpu", "ipu", "mps", "tpu"] + assert AcceleratorRegistry.available_accelerators() == ["cpu", "cuda", "gpu", "hpu", "ipu", "mps", "tpu"] diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 3de26b5888390..9395c7e84c709 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -14,14 +14,14 @@ from unittest import mock from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, CUDAAccelerator, IPUAccelerator, TPUAccelerator from pytorch_lightning.strategies import DDPStrategy @mock.patch("torch.cuda.device_count", return_value=2) def test_auto_device_count(device_count_mock): assert CPUAccelerator.auto_device_count() == 1 - assert GPUAccelerator.auto_device_count() == 2 + assert CUDAAccelerator.auto_device_count() == 2 assert TPUAccelerator.auto_device_count() == 8 assert IPUAccelerator.auto_device_count() == 4 diff --git a/tests/tests_pytorch/accelerators/test_gpu.py b/tests/tests_pytorch/accelerators/test_gpu.py index f6334780d75a5..e660ff270f921 100644 --- a/tests/tests_pytorch/accelerators/test_gpu.py +++ b/tests/tests_pytorch/accelerators/test_gpu.py @@ -17,8 +17,8 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import GPUAccelerator -from pytorch_lightning.accelerators.gpu import get_nvidia_gpu_stats +from pytorch_lightning.accelerators import CUDAAccelerator +from pytorch_lightning.accelerators.cuda import get_nvidia_gpu_stats from pytorch_lightning.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf @@ -26,7 +26,7 @@ @RunIf(min_cuda_gpus=1) def test_get_torch_gpu_stats(tmpdir): current_device = torch.device(f"cuda:{torch.cuda.current_device()}") - gpu_stats = GPUAccelerator().get_device_stats(current_device) + gpu_stats = CUDAAccelerator().get_device_stats(current_device) fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"] for f in fields: @@ -62,7 +62,7 @@ def test_set_cuda_device(set_device_mock, tmpdir): @RunIf(min_cuda_gpus=1) def test_gpu_availability(): - assert GPUAccelerator.is_available() + assert CUDAAccelerator.is_available() @RunIf(min_cuda_gpus=1) diff --git a/tests/tests_pytorch/callbacks/test_quantization.py b/tests/tests_pytorch/callbacks/test_quantization.py index 1a12728c7face..41d0810a0aec8 100644 --- a/tests/tests_pytorch/callbacks/test_quantization.py +++ b/tests/tests_pytorch/callbacks/test_quantization.py @@ -20,7 +20,7 @@ from torchmetrics.functional import mean_absolute_percentage_error as mape from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.accelerators import GPUAccelerator +from pytorch_lightning.accelerators import CUDAAccelerator from pytorch_lightning.callbacks import QuantizationAwareTraining from pytorch_lightning.demos.boring_classes import RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -38,9 +38,9 @@ @RunIf(quantization=True, max_torch="1.11") def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): """Parity test for quant model.""" - cuda_available = GPUAccelerator.is_available() + cuda_available = CUDAAccelerator.is_available() - if observe == "average" and not fuse and GPUAccelerator.is_available(): + if observe == "average" and not fuse and CUDAAccelerator.is_available(): pytest.xfail("TODO: flakiness in GPU CI") seed_everything(42) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 66bbf80d4e3ea..9c7d02d499ab4 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -18,6 +18,7 @@ import pytorch_lightning.loggers.base as logger_base from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.profiler.advanced import AdvancedProfiler @@ -195,3 +196,13 @@ def test_pytorch_profiler_schedule_wrapper_deprecation_warning(): def test_pytorch_profiler_register_record_function_deprecation_warning(): with pytest.deprecated_call(match="RegisterRecordFunction` is deprecated in v1.7 and will be removed in in v1.9."): _ = RegisterRecordFunction(None) + + +def test_gpu_accelerator_deprecation_warning(): + with pytest.deprecated_call( + match=( + "The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9." + + " Please use the `CUDAAccelerator` instead!" + ) + ): + GPUAccelerator() diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index ffd093e6ee0e3..bdc61ca399e12 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -23,7 +23,7 @@ import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator +from pytorch_lightning.accelerators import CPUAccelerator, CUDAAccelerator from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.utilities import device_parser @@ -196,7 +196,7 @@ def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus assert isinstance(trainer._accelerator_connector.cluster_environment, TorchElasticEnvironment) # when use gpu if device_parser.parse_gpu_ids(gpus, include_cuda=True) is not None: - assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.accelerator, CUDAAccelerator) assert trainer.num_devices == len(gpus) if isinstance(gpus, list) else gpus assert trainer.device_ids == device_parser.parse_gpu_ids(gpus, include_cuda=True) # fall back to cpu diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index a5c4f7e101761..c413b0015db61 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -59,7 +59,7 @@ def environment_combinations(): "strategy_cls", [DDPStrategy, DDPShardedStrategy, pytest.param(DeepSpeedStrategy, marks=RunIf(deepspeed=True))], ) -@mock.patch("pytorch_lightning.accelerators.gpu.GPUAccelerator.is_available", return_value=True) +@mock.patch("pytorch_lightning.accelerators.cuda.CUDAAccelerator.is_available", return_value=True) def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strategy_cls): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 58fa28559b97f..003fe2250b575 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -103,7 +103,7 @@ def test_torch_distributed_backend_env_variables(tmpdir): @mock.patch("torch.cuda.set_device") @mock.patch("torch.cuda.is_available", return_value=True) @mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("pytorch_lightning.accelerators.gpu.GPUAccelerator.is_available", return_value=True) +@mock.patch("pytorch_lightning.accelerators.cuda.CUDAAccelerator.is_available", return_value=True) @mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "gloo"}, clear=True) def test_ddp_torch_dist_is_available_in_setup( mock_gpu_is_available, mock_device_count, mock_cuda_available, mock_set_device, tmpdir diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index a0d20fc58ed1c..c46c0168db558 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -34,7 +34,7 @@ import pytorch_lightning import tests_pytorch.helpers.utils as tutils from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer -from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator +from pytorch_lightning.accelerators import CPUAccelerator, CUDAAccelerator from pytorch_lightning.callbacks import EarlyStopping, GradientAccumulationScheduler, ModelCheckpoint, Timer from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter @@ -1967,21 +1967,21 @@ def training_step(self, batch, batch_idx): {"strategy": None, "accelerator": "gpu", "devices": 1}, SingleDeviceStrategy, "single_device", - GPUAccelerator, + CUDAAccelerator, 1, ), - ({"strategy": "dp", "accelerator": "gpu", "devices": 1}, DataParallelStrategy, "dp", GPUAccelerator, 1), - ({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, DDPStrategy, "ddp", GPUAccelerator, 1), + ({"strategy": "dp", "accelerator": "gpu", "devices": 1}, DataParallelStrategy, "dp", CUDAAccelerator, 1), + ({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, DDPStrategy, "ddp", CUDAAccelerator, 1), ( {"strategy": "ddp_spawn", "accelerator": "gpu", "devices": 1}, DDPSpawnStrategy, "ddp_spawn", - GPUAccelerator, + CUDAAccelerator, 1, ), - ({"strategy": None, "accelerator": "gpu", "devices": 2}, DDPSpawnStrategy, "ddp_spawn", GPUAccelerator, 2), - ({"strategy": "dp", "accelerator": "gpu", "devices": 2}, DataParallelStrategy, "dp", GPUAccelerator, 2), - ({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, DDPStrategy, "ddp", GPUAccelerator, 2), + ({"strategy": None, "accelerator": "gpu", "devices": 2}, DDPSpawnStrategy, "ddp_spawn", CUDAAccelerator, 2), + ({"strategy": "dp", "accelerator": "gpu", "devices": 2}, DataParallelStrategy, "dp", CUDAAccelerator, 2), + ({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2), ({"strategy": "ddp", "accelerator": "cpu", "devices": 2}, DDPStrategy, "ddp", CPUAccelerator, 2), ( {"strategy": "ddp_spawn", "accelerator": "cpu", "devices": 2}, @@ -2001,7 +2001,7 @@ def training_step(self, batch, batch_idx): {"strategy": "ddp_fully_sharded", "accelerator": "gpu", "devices": 1}, DDPFullyShardedStrategy, "ddp_fully_sharded", - GPUAccelerator, + CUDAAccelerator, 1, ), ( @@ -2015,65 +2015,65 @@ def training_step(self, batch, batch_idx): {"strategy": DDPSpawnStrategy(), "accelerator": "gpu", "devices": 2}, DDPSpawnStrategy, "ddp_spawn", - GPUAccelerator, + CUDAAccelerator, 2, ), ({"strategy": DDPStrategy()}, DDPStrategy, "ddp", CPUAccelerator, 1), - ({"strategy": DDPStrategy(), "accelerator": "gpu", "devices": 2}, DDPStrategy, "ddp", GPUAccelerator, 2), + ({"strategy": DDPStrategy(), "accelerator": "gpu", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2), ( {"strategy": DataParallelStrategy(), "accelerator": "gpu", "devices": 2}, DataParallelStrategy, "dp", - GPUAccelerator, + CUDAAccelerator, 2, ), ( {"strategy": DDPFullyShardedStrategy(), "accelerator": "gpu", "devices": 2}, DDPFullyShardedStrategy, "ddp_fully_sharded", - GPUAccelerator, + CUDAAccelerator, 2, ), ( {"strategy": DDPSpawnShardedStrategy(), "accelerator": "gpu", "devices": 2}, DDPSpawnShardedStrategy, "ddp_sharded_spawn", - GPUAccelerator, + CUDAAccelerator, 2, ), ( {"strategy": DDPShardedStrategy(), "accelerator": "gpu", "devices": 2}, DDPShardedStrategy, "ddp_sharded", - GPUAccelerator, + CUDAAccelerator, 2, ), ( {"strategy": "ddp_spawn", "accelerator": "gpu", "devices": 2, "num_nodes": 2}, DDPSpawnStrategy, "ddp_spawn", - GPUAccelerator, + CUDAAccelerator, 2, ), ( {"strategy": "ddp_fully_sharded", "accelerator": "gpu", "devices": 1, "num_nodes": 2}, DDPFullyShardedStrategy, "ddp_fully_sharded", - GPUAccelerator, + CUDAAccelerator, 1, ), ( {"strategy": "ddp_sharded", "accelerator": "gpu", "devices": 2, "num_nodes": 2}, DDPShardedStrategy, "ddp_sharded", - GPUAccelerator, + CUDAAccelerator, 2, ), ( {"strategy": "ddp_sharded_spawn", "accelerator": "gpu", "devices": 2, "num_nodes": 2}, DDPSpawnShardedStrategy, "ddp_sharded_spawn", - GPUAccelerator, + CUDAAccelerator, 2, ), ],