Skip to content

Commit e9c571d

Browse files
carmoccaawaelchli
andauthored
Move accelerator-specific parsing functions with their accelerators (#14753)
Co-authored-by: awaelchli <[email protected]>
1 parent 4f9c779 commit e9c571d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+387
-477
lines changed

src/lightning_lite/accelerators/cpu.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import torch
1717

1818
from lightning_lite.accelerators.accelerator import Accelerator
19-
from lightning_lite.utilities import device_parser
2019

2120

2221
class CPUAccelerator(Accelerator):
@@ -37,13 +36,13 @@ def teardown(self) -> None:
3736
@staticmethod
3837
def parse_devices(devices: Union[int, str, List[int]]) -> int:
3938
"""Accelerator device parsing logic."""
40-
devices = device_parser.parse_cpu_cores(devices)
39+
devices = parse_cpu_cores(devices)
4140
return devices
4241

4342
@staticmethod
4443
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
4544
"""Gets parallel devices for the Accelerator."""
46-
devices = device_parser.parse_cpu_cores(devices)
45+
devices = parse_cpu_cores(devices)
4746
return [torch.device("cpu")] * devices
4847

4948
@staticmethod
@@ -63,3 +62,26 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
6362
cls,
6463
description=cls.__class__.__name__,
6564
)
65+
66+
67+
def parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
68+
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
69+
:class:`~pytorch_lightning.trainer.Trainer`.
70+
71+
Args:
72+
cpu_cores: An int > 0.
73+
74+
Returns:
75+
An int representing the number of processes
76+
77+
Raises:
78+
MisconfigurationException:
79+
If cpu_cores is not an int > 0
80+
"""
81+
if isinstance(cpu_cores, str) and cpu_cores.strip().isdigit():
82+
cpu_cores = int(cpu_cores)
83+
84+
if not isinstance(cpu_cores, int) or cpu_cores <= 0:
85+
raise TypeError("`devices` selected with `CPUAccelerator` should be an int > 0.")
86+
87+
return cpu_cores

src/lightning_lite/accelerators/cuda.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import multiprocessing
1415
from typing import Dict, List, Optional, Union
1516

1617
import torch
1718

1819
from lightning_lite.accelerators.accelerator import Accelerator
19-
from lightning_lite.utilities import device_parser
20+
from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
2021

2122

2223
class CUDAAccelerator(Accelerator):
@@ -39,7 +40,9 @@ def teardown(self) -> None:
3940
@staticmethod
4041
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
4142
"""Accelerator device parsing logic."""
42-
return device_parser.parse_gpu_ids(devices, include_cuda=True)
43+
from lightning_lite.utilities.device_parser import parse_gpu_ids
44+
45+
return parse_gpu_ids(devices, include_cuda=True)
4346

4447
@staticmethod
4548
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
@@ -49,11 +52,11 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]:
4952
@staticmethod
5053
def auto_device_count() -> int:
5154
"""Get the devices when set to auto."""
52-
return device_parser.num_cuda_devices()
55+
return num_cuda_devices()
5356

5457
@staticmethod
5558
def is_available() -> bool:
56-
return device_parser.num_cuda_devices() > 0
59+
return num_cuda_devices() > 0
5760

5861
@classmethod
5962
def register_accelerators(cls, accelerator_registry: Dict) -> None:
@@ -62,3 +65,35 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
6265
cls,
6366
description=cls.__class__.__name__,
6467
)
68+
69+
70+
def _get_all_available_cuda_gpus() -> List[int]:
71+
"""
72+
Returns:
73+
A list of all available CUDA GPUs
74+
"""
75+
return list(range(num_cuda_devices()))
76+
77+
78+
def num_cuda_devices() -> int:
79+
"""Returns the number of GPUs available.
80+
81+
Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
82+
if the platform allows it.
83+
"""
84+
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
85+
return torch.cuda.device_count()
86+
with multiprocessing.get_context("fork").Pool(1) as pool:
87+
return pool.apply(torch.cuda.device_count)
88+
89+
90+
def is_cuda_available() -> bool:
91+
"""Returns a bool indicating if CUDA is currently available.
92+
93+
Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
94+
if the platform allows it.
95+
"""
96+
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
97+
return torch.cuda.is_available()
98+
with multiprocessing.get_context("fork").Pool(1) as pool:
99+
return pool.apply(torch.cuda.is_available)

src/lightning_lite/accelerators/mps.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919

2020
from lightning_lite.accelerators.accelerator import Accelerator
21-
from lightning_lite.utilities import device_parser
2221
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
2322

2423

@@ -40,15 +39,16 @@ def teardown(self) -> None:
4039
@staticmethod
4140
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
4241
"""Accelerator device parsing logic."""
43-
parsed_devices = device_parser.parse_gpu_ids(devices, include_mps=True)
42+
from lightning_lite.utilities.device_parser import parse_gpu_ids
43+
44+
parsed_devices = parse_gpu_ids(devices, include_mps=True)
4445
return parsed_devices
4546

4647
@staticmethod
4748
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
4849
"""Gets parallel devices for the Accelerator."""
4950
parsed_devices = MPSAccelerator.parse_devices(devices)
5051
assert parsed_devices is not None
51-
5252
return [torch.device("mps", i) for i in range(len(parsed_devices))]
5353

5454
@staticmethod
@@ -72,3 +72,11 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
7272
cls,
7373
description=cls.__class__.__name__,
7474
)
75+
76+
77+
def _get_all_available_mps_gpus() -> List[int]:
78+
"""
79+
Returns:
80+
A list of all available MPS GPUs
81+
"""
82+
return [0] if MPSAccelerator.is_available() else []

src/lightning_lite/accelerators/tpu.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Dict, List, Optional, Union
14+
from typing import Any, Dict, List, Optional, Union
1515

1616
import torch
1717

1818
from lightning_lite.accelerators.accelerator import Accelerator
19-
from lightning_lite.utilities import device_parser
19+
from lightning_lite.utilities.device_parser import _check_data_type
2020
from lightning_lite.utilities.imports import _TPU_AVAILABLE
2121

2222

@@ -32,7 +32,7 @@ def teardown(self) -> None:
3232
@staticmethod
3333
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]:
3434
"""Accelerator device parsing logic."""
35-
return device_parser.parse_tpu_cores(devices)
35+
return parse_tpu_cores(devices)
3636

3737
@staticmethod
3838
def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]:
@@ -57,3 +57,54 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
5757
cls,
5858
description=cls.__class__.__name__,
5959
)
60+
61+
62+
def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
63+
"""
64+
Parses the tpu_cores given in the format as accepted by the
65+
:class:`~pytorch_lightning.trainer.Trainer`.
66+
67+
Args:
68+
tpu_cores: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
69+
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
70+
A list of ints or a strings containing a list of comma separated integers
71+
indicates the specific TPU core to use.
72+
73+
Returns:
74+
A list of tpu_cores to be used or ``None`` if no TPU cores were requested
75+
76+
Raises:
77+
MisconfigurationException:
78+
If TPU cores aren't 1, 8 or [<1-8>]
79+
"""
80+
_check_data_type(tpu_cores)
81+
82+
if isinstance(tpu_cores, str):
83+
tpu_cores = _parse_tpu_cores_str(tpu_cores.strip())
84+
85+
if not _tpu_cores_valid(tpu_cores):
86+
raise TypeError("`tpu_cores` can only be 1, 8 or [<1-8>]")
87+
88+
return tpu_cores
89+
90+
91+
def _tpu_cores_valid(tpu_cores: Any) -> bool:
92+
# allow 1 or 8 cores
93+
if tpu_cores in (1, 8, None):
94+
return True
95+
96+
# allow picking 1 of 8 indexes
97+
if isinstance(tpu_cores, (list, tuple, set)):
98+
has_1_tpu_idx = len(tpu_cores) == 1
99+
is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8
100+
101+
is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx
102+
return is_valid_tpu_core_choice
103+
104+
return False
105+
106+
107+
def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
108+
if tpu_cores in ("1", "8"):
109+
return int(tpu_cores)
110+
return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0]

src/lightning_lite/connector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
XLAStrategy,
5353
)
5454
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
55-
from lightning_lite.utilities import _StrategyType, device_parser, rank_zero_deprecation, rank_zero_info, rank_zero_warn
55+
from lightning_lite.utilities import _StrategyType, rank_zero_deprecation, rank_zero_info, rank_zero_warn
56+
from lightning_lite.utilities.device_parser import determine_root_gpu_device
5657
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE
5758

5859
_PLUGIN = Union[Strategy, Precision, ClusterEnvironment, CheckpointIO]
@@ -429,7 +430,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
429430
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
430431
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
431432
):
432-
device = device_parser.determine_root_gpu_device(self._parallel_devices)
433+
device = determine_root_gpu_device(self._parallel_devices)
433434
else:
434435
device = "cpu"
435436
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_device"

0 commit comments

Comments
 (0)