1515
1616import torch
1717
18- from pytorch_lightning .utilities import _HOROVOD_AVAILABLE
18+ from pytorch_lightning .utilities import _HOROVOD_AVAILABLE , DeviceType , DistributedType
1919from pytorch_lightning import _logger as log
2020from pytorch_lightning import accelerators
2121from pytorch_lightning .accelerators .accelerator import Accelerator
@@ -81,10 +81,7 @@ def on_trainer_init(
8181 # sync-bn backend
8282 self .trainer .sync_batchnorm = sync_batchnorm
8383
84- self .trainer .tpu_cores = device_parser .parse_tpu_cores (tpu_cores )
85- self .trainer .on_tpu = self .trainer .tpu_cores is not None
86-
87- self .trainer .tpu_id = self .trainer .tpu_cores [0 ] if isinstance (self .trainer .tpu_cores , list ) else None
84+ self ._parse_tpu_device_details (tpu_cores )
8885
8986 if num_processes != 1 and distributed_backend != "ddp_cpu" :
9087 rank_zero_warn ("num_processes is only used for `accelerator='ddp_cpu'`. Ignoring it." )
@@ -100,23 +97,10 @@ def on_trainer_init(
10097
10198 self .trainer .data_parallel_device_ids = device_parser .parse_gpu_ids (self .trainer .gpus )
10299 self .trainer .root_gpu = device_parser .determine_root_gpu_device (self .trainer .data_parallel_device_ids )
103- self .trainer .root_device = torch .device ("cpu" )
104-
105- self .trainer .on_gpu = True if (self .trainer .data_parallel_device_ids and torch .cuda .is_available ()) else False
106-
107- # tpu state flags
108- self .trainer .use_tpu = False
109- self .trainer .tpu_local_core_rank = None
110- self .trainer .tpu_global_core_rank = None
111100
112101 # distributed backend choice
113102 self .set_distributed_mode ()
114103
115- # override dist backend when using tpus
116- if self .trainer .on_tpu :
117- self .trainer .distributed_backend = "tpu"
118- self .trainer .use_tpu = True
119-
120104 # init flags for SLURM+DDP to work
121105 self .trainer .world_size = 1
122106 self .trainer .interactive_ddp_procs = []
@@ -135,10 +119,29 @@ def on_trainer_init(
135119
136120 self .trainer .replace_sampler_ddp = replace_sampler_ddp
137121
122+ def _parse_tpu_device_details (self , tpu_cores ):
123+ self .trainer .tpu_cores = device_parser .parse_tpu_cores (tpu_cores )
124+ if self .trainer .tpu_cores is not None :
125+ if _TPU_AVAILABLE :
126+ self .trainer ._device_type = DeviceType .TPU
127+ self .trainer .distributed_backend = "tpu"
128+ else :
129+ raise MisconfigurationException (
130+ f"You have requested { self .trainer .tpu_cores } TPU cores but none is available."
131+ )
132+
133+ self .trainer .tpu_id = self .trainer .tpu_cores [0 ] if isinstance (self .trainer .tpu_cores , list ) else None
134+
135+ # tpu state flags
136+ self .trainer .tpu_local_core_rank = None
137+ self .trainer .tpu_global_core_rank = None
138+
138139 def _map_deprecated_dist_backend (self , accelerator , distributed_backend ):
139140 if distributed_backend is not None :
140- rank_zero_warn (DeprecationWarning ('distributed_backend has been renamed to accelerator. '
141- 'Deprecated in 1.0.0, will be removed in 1.2.0' ))
141+ rank_zero_warn (
142+ '`distributed_backend` has been renamed to accelerator. Deprecated in 1.0.0, will be removed in 1.2.0' ,
143+ DeprecationWarning
144+ )
142145
143146 # temporary mapping until we remove all the distributed_backend references
144147 if accelerator is not None :
@@ -276,71 +279,75 @@ def select_accelerator(self):
276279 accelerator_backend = accelerators .CPUAccelerator (self .trainer , cluster_env )
277280 else :
278281 raise MisconfigurationException (
279- f'Trainer(accelerator={ self .trainer .distributed_backend } is not a supported backend'
282+ f'`Trainer(accelerator={ self .trainer .distributed_backend } , num_nodes={ self .trainer .num_nodes } ,'
283+ f' num_processes={ self .trainer .num_processes } , ...)` is not a supported backend for'
284+ f' num_gpus={ self .trainer .num_gpus } '
280285 )
281286
282287 return accelerator_backend
283288
284289 def set_distributed_mode (self ):
285- self .trainer .use_dp = False
286- self .trainer .use_ddp = False
287- self .trainer .use_ddp2 = False
288- self .trainer .use_horovod = False
289- self .trainer .use_single_gpu = False
290290
291291 if self .trainer .distributed_backend is None :
292292 if self .has_horovodrun ():
293293 self ._set_horovod_backend ()
294- elif self .trainer .num_gpus == 0 :
295- if self .trainer .num_nodes > 1 or self .trainer .num_processes > 1 :
296- self .trainer .use_ddp = True # ddp_cpu
297- elif self .trainer .num_gpus == 1 :
298- self .trainer .use_single_gpu = True
294+ elif self .trainer .num_gpus == 0 and (self .trainer .num_nodes > 1 or self .trainer .num_processes > 1 ):
295+ self .trainer ._distrib_type = DistributedType .DDP
299296 elif self .trainer .num_gpus > 1 :
300297 rank_zero_warn (
301298 'You requested multiple GPUs but did not specify a backend, e.g.'
302- ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`.'
303- ' Setting `accelerator="ddp_spawn"` for you.'
299+ ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
304300 )
305301 self .trainer .distributed_backend = "ddp_spawn"
306302
307- if self .trainer .distributed_backend == "dp" :
308- # do nothing if num_gpus == 0
309- if self .trainer .num_gpus == 1 :
310- self .trainer .use_single_gpu = True
311- self .trainer .use_dp = True
312- elif self .trainer .num_gpus > 1 :
313- self .trainer .use_dp = True
314-
315- elif self .trainer .distributed_backend in ("ddp" , "ddp_spawn" ):
316- if self .trainer .num_gpus == 0 :
317- if self .trainer .num_nodes > 1 or self .trainer .num_processes > 1 :
318- self .trainer .use_ddp = True # ddp_cpu
319- elif self .trainer .num_gpus == 1 :
320- self .trainer .use_single_gpu = True
321- self .trainer .use_ddp = True
322- elif self .trainer .num_gpus > 1 :
323- self .trainer .use_ddp = True
324- self .trainer .num_processes = self .trainer .num_gpus
325-
326- elif self .trainer .distributed_backend == "ddp2" :
327- # do nothing if num_gpus == 0
328- if self .trainer .num_gpus >= 1 :
329- self .trainer .use_ddp2 = True
330- elif self .trainer .distributed_backend == "ddp_cpu" :
303+ # special case with DDP on CPUs
304+ if self .trainer .distributed_backend == "ddp_cpu" :
305+ self .trainer ._distrib_type = DistributedType .DDP
306+ self .trainer .data_parallel_device_ids = None
331307 if self .trainer .num_gpus > 0 :
332308 rank_zero_warn (
333309 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
334310 )
335- self .trainer .use_ddp = True
336- self .trainer .data_parallel_device_ids = None
337- self .trainer .on_gpu = False
338- self .trainer .on_cpu = True
339- elif self .trainer .distributed_backend == "horovod" :
311+ if self .trainer .num_processes is None :
312+ # define the max CPU available
313+ self .trainer .num_processes = os .cpu_count ()
314+ # special case with TPUs
315+ elif self .trainer .distributed_backend == 'tpu' :
316+ self .trainer ._device_type = DeviceType .TPU
317+ # set all other requested distrib. types adn if it was not set in the
318+ elif self .trainer .distributed_backend and self .trainer ._distrib_type is None :
319+ self .trainer ._distrib_type = DistributedType (self .trainer .distributed_backend )
320+
321+ # unless you request explicitly for CPU and some GPU are available use them
322+ _on_cpu = self .trainer .distributed_backend and 'cpu' in self .trainer .distributed_backend
323+ if (self .trainer .num_gpus > 0 and not _on_cpu ):
324+ self .trainer ._device_type = DeviceType .GPU
325+
326+ _distrib_types = (DistributedType .DP , DistributedType .DDP , DistributedType .DDP_SPAWN , DistributedType .DDP2 )
327+ # DP and DDP2 cannot run without GPU
328+ if (self .trainer .num_gpus == 0 and self .trainer ._distrib_type in _distrib_types ):
329+ rank_zero_warn (
330+ 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
331+ )
332+ # todo: in some cases it yield in comarison None and int
333+ if ((self .trainer .num_nodes and self .trainer .num_nodes > 1 )
334+ or (self .trainer .num_processes and self .trainer .num_processes > 1 )):
335+ self .trainer ._distrib_type = DistributedType .DDP
336+ else :
337+ rank_zero_warn ('You are running on single node with no parallelization, so distributed has no effect.' )
338+ self .trainer ._distrib_type = None
339+
340+ # for DDP overwrite nb processes by requested GPUs
341+ if (self .trainer ._device_type == DeviceType .GPU
342+ and self .trainer ._distrib_type in (DistributedType .DDP , DistributedType .DDP_SPAWN )):
343+ self .trainer .num_processes = self .trainer .num_gpus
344+
345+ # Horovod si an extra case...
346+ if self .trainer .distributed_backend == "horovod" :
340347 self ._set_horovod_backend ()
341348
342349 # throw error to force user ddp or ddp2 choice
343- if self .trainer .num_nodes > 1 and not ( self .trainer .use_ddp2 or self . trainer . use_ddp ):
350+ if self .trainer .num_nodes > 1 and self .trainer ._distrib_type not in ( DistributedType . DDP2 , DistributedType . DDP ):
344351 raise MisconfigurationException (
345352 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
346353 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
@@ -350,20 +357,20 @@ def set_distributed_mode(self):
350357 num_cores = self .trainer .tpu_cores if self .trainer .tpu_cores is not None else 0
351358 rank_zero_info (f'TPU available: { _TPU_AVAILABLE } , using: { num_cores } TPU cores' )
352359
353- if torch .cuda .is_available () and not self .trainer .on_gpu :
360+ if torch .cuda .is_available () and self .trainer ._device_type != DeviceType . GPU :
354361 rank_zero_warn ('GPU available but not used. Set the --gpus flag when calling the script.' )
355362
356363 def _set_horovod_backend (self ):
357- self .check_horovod ()
358- self .trainer .use_horovod = True
364+ self ._check_horovod ()
365+ self .trainer ._distrib_type = DistributedType . HOROVOD
359366
360367 # Initialize Horovod to get rank / size info
361368 hvd .init ()
362369 if self .trainer .on_gpu :
363370 # Horovod assigns one local GPU per process
364371 self .trainer .root_gpu = hvd .local_rank ()
365372
366- def check_horovod (self ):
373+ def _check_horovod (self ):
367374 """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
368375 if not _HOROVOD_AVAILABLE :
369376 raise MisconfigurationException (
0 commit comments