6161 TorchElasticEnvironment ,
6262)
6363from pytorch_lightning .utilities import (
64+ _AcceleratorType ,
6465 _StrategyType ,
6566 AMPType ,
6667 device_parser ,
67- DeviceType ,
6868 rank_zero_deprecation ,
6969 rank_zero_info ,
7070 rank_zero_warn ,
@@ -106,7 +106,7 @@ def __init__(
106106 plugins ,
107107 ):
108108 # initialization
109- self ._device_type = DeviceType .CPU
109+ self ._device_type = _AcceleratorType .CPU
110110 self ._distrib_type = None
111111 self ._accelerator_type = None
112112
@@ -199,32 +199,32 @@ def _init_deterministic(self, deterministic: bool) -> None:
199199 def select_accelerator_type (self ) -> None :
200200 if self .distributed_backend == "auto" :
201201 if self .has_tpu :
202- self ._accelerator_type = DeviceType .TPU
202+ self ._accelerator_type = _AcceleratorType .TPU
203203 elif self .has_ipu :
204- self ._accelerator_type = DeviceType .IPU
204+ self ._accelerator_type = _AcceleratorType .IPU
205205 elif self .has_gpu :
206- self ._accelerator_type = DeviceType .GPU
206+ self ._accelerator_type = _AcceleratorType .GPU
207207 else :
208208 self ._set_devices_to_cpu_num_processes ()
209- self ._accelerator_type = DeviceType .CPU
210- elif self .distributed_backend == DeviceType .TPU :
209+ self ._accelerator_type = _AcceleratorType .CPU
210+ elif self .distributed_backend == _AcceleratorType .TPU :
211211 if not self .has_tpu :
212212 msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`"
213213 raise MisconfigurationException (f"You passed `accelerator='tpu'`, but { msg } ." )
214- self ._accelerator_type = DeviceType .TPU
215- elif self .distributed_backend == DeviceType .IPU :
214+ self ._accelerator_type = _AcceleratorType .TPU
215+ elif self .distributed_backend == _AcceleratorType .IPU :
216216 if not self .has_ipu :
217217 msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`"
218218 raise MisconfigurationException (f"You passed `accelerator='ipu'`, but { msg } ." )
219- self ._accelerator_type = DeviceType .IPU
220- elif self .distributed_backend == DeviceType .GPU :
219+ self ._accelerator_type = _AcceleratorType .IPU
220+ elif self .distributed_backend == _AcceleratorType .GPU :
221221 if not self .has_gpu :
222222 msg = "you didn't pass `gpus` to `Trainer`" if torch .cuda .is_available () else "GPUs are not available"
223223 raise MisconfigurationException (f"You passed `accelerator='gpu'`, but { msg } ." )
224- self ._accelerator_type = DeviceType .GPU
225- elif self .distributed_backend == DeviceType .CPU :
224+ self ._accelerator_type = _AcceleratorType .GPU
225+ elif self .distributed_backend == _AcceleratorType .CPU :
226226 self ._set_devices_to_cpu_num_processes ()
227- self ._accelerator_type = DeviceType .CPU
227+ self ._accelerator_type = _AcceleratorType .CPU
228228
229229 if self .distributed_backend in self .accelerator_types :
230230 self .distributed_backend = None
@@ -250,29 +250,29 @@ def _warn_if_devices_flag_ignored(self) -> None:
250250 if self .devices is None :
251251 return
252252 devices_warning = f"The flag `devices={ self .devices } ` will be ignored, as you have set"
253- if self .distributed_backend in ("auto" , DeviceType .TPU ):
253+ if self .distributed_backend in ("auto" , _AcceleratorType .TPU ):
254254 if self .tpu_cores is not None :
255255 rank_zero_warn (f"{ devices_warning } `tpu_cores={ self .tpu_cores } `" )
256- elif self .distributed_backend in ("auto" , DeviceType .IPU ):
256+ elif self .distributed_backend in ("auto" , _AcceleratorType .IPU ):
257257 if self .ipus is not None :
258258 rank_zero_warn (f"{ devices_warning } `ipus={ self .ipus } `" )
259- elif self .distributed_backend in ("auto" , DeviceType .GPU ):
259+ elif self .distributed_backend in ("auto" , _AcceleratorType .GPU ):
260260 if self .gpus is not None :
261261 rank_zero_warn (f"{ devices_warning } `gpus={ self .gpus } `" )
262- elif self .distributed_backend in ("auto" , DeviceType .CPU ):
262+ elif self .distributed_backend in ("auto" , _AcceleratorType .CPU ):
263263 if self .num_processes != 1 :
264264 rank_zero_warn (f"{ devices_warning } `num_processes={ self .num_processes } `" )
265265
266266 def _set_devices_if_none (self ) -> None :
267267 if self .devices is not None :
268268 return
269- if self ._accelerator_type == DeviceType .TPU :
269+ if self ._accelerator_type == _AcceleratorType .TPU :
270270 self .devices = self .tpu_cores
271- elif self ._accelerator_type == DeviceType .IPU :
271+ elif self ._accelerator_type == _AcceleratorType .IPU :
272272 self .devices = self .ipus
273- elif self ._accelerator_type == DeviceType .GPU :
273+ elif self ._accelerator_type == _AcceleratorType .GPU :
274274 self .devices = self .gpus
275- elif self ._accelerator_type == DeviceType .CPU :
275+ elif self ._accelerator_type == _AcceleratorType .CPU :
276276 self .devices = self .num_processes
277277
278278 def _handle_accelerator_and_strategy (self ) -> None :
@@ -386,7 +386,7 @@ def handle_given_plugins(self) -> None:
386386
387387 @property
388388 def accelerator_types (self ) -> List [str ]:
389- return ["auto" ] + list (DeviceType )
389+ return ["auto" ] + list (_AcceleratorType )
390390
391391 @property
392392 def precision_plugin (self ) -> PrecisionPlugin :
@@ -424,7 +424,7 @@ def has_cpu(self) -> bool:
424424
425425 @property
426426 def use_cpu (self ) -> bool :
427- return self ._accelerator_type == DeviceType .CPU
427+ return self ._accelerator_type == _AcceleratorType .CPU
428428
429429 @property
430430 def has_gpu (self ) -> bool :
@@ -433,23 +433,23 @@ def has_gpu(self) -> bool:
433433 gpus = self .parallel_device_ids
434434 if gpus is not None and len (gpus ) > 0 :
435435 return True
436- return self ._map_devices_to_accelerator (DeviceType .GPU )
436+ return self ._map_devices_to_accelerator (_AcceleratorType .GPU )
437437
438438 @property
439439 def use_gpu (self ) -> bool :
440- return self ._accelerator_type == DeviceType .GPU and self .has_gpu
440+ return self ._accelerator_type == _AcceleratorType .GPU and self .has_gpu
441441
442442 @property
443443 def has_tpu (self ) -> bool :
444444 # Here, we are not checking for TPU availability, but instead if User has passed
445445 # `tpu_cores` to Trainer for training.
446446 if self .tpu_cores is not None :
447447 return True
448- return self ._map_devices_to_accelerator (DeviceType .TPU )
448+ return self ._map_devices_to_accelerator (_AcceleratorType .TPU )
449449
450450 @property
451451 def use_tpu (self ) -> bool :
452- return self ._accelerator_type == DeviceType .TPU and self .has_tpu
452+ return self ._accelerator_type == _AcceleratorType .TPU and self .has_tpu
453453
454454 @property
455455 def tpu_id (self ) -> Optional [int ]:
@@ -463,36 +463,36 @@ def has_ipu(self) -> bool:
463463 # `ipus` to Trainer for training.
464464 if self .ipus is not None or isinstance (self ._training_type_plugin , IPUPlugin ):
465465 return True
466- return self ._map_devices_to_accelerator (DeviceType .IPU )
466+ return self ._map_devices_to_accelerator (_AcceleratorType .IPU )
467467
468468 @property
469469 def use_ipu (self ) -> bool :
470- return self ._accelerator_type == DeviceType .IPU and self .has_ipu
470+ return self ._accelerator_type == _AcceleratorType .IPU and self .has_ipu
471471
472472 def _set_devices_to_cpu_num_processes (self ) -> None :
473473 if self .num_processes == 1 :
474- self ._map_devices_to_accelerator (DeviceType .CPU )
474+ self ._map_devices_to_accelerator (_AcceleratorType .CPU )
475475
476476 def _map_devices_to_accelerator (self , accelerator : str ) -> bool :
477477 if self .devices is None :
478478 return False
479- if accelerator == DeviceType .TPU and _TPU_AVAILABLE :
479+ if accelerator == _AcceleratorType .TPU and _TPU_AVAILABLE :
480480 if self .devices == "auto" :
481481 self .devices = TPUAccelerator .auto_device_count ()
482482 self .tpu_cores = device_parser .parse_tpu_cores (self .devices )
483483 return True
484- if accelerator == DeviceType .IPU and _IPU_AVAILABLE :
484+ if accelerator == _AcceleratorType .IPU and _IPU_AVAILABLE :
485485 if self .devices == "auto" :
486486 self .devices = IPUAccelerator .auto_device_count ()
487487 self .ipus = self .devices
488488 return True
489- if accelerator == DeviceType .GPU and torch .cuda .is_available ():
489+ if accelerator == _AcceleratorType .GPU and torch .cuda .is_available ():
490490 if self .devices == "auto" :
491491 self .devices = GPUAccelerator .auto_device_count ()
492492 self .gpus = self .devices
493493 self .parallel_device_ids = device_parser .parse_gpu_ids (self .devices )
494494 return True
495- if accelerator == DeviceType .CPU :
495+ if accelerator == _AcceleratorType .CPU :
496496 if self .devices == "auto" :
497497 self .devices = CPUAccelerator .auto_device_count ()
498498 if not isinstance (self .devices , int ):
@@ -829,7 +829,7 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
829829 if isinstance (self .distributed_backend , Accelerator ):
830830 return
831831
832- is_cpu_accelerator_type = self ._accelerator_type and self ._accelerator_type == DeviceType .CPU
832+ is_cpu_accelerator_type = self ._accelerator_type and self ._accelerator_type == _AcceleratorType .CPU
833833 _use_cpu = is_cpu_accelerator_type or self .distributed_backend and "cpu" in self .distributed_backend
834834
835835 if self .distributed_backend is None :
@@ -867,16 +867,16 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
867867 self .num_processes = os .cpu_count ()
868868 # special case with TPUs
869869 elif self .has_tpu and not _use_cpu :
870- self ._device_type = DeviceType .TPU
870+ self ._device_type = _AcceleratorType .TPU
871871 if isinstance (self .tpu_cores , int ):
872872 self ._distrib_type = _StrategyType .TPU_SPAWN
873873 elif self .has_ipu and not _use_cpu :
874- self ._device_type = DeviceType .IPU
874+ self ._device_type = _AcceleratorType .IPU
875875 elif self .distributed_backend and self ._distrib_type is None :
876876 self ._distrib_type = _StrategyType (self .distributed_backend )
877877
878878 if self .num_gpus > 0 and not _use_cpu :
879- self ._device_type = DeviceType .GPU
879+ self ._device_type = _AcceleratorType .GPU
880880
881881 _gpu_distrib_types = (_StrategyType .DP , _StrategyType .DDP , _StrategyType .DDP_SPAWN , _StrategyType .DDP2 )
882882 # DP and DDP2 cannot run without GPU
@@ -896,13 +896,13 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
896896 self .check_interactive_compatibility ()
897897
898898 # for DDP overwrite nb processes by requested GPUs
899- if self ._device_type == DeviceType .GPU and self ._distrib_type in (
899+ if self ._device_type == _AcceleratorType .GPU and self ._distrib_type in (
900900 _StrategyType .DDP ,
901901 _StrategyType .DDP_SPAWN ,
902902 ):
903903 self .num_processes = self .num_gpus
904904
905- if self ._device_type == DeviceType .GPU and self ._distrib_type == _StrategyType .DDP2 :
905+ if self ._device_type == _AcceleratorType .GPU and self ._distrib_type == _StrategyType .DDP2 :
906906 self .num_processes = self .num_nodes
907907
908908 # Horovod is an extra case...
@@ -965,27 +965,27 @@ def has_horovodrun() -> bool:
965965 def update_device_type_if_ipu_plugin (self ) -> None :
966966 # This allows the poptorch.Options that are passed into the IPUPlugin to be the source of truth,
967967 # which gives users the flexibility to not have to pass `ipus` flag directly to Trainer
968- if isinstance (self ._training_type_plugin , IPUPlugin ) and self ._device_type != DeviceType .IPU :
969- self ._device_type = DeviceType .IPU
968+ if isinstance (self ._training_type_plugin , IPUPlugin ) and self ._device_type != _AcceleratorType .IPU :
969+ self ._device_type = _AcceleratorType .IPU
970970
971971 def update_device_type_if_training_type_plugin_passed (self ) -> None :
972972 if isinstance (self .strategy , TrainingTypePlugin ) or any (
973973 isinstance (plug , TrainingTypePlugin ) for plug in self .plugins
974974 ):
975975 if self ._accelerator_type is not None :
976976 if self .use_ipu :
977- self ._device_type = DeviceType .IPU
977+ self ._device_type = _AcceleratorType .IPU
978978 elif self .use_tpu :
979- self ._device_type = DeviceType .TPU
979+ self ._device_type = _AcceleratorType .TPU
980980 elif self .use_gpu :
981- self ._device_type = DeviceType .GPU
981+ self ._device_type = _AcceleratorType .GPU
982982 else :
983983 if self .has_ipu :
984- self ._device_type = DeviceType .IPU
984+ self ._device_type = _AcceleratorType .IPU
985985 elif self .has_tpu :
986- self ._device_type = DeviceType .TPU
986+ self ._device_type = _AcceleratorType .TPU
987987 elif self .has_gpu :
988- self ._device_type = DeviceType .GPU
988+ self ._device_type = _AcceleratorType .GPU
989989
990990 def _set_distrib_type_if_training_type_plugin_passed (self ):
991991 # This is required as when `TrainingTypePlugin` instance is passed to either `strategy`
0 commit comments