@@ -185,14 +185,21 @@ def select_accelerator(self):
185185 # ----------------------------------
186186 # choose an accelerator for the user
187187 # ----------------------------------
188- use_slurm_ddp = self .trainer .use_ddp and self .trainer .is_slurm_managing_tasks
188+ use_slurm_ddp = (
189+ self .trainer ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN )
190+ and self .trainer .is_slurm_managing_tasks
191+ )
189192
190193 # torchelastic or general non_slurm ddp
191194 te_flags_passed = 'WORLD_SIZE' in os .environ and ('GROUP_RANK' in os .environ or 'NODE_RANK' in os .environ )
192- use_torchelastic_ddp = self .trainer .use_ddp and te_flags_passed
195+ use_torchelastic_ddp = (
196+ self .trainer ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN ) and te_flags_passed
197+ )
193198
194- use_ddp_spawn = self .trainer .use_ddp and self .trainer .distributed_backend == "ddp_spawn"
195- use_ddp_cpu_spawn = self .trainer .use_ddp and self .trainer .distributed_backend == "ddp_cpu"
199+ use_ddp_cpu_spawn = (
200+ self .trainer ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN )
201+ and self .trainer ._device_type == DeviceType .CPU
202+ )
196203
197204 use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self ._is_using_torchelastic ()
198205 use_ddp_cpu_slurm = use_ddp_cpu_spawn and self .trainer .is_slurm_managing_tasks
@@ -204,8 +211,9 @@ def select_accelerator(self):
204211
205212 cluster_env = self ._select_environment ()
206213
214+ # TODO: clean-up this branching as most just select class and uses the very same arguments
207215 # choose the appropriate accelerator backend
208- if self .trainer .use_ddp2 :
216+ if self .trainer ._distrib_type == DistributedType . DDP2 :
209217 accelerator_backend = accelerators .DDP2Accelerator (
210218 self .trainer ,
211219 cluster_env ,
@@ -240,7 +248,7 @@ def select_accelerator(self):
240248 self .trainer .plugin_connector .ddp_plugin
241249 )
242250
243- elif use_ddp_spawn :
251+ elif self . trainer . _distrib_type == DistributedType . DDP_SPAWN :
244252 accelerator_backend = accelerators .DDPSpawnAccelerator (
245253 self .trainer ,
246254 nprocs = self .trainer .num_processes ,
@@ -263,16 +271,16 @@ def select_accelerator(self):
263271 ddp_plugin = self .trainer .plugin_connector .ddp_plugin
264272 )
265273
266- elif self .trainer .use_dp :
274+ elif self .trainer ._distrib_type == DistributedType . DP :
267275 accelerator_backend = accelerators .DataParallelAccelerator (self .trainer , cluster_env )
268276
269- elif self .trainer .use_horovod :
277+ elif self .trainer ._distrib_type == DistributedType . HOROVOD :
270278 accelerator_backend = accelerators .HorovodAccelerator (self .trainer , cluster_env )
271279
272- elif self .trainer .use_single_gpu :
280+ elif self .trainer ._device_type == DeviceType . GPU and self . trainer . num_gpus == 1 :
273281 accelerator_backend = accelerators .GPUAccelerator (self .trainer , cluster_env )
274282
275- elif self .trainer .use_tpu :
283+ elif self .trainer ._device_type == DeviceType . TPU :
276284 accelerator_backend = accelerators .TPUAccelerator (self .trainer , cluster_env )
277285
278286 elif self .trainer .distributed_backend is None :
@@ -347,13 +355,16 @@ def set_distributed_mode(self):
347355 self ._set_horovod_backend ()
348356
349357 # throw error to force user ddp or ddp2 choice
350- if self .trainer .num_nodes > 1 and self .trainer ._distrib_type not in (DistributedType .DDP2 , DistributedType .DDP ):
358+ _ddp = (DistributedType .DDP , DistributedType .DDP_SPAWN , DistributedType .DDP2 )
359+ if (self .trainer .num_nodes > 1 and self .trainer ._distrib_type not in _ddp ):
351360 raise MisconfigurationException (
352361 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
353362 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
354363 )
355364
356- rank_zero_info (f'GPU available: { torch .cuda .is_available ()} , used: { self .trainer .on_gpu } ' )
365+ rank_zero_info (
366+ f'GPU available: { torch .cuda .is_available ()} , used: { self .trainer ._device_type == DeviceType .GPU } '
367+ )
357368 num_cores = self .trainer .tpu_cores if self .trainer .tpu_cores is not None else 0
358369 rank_zero_info (f'TPU available: { _TPU_AVAILABLE } , using: { num_cores } TPU cores' )
359370
@@ -366,7 +377,7 @@ def _set_horovod_backend(self):
366377
367378 # Initialize Horovod to get rank / size info
368379 hvd .init ()
369- if self .trainer .on_gpu :
380+ if self .trainer ._device_type == DeviceType . GPU :
370381 # Horovod assigns one local GPU per process
371382 self .trainer .root_gpu = hvd .local_rank ()
372383
0 commit comments