Skip to content

Commit b72ed71

Browse files
authored
Refactor: clean trainer device & distrib setters (#5297)
* naive replace * simplify * clean * . * fix * . * fix * fix
1 parent 9575835 commit b72ed71

File tree

15 files changed

+169
-162
lines changed

15 files changed

+169
-162
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4545

4646
### Fixed
4747

48+
- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))
49+
4850

4951

5052
## [1.1.0] - 2020-12-09

benchmarks/test_sharded_parity.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import platform
33
import time
4-
from typing import Union
4+
from typing import Type, Union
55

66
import pytest
77
import torch
@@ -14,64 +14,48 @@
1414
from tests.base.boring_model import BoringModel, RandomDataset
1515

1616

17-
@pytest.mark.skipif(platform.system() == "Windows",
18-
reason="Distributed training is not supported on Windows")
19-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
20-
def test_ddp_sharded_plugin_correctness_one_device():
21-
plugin_parity_test(
22-
accelerator='ddp_cpu',
23-
max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls
24-
plugin=DDPShardedPlugin(),
25-
model_cls=SeedTrainLoaderModel
26-
)
27-
28-
2917
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
30-
@pytest.mark.skipif(platform.system() == "Windows",
31-
reason="Distributed training is not supported on Windows")
18+
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
3219
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
3320
def test_ddp_sharded_plugin_correctness_one_gpu():
3421
plugin_parity_test(
3522
gpus=1,
3623
accelerator='ddp_spawn',
3724
plugin=DDPShardedPlugin(),
38-
model_cls=SeedTrainLoaderModel
25+
model_cls=SeedTrainLoaderModel,
3926
)
4027

4128

4229
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
4330
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
44-
@pytest.mark.skipif(platform.system() == "Windows",
45-
reason="Distributed training is not supported on Windows")
31+
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
4632
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
4733
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
4834
plugin_parity_test(
4935
gpus=1,
5036
precision=16,
5137
accelerator='ddp_spawn',
5238
plugin=DDPShardedPlugin(),
53-
model_cls=SeedTrainLoaderModel
39+
model_cls=SeedTrainLoaderModel,
5440
)
5541

5642

5743
@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
5844
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
59-
@pytest.mark.skipif(platform.system() == "Windows",
60-
reason="Distributed training is not supported on Windows")
45+
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
6146
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
6247
def test_ddp_sharded_plugin_correctness_multi_gpu():
6348
plugin_parity_test(
6449
gpus=2,
6550
accelerator='ddp_spawn',
6651
plugin=DDPShardedPlugin(),
6752
model_cls=SeedTrainLoaderModel,
68-
max_percent_speed_diff=0.25
53+
max_percent_speed_diff=0.25,
6954
)
7055

7156

7257
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
73-
@pytest.mark.skipif(platform.system() == "Windows",
74-
reason="Distributed training is not supported on Windows")
58+
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
7559
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
7660
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
7761
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
@@ -81,13 +65,12 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
8165
accelerator='ddp_spawn',
8266
plugin=DDPShardedPlugin(),
8367
model_cls=SeedTrainLoaderModel,
84-
max_percent_speed_diff=0.25
68+
max_percent_speed_diff=0.25,
8569
)
8670

8771

8872
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
89-
@pytest.mark.skipif(platform.system() == "Windows",
90-
reason="Distributed training is not supported on Windows")
73+
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
9174
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
9275
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
9376
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
@@ -97,7 +80,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
9780
accelerator='ddp_spawn',
9881
plugin='ddp_sharded',
9982
model_cls=SeedTrainLoaderModel,
100-
max_percent_speed_diff=0.25
83+
max_percent_speed_diff=0.25,
10184
)
10285

10386

@@ -133,8 +116,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
133116

