Skip to content
22 changes: 8 additions & 14 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,13 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
gpus = _normalize_parse_gpu_input_to_list(gpus)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")

if TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1:
elif TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1:
# omit sanity check on torchelastic as by default shows one visible GPU per process
return gpus
return _sanitize_gpu_ids(gpus)

gpus = _sanitize_gpu_ids(gpus)

return gpus


def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int], int]]:
def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[int, List[int]]]:
"""
Parses the tpu_cores given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Expand Down Expand Up @@ -209,25 +205,23 @@ def _check_data_type(device_ids: Any) -> None:
raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.")


def _tpu_cores_valid(tpu_cores):
def _tpu_cores_valid(tpu_cores: Any) -> bool:
# allow 1 or 8 cores
if tpu_cores in (1, 8, None):
return True

# allow picking 1 of 8 indexes
if isinstance(tpu_cores, (list, tuple, set)):
has_1_tpu_idx = len(tpu_cores) == 1
is_valid_tpu_idx = tpu_cores[0] in range(1, 9)
is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8

is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx
return is_valid_tpu_core_choice

return False


def _parse_tpu_cores_str(tpu_cores):
def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
if tpu_cores in ('1', '8'):
tpu_cores = int(tpu_cores)
else:
tpu_cores = [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0]
return tpu_cores
return int(tpu_cores)
return [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0]
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ ignore_errors = True
ignore_errors = True
[mypy-pytorch_lightning.utilities.cli]
ignore_errors = False
[mypy-pytorch_lightning.utilities.device_parser]
ignore_errors = False

# todo: add proper typing to this module...
[mypy-pl_examples.*]
Expand Down