@@ -140,7 +140,6 @@ def __init__(
140140 # --Parsing_flags------------------------------------------------------
141141 # Get registered strategies, existing accelerators and precision plugins
142142 self ._existing_strategies_str = StrategyRegistry .available_strategies ()
143- # print(self._existing_strategies_str)
144143 self ._existing_accelerator_type = ["tpu" , "ipu" , "gpu" , "cpu" ]
145144 self ._supported_precision = PrecisionType .supported_types ()
146145
@@ -156,7 +155,7 @@ def __init__(
156155 # --Accelerator-------------------------------------------------------------
157156 # handle `auto` and `None`
158157 if self ._accelerator_flag == "auto" or self ._accelerator_flag is None :
159- self ._choose_accelerator ()
158+ self ._accelerator_flag = self . _choose_accelerator ()
160159 # else:
161160 # # [RFC] move to XAccelerator class init?
162161 # self._check_device_availibility()
@@ -388,20 +387,20 @@ def _mapping_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
388387 self ._accelerator_flag = "cpu"
389388
390389 def _choose_accelerator (self ):
390+ if _TPU_AVAILABLE :
391+ return "tpu"
392+ if _IPU_AVAILABLE :
393+ return "ipu"
391394 if self ._accelerator_flag == "auto" :
392- if _TPU_AVAILABLE :
393- self ._accelerator_flag = "tpu"
394- elif _IPU_AVAILABLE :
395- self ._accelerator_flag = "ipu"
396- elif torch .cuda .is_available () and torch .cuda .device_count () > 0 :
397- self ._accelerator_flag = "gpu"
395+ if torch .cuda .is_available () and torch .cuda .device_count () > 0 :
396+ return "gpu"
398397 else :
399- self ._accelerator_flag = "cpu"
400398 if self ._device_flag == "auto" :
401399 self ._device_flag = 1
400+ return "cpu"
402401 # [RFC] this is current logic, if accelerator not set, default cpu?
403402 else :
404- self . _accelerator_flag = "cpu"
403+ return "cpu"
405404
406405 # TODO move this to xAccelerator
407406 # def _check_device_availibility(self):
@@ -485,8 +484,8 @@ def _is_slurm_managing_tasks(self):
485484 return num_slurm_tasks == total_requested_devices
486485
487486 def _choose_strategy (self ):
488- if self ._accelerator_flag == "ipu " :
489- self ._strategy_flag = "ipu "
487+ if self ._accelerator_flag == "ipu_strategy " :
488+ self ._strategy_flag = "ipu_strategy "
490489 elif self ._accelerator_flag == "tpu" :
491490 if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
492491 self ._strategy_flag = "tpu_spawn"
@@ -755,29 +754,31 @@ def devices(self):
755754 return 1
756755 elif isinstance (self .strategy , ParallelStrategy ):
757756 return len (self .strategy .parallel_devices )
758- else :
759- return 0
757+ return 0
760758
761759 @property
762760 def tpu_cores (self ) -> int :
763761 if isinstance (self .accelerator , TPUAccelerator ):
764762 return self .devices
765- else :
766- return 0
763+ return 0
764+
765+ @property
766+ def tpu_id (self ) -> Optional [int ]:
767+ if isinstance (self .accelerator , TPUAccelerator ):
768+ return self .parallel_devices [0 ]
769+ return None
767770
768771 @property
769772 def num_ipus (self ) -> int :
770773 if isinstance (self .accelerator , IPUAccelerator ):
771774 return self .devices
772- else :
773- return 0
775+ return 0
774776
775777 @property
776778 def num_gpus (self ) -> int :
777779 if isinstance (self .accelerator , GPUAccelerator ):
778780 return self .devices
779- else :
780- return 0
781+ return 0
781782
782783 # def parallel_device_ids():
783784 @property
0 commit comments