Skip to content

Commit 3019414

Browse files
committed
sync accelerator connector changes from dev1.2
1 parent df0900c commit 3019414

File tree

1 file changed

+93
-81
lines changed

1 file changed

+93
-81
lines changed

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 93 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,16 @@
3939
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
4040
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
4141
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
42-
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, AMPType, device_parser, rank_zero_only
42+
from pytorch_lightning.utilities import (
43+
_APEX_AVAILABLE,
44+
_NATIVE_AMP_AVAILABLE,
45+
_TPU_AVAILABLE,
46+
AMPType,
47+
device_parser,
48+
DeviceType,
49+
DistributedType,
50+
rank_zero_only,
51+
)
4352
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
4453
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4554

@@ -77,13 +86,9 @@ def __init__(
7786
amp_level,
7887
cluster_environment,
7988
):
80-
8189
# initialization
82-
self.use_dp = False
83-
self.use_ddp = False
84-
self.use_ddp2 = False
85-
self.use_horovod = False
86-
self.use_single_gpu = False
90+
self._device_type = DeviceType.CPU
91+
self._distrib_type = None
8792

8893
self.num_processes = num_processes
8994
self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
@@ -149,6 +154,10 @@ def __init__(
149154

150155
self.replace_sampler_ddp = replace_sampler_ddp
151156

157+
@property
158+
def on_cpu(self):
159+
return self._device_type == DeviceType.CPU
160+
152161
@property
153162
def on_tpu(self):
154163
return self.tpu_cores is not None
@@ -165,6 +174,22 @@ def on_gpu(self):
165174
gpus = self.parallel_device_ids
166175
return gpus is not None and len(gpus) > 0 and torch.cuda.is_available()
167176

177+
@property
178+
def use_dp(self):
179+
return self._distrib_type == DistributedType.DP
180+
181+
@property
182+
def use_ddp(self):
183+
return self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
184+
185+
@property
186+
def use_ddp2(self):
187+
return self._distrib_type == DistributedType.DDP2
188+
189+
@property
190+
def use_horovod(self):
191+
return self._distrib_type == DistributedType.HOROVOD
192+
168193
@property
169194
def num_gpus(self) -> int:
170195
gpus = self.parallel_device_ids
@@ -236,8 +261,8 @@ def select_training_type_plugin(self):
236261
elif self.use_ddp:
237262
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
238263
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
239-
use_ddp_spawn = self.use_ddp and self.distributed_backend == "ddp_spawn"
240-
use_ddp_cpu_spawn = self.use_ddp and self.distributed_backend == "ddp_cpu"
264+
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
265+
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
241266
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
242267
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
243268
# use_ddp_sharded = self.distributed_backend == "ddp_sharded"
@@ -273,11 +298,10 @@ def select_training_type_plugin(self):
273298
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
274299
elif self.use_horovod:
275300
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
301+
elif self.on_tpu:
302+
plugin = SingleTPUPlugin(self.tpu_id)
276303
else:
277-
if self.on_tpu:
278-
plugin = SingleTPUPlugin(self.tpu_id)
279-
else:
280-
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
304+
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
281305
return plugin
282306

283307
def select_accelerator(self):
@@ -287,7 +311,7 @@ def select_accelerator(self):
287311

288312
if self.on_gpu:
289313
acc_cls = GPUAccelerator
290-
elif self.on_gpu:
314+
elif self.on_tpu:
291315
acc_cls = TPUAccelerator
292316
else:
293317
acc_cls = CPUAccelerator
@@ -313,96 +337,84 @@ def select_cluster_environment(self):
313337
return env
314338

315339
def set_distributed_mode(self):
316-
# No distributed backend
340+
317341
if self.distributed_backend is None:
318-
# horovod multi GPU
319342
if self.has_horovodrun():
320343
self._set_horovod_backend()
321-
322-
# DDP CPU
323-
elif self.num_gpus == 0:
324-
if self.num_nodes > 1 or self.num_processes > 1:
325-
self.use_ddp = True
326-
327-
# Single GPU
328-
elif self.num_gpus == 1:
329-
self.use_single_gpu = True
330-
331-
# Default: DDP-Spawn
344+
elif self.num_gpus == 0 and (self.num_nodes > 1 or self.num_processes > 1):
345+
self._distrib_type = DistributedType.DDP
332346
elif self.num_gpus > 1:
333347
rank_zero_warn(
334-
"You requested multiple GPUs but did not specify a backend, e.g."
335-
' (distributed_backend="dp"|"ddp"|"ddp2").'
336-
' Setting distributed_backend="ddp_spawn" for you.'
348+
'You requested multiple GPUs but did not specify a backend, e.g.'
349+
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
337350
)
338351
self.distributed_backend = "ddp_spawn"
339352

340-
# DP
341-
if self.distributed_backend == "dp":
342-
# do nothing if num_gpus == 0
343-
if self.num_gpus == 1:
344-
self.use_single_gpu = True
345-
self.use_dp = True
346-
elif self.num_gpus > 1:
347-
self.use_dp = True
348-
349-
# DDP, DDP-Spawn
350-
elif self.distributed_backend in ("ddp", "ddp_spawn"):
351-
if self.num_gpus == 0:
352-
# DDP CPU
353-
if self.num_nodes > 1 or self.num_processes > 1:
354-
self.use_ddp = True
355-
356-
# DDP Single GPU
357-
elif self.num_gpus == 1:
358-
self.use_single_gpu = True
359-
self.use_ddp = True
360-
361-
# DDP Multi GPU
362-
elif self.num_gpus > 1:
363-
self.use_ddp = True
364-
self.num_processes = self.num_gpus
365-
366-
# DDP2
367-
elif self.distributed_backend == "ddp2":
368-
# do nothing if num_gpus == 0
369-
if self.num_gpus >= 1:
370-
self.use_ddp2 = True
371-
372-
# DDP CPU
373-
elif self.distributed_backend == "ddp_cpu":
353+
# special case with DDP on CPUs
354+
if self.distributed_backend == "ddp_cpu":
355+
self._distrib_type = DistributedType.DDP
356+
self.data_parallel_device_ids = None
374357
if self.num_gpus > 0:
375358
rank_zero_warn(
376-
"You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs."
359+
'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
377360
)
378-
self.parallel_device_ids = None
379-
self.use_ddp = True
380-
381-
# Sharded DDP
382-
elif self.distributed_backend in ("ddp_sharded", "ddp_sharded_spawn"):
383-
self.use_ddp = True
384-
385-
# HOROVOD
386-
elif self.distributed_backend == "horovod":
361+
if self.num_processes is None:
362+
# define the max CPU available
363+
self.num_processes = os.cpu_count()
364+
# special case with TPUs
365+
elif self.distributed_backend == 'tpu':
366+
self._device_type = DeviceType.TPU
367+
# set all other requested distrib. types adn if it was not set in the
368+
elif self.distributed_backend and self._distrib_type is None:
369+
self._distrib_type = DistributedType(self.distributed_backend)
370+
371+
# unless you request explicitly for CPU and some GPU are available use them
372+
_on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend
373+
if (self.num_gpus > 0 and not _on_cpu):
374+
self._device_type = DeviceType.GPU
375+
376+
_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
377+
# DP and DDP2 cannot run without GPU
378+
if (self.num_gpus == 0 and self._distrib_type in _distrib_types):
379+
rank_zero_warn(
380+
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
381+
)
382+
# todo: in some cases it yield in comarison None and int
383+
if ((self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1)):
384+
self._distrib_type = DistributedType.DDP
385+
else:
386+
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
387+
self._distrib_type = None
388+
389+
# for DDP overwrite nb processes by requested GPUs
390+
if (
391+
self._device_type == DeviceType.GPU
392+
and self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
393+
):
394+
self.num_processes = self.num_gpus
395+
396+
# Horovod si an extra case...
397+
if self.distributed_backend == "horovod":
387398
self._set_horovod_backend()
388399

389400
# throw error to force user ddp or ddp2 choice
390-
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
401+
_ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
402+
if (self.num_nodes > 1 and self._distrib_type not in _ddp):
391403
raise MisconfigurationException(
392-
"DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. "
393-
"To silence this warning set distributed_backend=ddp or distributed_backend=ddp2"
404+
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
405+
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
394406
)
395407

396-
rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}")
408+
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}')
397409
num_cores = self.tpu_cores if self.tpu_cores is not None else 0
398-
rank_zero_info(f"TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores")
410+
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')
399411

400-
if torch.cuda.is_available() and not self.on_gpu:
412+
if torch.cuda.is_available() and self._device_type != DeviceType.GPU:
401413
rank_zero_warn("GPU available but not used. Set the --gpus flag when calling the script.")
402414

403415
def _set_horovod_backend(self):
404416
self.check_horovod()
405-
self.use_horovod = True
417+
self._distrib_type = DistributedType.HOROVOD
406418

407419
# Initialize Horovod to get rank / size info
408420
hvd.init()

0 commit comments

Comments
 (0)