Skip to content

Commit c2c3ab0

Browse files
committed
fix ipus and cli tests
1 parent d117c66 commit c2c3ab0

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

pytorch_lightning/strategies/ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _move_float_tensors_to_half(self, batch: Any) -> Any:
6262
class IPUStrategy(ParallelStrategy):
6363
"""Plugin for training on IPU devices."""
6464

65-
distributed_backend = "ipu"
65+
distributed_backend = "ipu_strategy"
6666

6767
def __init__(
6868
self,

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def __init__(
140140
# --Parsing_flags------------------------------------------------------
141141
# Get registered strategies, existing accelerators and precision plugins
142142
self._existing_strategies_str = StrategyRegistry.available_strategies()
143-
# print(self._existing_strategies_str)
144143
self._existing_accelerator_type = ["tpu", "ipu", "gpu", "cpu"]
145144
self._supported_precision = PrecisionType.supported_types()
146145

@@ -156,7 +155,7 @@ def __init__(
156155
# --Accelerator-------------------------------------------------------------
157156
# handle `auto` and `None`
158157
if self._accelerator_flag == "auto" or self._accelerator_flag is None:
159-
self._choose_accelerator()
158+
self._accelerator_flag = self._choose_accelerator()
160159
# else:
161160
# # [RFC] move to XAccelerator class init?
162161
# self._check_device_availibility()
@@ -388,20 +387,20 @@ def _mapping_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
388387
self._accelerator_flag = "cpu"
389388

390389
def _choose_accelerator(self):
390+
if _TPU_AVAILABLE:
391+
return "tpu"
392+
if _IPU_AVAILABLE:
393+
return "ipu"
391394
if self._accelerator_flag == "auto":
392-
if _TPU_AVAILABLE:
393-
self._accelerator_flag = "tpu"
394-
elif _IPU_AVAILABLE:
395-
self._accelerator_flag = "ipu"
396-
elif torch.cuda.is_available() and torch.cuda.device_count() > 0:
397-
self._accelerator_flag = "gpu"
395+
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
396+
return "gpu"
398397
else:
399-
self._accelerator_flag = "cpu"
400398
if self._device_flag == "auto":
401399
self._device_flag = 1
400+
return "cpu"
402401
# [RFC] this is current logic, if accelerator not set, default cpu?
403402
else:
404-
self._accelerator_flag = "cpu"
403+
return "cpu"
405404

406405
# TODO move this to xAccelerator
407406
# def _check_device_availibility(self):
@@ -485,8 +484,8 @@ def _is_slurm_managing_tasks(self):
485484
return num_slurm_tasks == total_requested_devices
486485

487486
def _choose_strategy(self):
488-
if self._accelerator_flag == "ipu":
489-
self._strategy_flag = "ipu"
487+
if self._accelerator_flag == "ipu_strategy":
488+
self._strategy_flag = "ipu_strategy"
490489
elif self._accelerator_flag == "tpu":
491490
if self._parallel_devices and len(self._parallel_devices) > 1:
492491
self._strategy_flag = "tpu_spawn"
@@ -755,29 +754,31 @@ def devices(self):
755754
return 1
756755
elif isinstance(self.strategy, ParallelStrategy):
757756
return len(self.strategy.parallel_devices)
758-
else:
759-
return 0
757+
return 0
760758

761759
@property
762760
def tpu_cores(self) -> int:
763761
if isinstance(self.accelerator, TPUAccelerator):
764762
return self.devices
765-
else:
766-
return 0
763+
return 0
764+
765+
@property
766+
def tpu_id(self) -> Optional[int]:
767+
if isinstance(self.accelerator, TPUAccelerator):
768+
return self.parallel_devices[0]
769+
return None
767770

768771
@property
769772
def num_ipus(self) -> int:
770773
if isinstance(self.accelerator, IPUAccelerator):
771774
return self.devices
772-
else:
773-
return 0
775+
return 0
774776

775777
@property
776778
def num_gpus(self) -> int:
777779
if isinstance(self.accelerator, GPUAccelerator):
778780
return self.devices
779-
else:
780-
return 0
781+
return 0
781782

782783
# def parallel_device_ids():
783784
@property

tests/utilities/test_cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,10 @@ def on_fit_start(self):
577577
@pytest.mark.parametrize(
578578
"trainer_kwargs",
579579
(
580-
dict(strategy="ddp_spawn"),
580+
# dict(strategy="ddp_spawn")
581+
# !! old accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
582+
# this test never worked with DDPSpawnStrategy
583+
dict(strategy="single_device"),
581584
dict(strategy="ddp"),
582585
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
583586
),

0 commit comments

Comments
 (0)