3939from pytorch_lightning .plugins .environments .slurm_environment import SLURMEnvironment
4040from pytorch_lightning .plugins .environments .torchelastic_environment import TorchElasticEnvironment
4141from pytorch_lightning .tuner .auto_gpu_select import pick_multiple_gpus
42- from pytorch_lightning .utilities import _APEX_AVAILABLE , _NATIVE_AMP_AVAILABLE , AMPType , device_parser , rank_zero_only
42+ from pytorch_lightning .utilities import (
43+ _APEX_AVAILABLE ,
44+ _NATIVE_AMP_AVAILABLE ,
45+ _TPU_AVAILABLE ,
46+ AMPType ,
47+ device_parser ,
48+ DeviceType ,
49+ DistributedType ,
50+ rank_zero_only ,
51+ )
4352from pytorch_lightning .utilities .distributed import rank_zero_info , rank_zero_warn
4453from pytorch_lightning .utilities .exceptions import MisconfigurationException
4554
@@ -77,13 +86,9 @@ def __init__(
7786 amp_level ,
7887 cluster_environment ,
7988 ):
80-
8189 # initialization
82- self .use_dp = False
83- self .use_ddp = False
84- self .use_ddp2 = False
85- self .use_horovod = False
86- self .use_single_gpu = False
90+ self ._device_type = DeviceType .CPU
91+ self ._distrib_type = None
8792
8893 self .num_processes = num_processes
8994 self .tpu_cores = device_parser .parse_tpu_cores (tpu_cores )
@@ -149,6 +154,10 @@ def __init__(
149154
150155 self .replace_sampler_ddp = replace_sampler_ddp
151156
157+ @property
158+ def on_cpu (self ):
159+ return self ._device_type == DeviceType .CPU
160+
152161 @property
153162 def on_tpu (self ):
154163 return self .tpu_cores is not None
@@ -165,6 +174,22 @@ def on_gpu(self):
165174 gpus = self .parallel_device_ids
166175 return gpus is not None and len (gpus ) > 0 and torch .cuda .is_available ()
167176
177+ @property
178+ def use_dp (self ):
179+ return self ._distrib_type == DistributedType .DP
180+
181+ @property
182+ def use_ddp (self ):
183+ return self ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN )
184+
185+ @property
186+ def use_ddp2 (self ):
187+ return self ._distrib_type == DistributedType .DDP2
188+
189+ @property
190+ def use_horovod (self ):
191+ return self ._distrib_type == DistributedType .HOROVOD
192+
168193 @property
169194 def num_gpus (self ) -> int :
170195 gpus = self .parallel_device_ids
@@ -236,8 +261,8 @@ def select_training_type_plugin(self):
236261 elif self .use_ddp :
237262 use_slurm_ddp = self .use_ddp and self .is_slurm_managing_tasks
238263 use_torchelastic_ddp = self .use_ddp and self .is_using_torchelastic
239- use_ddp_spawn = self .use_ddp and self . distributed_backend == "ddp_spawn"
240- use_ddp_cpu_spawn = self .use_ddp and self .distributed_backend == "ddp_cpu"
264+ use_ddp_spawn = self ._distrib_type == DistributedType . DDP_SPAWN
265+ use_ddp_cpu_spawn = self .use_ddp and self .on_cpu
241266 use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self .is_using_torchelastic
242267 use_ddp_cpu_slurm = use_ddp_cpu_spawn and self .is_slurm_managing_tasks
243268 # use_ddp_sharded = self.distributed_backend == "ddp_sharded"
@@ -273,11 +298,10 @@ def select_training_type_plugin(self):
273298 plugin = DataParallelPlugin (parallel_devices = self .parallel_devices )
274299 elif self .use_horovod :
275300 plugin = HorovodPlugin (parallel_devices = self .parallel_devices )
301+ elif self .on_tpu :
302+ plugin = SingleTPUPlugin (self .tpu_id )
276303 else :
277- if self .on_tpu :
278- plugin = SingleTPUPlugin (self .tpu_id )
279- else :
280- plugin = SingleDevicePlugin (device = torch .device (f"cuda:{ self .root_gpu } " if self .on_gpu else "cpu" ))
304+ plugin = SingleDevicePlugin (device = torch .device (f"cuda:{ self .root_gpu } " if self .on_gpu else "cpu" ))
281305 return plugin
282306
283307 def select_accelerator (self ):
@@ -287,7 +311,7 @@ def select_accelerator(self):
287311
288312 if self .on_gpu :
289313 acc_cls = GPUAccelerator
290- elif self .on_gpu :
314+ elif self .on_tpu :
291315 acc_cls = TPUAccelerator
292316 else :
293317 acc_cls = CPUAccelerator
@@ -313,96 +337,84 @@ def select_cluster_environment(self):
313337 return env
314338
315339 def set_distributed_mode (self ):
316- # No distributed backend
340+
317341 if self .distributed_backend is None :
318- # horovod multi GPU
319342 if self .has_horovodrun ():
320343 self ._set_horovod_backend ()
321-
322- # DDP CPU
323- elif self .num_gpus == 0 :
324- if self .num_nodes > 1 or self .num_processes > 1 :
325- self .use_ddp = True
326-
327- # Single GPU
328- elif self .num_gpus == 1 :
329- self .use_single_gpu = True
330-
331- # Default: DDP-Spawn
344+ elif self .num_gpus == 0 and (self .num_nodes > 1 or self .num_processes > 1 ):
345+ self ._distrib_type = DistributedType .DDP
332346 elif self .num_gpus > 1 :
333347 rank_zero_warn (
334- "You requested multiple GPUs but did not specify a backend, e.g."
335- ' (distributed_backend="dp"|"ddp"|"ddp2").'
336- ' Setting distributed_backend="ddp_spawn" for you.'
348+ 'You requested multiple GPUs but did not specify a backend, e.g.'
349+ ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
337350 )
338351 self .distributed_backend = "ddp_spawn"
339352
340- # DP
341- if self .distributed_backend == "dp" :
342- # do nothing if num_gpus == 0
343- if self .num_gpus == 1 :
344- self .use_single_gpu = True
345- self .use_dp = True
346- elif self .num_gpus > 1 :
347- self .use_dp = True
348-
349- # DDP, DDP-Spawn
350- elif self .distributed_backend in ("ddp" , "ddp_spawn" ):
351- if self .num_gpus == 0 :
352- # DDP CPU
353- if self .num_nodes > 1 or self .num_processes > 1 :
354- self .use_ddp = True
355-
356- # DDP Single GPU
357- elif self .num_gpus == 1 :
358- self .use_single_gpu = True
359- self .use_ddp = True
360-
361- # DDP Multi GPU
362- elif self .num_gpus > 1 :
363- self .use_ddp = True
364- self .num_processes = self .num_gpus
365-
366- # DDP2
367- elif self .distributed_backend == "ddp2" :
368- # do nothing if num_gpus == 0
369- if self .num_gpus >= 1 :
370- self .use_ddp2 = True
371-
372- # DDP CPU
373- elif self .distributed_backend == "ddp_cpu" :
353+ # special case with DDP on CPUs
354+ if self .distributed_backend == "ddp_cpu" :
355+ self ._distrib_type = DistributedType .DDP
356+ self .data_parallel_device_ids = None
374357 if self .num_gpus > 0 :
375358 rank_zero_warn (
376- " You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs."
359+ ' You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
377360 )
378- self .parallel_device_ids = None
379- self .use_ddp = True
380-
381- # Sharded DDP
382- elif self .distributed_backend in ("ddp_sharded" , "ddp_sharded_spawn" ):
383- self .use_ddp = True
384-
385- # HOROVOD
386- elif self .distributed_backend == "horovod" :
361+ if self .num_processes is None :
362+ # define the max CPU available
363+ self .num_processes = os .cpu_count ()
364+ # special case with TPUs
365+ elif self .distributed_backend == 'tpu' :
366+ self ._device_type = DeviceType .TPU
367+ # set all other requested distrib. types adn if it was not set in the
368+ elif self .distributed_backend and self ._distrib_type is None :
369+ self ._distrib_type = DistributedType (self .distributed_backend )
370+
371+ # unless you request explicitly for CPU and some GPU are available use them
372+ _on_cpu = self .distributed_backend and 'cpu' in self .distributed_backend
373+ if (self .num_gpus > 0 and not _on_cpu ):
374+ self ._device_type = DeviceType .GPU
375+
376+ _distrib_types = (DistributedType .DP , DistributedType .DDP , DistributedType .DDP_SPAWN , DistributedType .DDP2 )
377+ # DP and DDP2 cannot run without GPU
378+ if (self .num_gpus == 0 and self ._distrib_type in _distrib_types ):
379+ rank_zero_warn (
380+ 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
381+ )
382+ # todo: in some cases it yield in comarison None and int
383+ if ((self .num_nodes and self .num_nodes > 1 ) or (self .num_processes and self .num_processes > 1 )):
384+ self ._distrib_type = DistributedType .DDP
385+ else :
386+ rank_zero_warn ('You are running on single node with no parallelization, so distributed has no effect.' )
387+ self ._distrib_type = None
388+
389+ # for DDP overwrite nb processes by requested GPUs
390+ if (
391+ self ._device_type == DeviceType .GPU
392+ and self ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN )
393+ ):
394+ self .num_processes = self .num_gpus
395+
396+ # Horovod si an extra case...
397+ if self .distributed_backend == "horovod" :
387398 self ._set_horovod_backend ()
388399
389400 # throw error to force user ddp or ddp2 choice
390- if self .num_nodes > 1 and not (self .use_ddp2 or self .use_ddp ):
401+ _ddp = (DistributedType .DDP , DistributedType .DDP_SPAWN , DistributedType .DDP2 )
402+ if (self .num_nodes > 1 and self ._distrib_type not in _ddp ):
391403 raise MisconfigurationException (
392- " DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. "
393- " To silence this warning set distributed_backend= ddp or distributed_backend= ddp2"
404+ ' DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
405+ ' To silence this warning set `accelerator=" ddp"` or `accelerator=" ddp2"`'
394406 )
395407
396- rank_zero_info (f" GPU available: { torch .cuda .is_available ()} , used: { self .on_gpu } " )
408+ rank_zero_info (f' GPU available: { torch .cuda .is_available ()} , used: { self ._device_type == DeviceType . GPU } ' )
397409 num_cores = self .tpu_cores if self .tpu_cores is not None else 0
398- rank_zero_info (f" TPU available: { XLA_AVAILABLE } , using: { num_cores } TPU cores" )
410+ rank_zero_info (f' TPU available: { _TPU_AVAILABLE } , using: { num_cores } TPU cores' )
399411
400- if torch .cuda .is_available () and not self .on_gpu :
412+ if torch .cuda .is_available () and self ._device_type != DeviceType . GPU :
401413 rank_zero_warn ("GPU available but not used. Set the --gpus flag when calling the script." )
402414
403415 def _set_horovod_backend (self ):
404416 self .check_horovod ()
405- self .use_horovod = True
417+ self ._distrib_type = DistributedType . HOROVOD
406418
407419 # Initialize Horovod to get rank / size info
408420 hvd .init ()
0 commit comments