134117
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
135118
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
136-
@pytest.mark.skipif(platform.system() == "Windows",
137-
reason="Distributed training is not supported on Windows")
119+
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
138120
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
139121
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
140122
"""
@@ -145,14 +127,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
145127
gpus=2,
146128
accelerator='ddp_spawn',
147129
model_cls=SeedTrainLoaderMultipleOptimizersModel,
148-
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
130+
max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers
149131
)
150132

151133

152134
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
153135
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
154-
@pytest.mark.skipif(platform.system() == "Windows",
155-
reason="Distributed training is not supported on Windows")
136+
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
156137
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
157138
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
158139
"""
@@ -163,7 +144,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
163144
gpus=2,
164145
accelerator='ddp_spawn',
165146
model_cls=SeedTrainLoaderManualModel,
166-
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
147+
max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers
167148
)
168149

169150

@@ -259,13 +240,14 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda):
259240

260241

261242
def plugin_parity_test(
262-
model_cls: SeedTrainLoaderModel,
243+
model_cls: Type[SeedTrainLoaderModel],
263244
plugin: Union[str, DDPPlugin],
264245
seed: int = 42,
265246
accelerator: str = 'ddp_spawn',
266247
gpus: int = 0,
267248
precision: int = 32,
268-
max_percent_speed_diff: float = 0.1):
249+
max_percent_speed_diff: float = 0.1,
250+
):
269251
"""
270252
Ensures that the trained model is identical to the standard DDP implementation.
271253
Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate.

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 74 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torch
1717

18-
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
18+
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, DeviceType, DistributedType
1919
from pytorch_lightning import _logger as log
2020
from pytorch_lightning import accelerators
2121
from pytorch_lightning.accelerators.accelerator import Accelerator
@@ -81,10 +81,7 @@ def on_trainer_init(
8181
# sync-bn backend
8282
self.trainer.sync_batchnorm = sync_batchnorm
8383

84-
self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
85-
self.trainer.on_tpu = self.trainer.tpu_cores is not None
86-
87-
self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None
84+
self._parse_tpu_device_details(tpu_cores)
8885

8986
if num_processes != 1 and distributed_backend != "ddp_cpu":
9087
rank_zero_warn("num_processes is only used for `accelerator='ddp_cpu'`. Ignoring it.")
@@ -100,23 +97,10 @@ def on_trainer_init(
10097

10198
self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(self.trainer.gpus)
10299
self.trainer.root_gpu = device_parser.determine_root_gpu_device(self.trainer.data_parallel_device_ids)
103-
self.trainer.root_device = torch.device("cpu")
104-
105-
self.trainer.on_gpu = True if (self.trainer.data_parallel_device_ids and torch.cuda.is_available()) else False
106-
107-
# tpu state flags
108-
self.trainer.use_tpu = False
109-
self.trainer.tpu_local_core_rank = None
110-
self.trainer.tpu_global_core_rank = None
111100

112101
# distributed backend choice
113102
self.set_distributed_mode()
114103

115-
# override dist backend when using tpus
116-
if self.trainer.on_tpu:
117-
self.trainer.distributed_backend = "tpu"
118-
self.trainer.use_tpu = True
119-
120104
# init flags for SLURM+DDP to work
121105
self.trainer.world_size = 1
122106
self.trainer.interactive_ddp_procs = []
@@ -135,10 +119,29 @@ def on_trainer_init(
135119

136120
self.trainer.replace_sampler_ddp = replace_sampler_ddp
137121

122+
def _parse_tpu_device_details(self, tpu_cores):
123+
self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
124+
if self.trainer.tpu_cores is not None:
125+
if _TPU_AVAILABLE:
126+
self.trainer._device_type = DeviceType.TPU
127+
self.trainer.distributed_backend = "tpu"
128+
else:
129+
raise MisconfigurationException(
130+
f"You have requested {self.trainer.tpu_cores} TPU cores but none is available."
131+
)
132+
133+
self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None
134+
135+
# tpu state flags
136+
self.trainer.tpu_local_core_rank = None
137+
self.trainer.tpu_global_core_rank = None
138+
138139
def _map_deprecated_dist_backend(self, accelerator, distributed_backend):
139140
if distributed_backend is not None:
140-
rank_zero_warn(DeprecationWarning('distributed_backend has been renamed to accelerator. '
141-
'Deprecated in 1.0.0, will be removed in 1.2.0'))
141+
rank_zero_warn(
142+
'`distributed_backend` has been renamed to accelerator. Deprecated in 1.0.0, will be removed in 1.2.0',
143+
DeprecationWarning
144+
)
142145

143146
# temporary mapping until we remove all the distributed_backend references
144147
if accelerator is not None:
@@ -276,71 +279,75 @@ def select_accelerator(self):
276279
accelerator_backend = accelerators.CPUAccelerator(self.trainer, cluster_env)
277280
else:
278281
raise MisconfigurationException(
279-
f'Trainer(accelerator={self.trainer.distributed_backend} is not a supported backend'
282+
f'`Trainer(accelerator={self.trainer.distributed_backend}, num_nodes={self.trainer.num_nodes},'
283+
f' num_processes={self.trainer.num_processes}, ...)` is not a supported backend for'
284+
f' num_gpus={self.trainer.num_gpus}'
280285
)
281286

282287
return accelerator_backend
283288

284289
def set_distributed_mode(self):
285-
self.trainer.use_dp = False
286-
self.trainer.use_ddp = False
287-
self.trainer.use_ddp2 = False
288-
self.trainer.use_horovod = False
289-
self.trainer.use_single_gpu = False
290290

291291
if self.trainer.distributed_backend is None:
292292
if self.has_horovodrun():
293293
self._set_horovod_backend()
294-
elif self.trainer.num_gpus == 0:
295-
if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1:
296-
self.trainer.use_ddp = True # ddp_cpu
297-
elif self.trainer.num_gpus == 1:
298-
self.trainer.use_single_gpu = True
294+
elif self.trainer.num_gpus == 0 and (self.trainer.num_nodes > 1 or self.trainer.num_processes > 1):
295+
self.trainer._distrib_type = DistributedType.DDP
299296
elif self.trainer.num_gpus > 1:
300297
rank_zero_warn(
301298
'You requested multiple GPUs but did not specify a backend, e.g.'
302-
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`.'
303-
' Setting `accelerator="ddp_spawn"` for you.'
299+
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
304300
)
305301
self.trainer.distributed_backend = "ddp_spawn"
306302

307-
if self.trainer.distributed_backend == "dp":
308-
# do nothing if num_gpus == 0
309-
if self.trainer.num_gpus == 1:
310-
self.trainer.use_single_gpu = True
311-
self.trainer.use_dp = True
312-
elif self.trainer.num_gpus > 1:
313-
self.trainer.use_dp = True
314-
315-
elif self.trainer.distributed_backend in ("ddp", "ddp_spawn"):
316-
if self.trainer.num_gpus == 0:
317-
if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1:
318-
self.trainer.use_ddp = True # ddp_cpu
319-
elif self.trainer.num_gpus == 1:
320-
self.trainer.use_single_gpu = True
321-
self.trainer.use_ddp = True
322-
elif self.trainer.num_gpus > 1:
323-
self.trainer.use_ddp = True
324-
self.trainer.num_processes = self.trainer.num_gpus
325-
326-
elif self.trainer.distributed_backend == "ddp2":
327-
# do nothing if num_gpus == 0
328-
if self.trainer.num_gpus >= 1:
329-
self.trainer.use_ddp2 = True
330-
elif self.trainer.distributed_backend == "ddp_cpu":
303+
# special case with DDP on CPUs
304+
if self.trainer.distributed_backend == "ddp_cpu":
305+
self.trainer._distrib_type = DistributedType.DDP
306+
self.trainer.data_parallel_device_ids = None
331307
if self.trainer.num_gpus > 0:
332308
rank_zero_warn(
333309
'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
334310
)
335-
self.trainer.use_ddp = True
336-
self.trainer.data_parallel_device_ids = None
337-
self.trainer.on_gpu = False
338-
self.trainer.on_cpu = True
339-
elif self.trainer.distributed_backend == "horovod":
311+
if self.trainer.num_processes is None:
312+
# define the max CPU available
313+
self.trainer.num_processes = os.cpu_count()
314+
# special case with TPUs
315+
elif self.trainer.distributed_backend == 'tpu':
316+
self.trainer._device_type = DeviceType.TPU
317+
# set all other requested distrib. types adn if it was not set in the
318+
elif self.trainer.distributed_backend and self.trainer._distrib_type is None:
319+
self.trainer._distrib_type = DistributedType(self.trainer.distributed_backend)
320+
321+
# unless you request explicitly for CPU and some GPU are available use them
322+
_on_cpu = self.trainer.distributed_backend and 'cpu' in self.trainer.distributed_backend
323+
if (self.trainer.num_gpus > 0 and not _on_cpu):
324+
self.trainer._device_type = DeviceType.GPU
325+
326+
_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
327+
# DP and DDP2 cannot run without GPU
328+
if (self.trainer.num_gpus == 0 and self.trainer._distrib_type in _distrib_types):
329+
rank_zero_warn(
330+
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
331+
)
332+
# todo: in some cases it yield in comarison None and int
333+
if ((self.trainer.num_nodes and self.trainer.num_nodes > 1)
334+
or (self.trainer.num_processes and self.trainer.num_processes > 1)):
335+
self.trainer._distrib_type = DistributedType.DDP
336+
else:
337+
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
338+
self.trainer._distrib_type = None
339+
340+
# for DDP overwrite nb processes by requested GPUs
341+
if (self.trainer._device_type == DeviceType.GPU
342+
and self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)):
343+
self.trainer.num_processes = self.trainer.num_gpus
344+
345+
# Horovod si an extra case...
346+
if self.trainer.distributed_backend == "horovod":
340347
self._set_horovod_backend()
341348

342349
# throw error to force user ddp or ddp2 choice
343-
if self.trainer.num_nodes > 1 and not (self.trainer.use_ddp2 or self.trainer.use_ddp):
350+
if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP):
344351
raise MisconfigurationException(
345352
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
346353
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
@@ -350,20 +357,20 @@ def set_distributed_mode(self):
350357
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
351358
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')
352359

353-
if torch.cuda.is_available() and not self.trainer.on_gpu:
360+
if torch.cuda.is_available() and self.trainer._device_type != DeviceType.GPU:
354361
rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.')
355362

356363
def _set_horovod_backend(self):
357-
self.check_horovod()
358-
self.trainer.use_horovod = True
364+
self._check_horovod()
365+
self.trainer._distrib_type = DistributedType.HOROVOD
359366

360367
# Initialize Horovod to get rank / size info
361368
hvd.init()
362369
if self.trainer.on_gpu:
363370
# Horovod assigns one local GPU per process
364371
self.trainer.root_gpu = hvd.local_rank()
365372

366-
def check_horovod(self):
373+
def _check_horovod(self):
367374
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
368375
if not _HOROVOD_AVAILABLE:
369376
raise MisconfigurationException(

pytorch_lightning/plugins/plugin_connector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def __init__(self, trainer):
3131
self.plugins = []
3232
self.ddp_plugin = DDPPlugin()
3333
self.cloud_environment = None
34-
self.amp_plugin = NativeAMPPlugin(trainer)
35-
self.apex_plugin = ApexPlugin(trainer)
3634

3735
def on_trainer_init(self, plugins: Optional[Union[str, list]]):
3836
self.plugins = plugins

0 commit comments

Comments
 (0)