Skip to content
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))


## [1.1.0] - 2020-12-09

Expand Down
54 changes: 18 additions & 36 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import platform
import time
from typing import Union
from typing import Type, Union

import pytest
import torch
Expand All @@ -14,64 +14,48 @@
from tests.base.boring_model import BoringModel, RandomDataset


@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_device():
plugin_parity_test(
accelerator='ddp_cpu',
max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
plugin_parity_test(
gpus=1,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
model_cls=SeedTrainLoaderModel,
)


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
plugin_parity_test(
gpus=1,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
model_cls=SeedTrainLoaderModel,
)


@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
plugin_parity_test(
gpus=2,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25,
)


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
Expand All @@ -81,13 +65,12 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25,
)


@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
Expand All @@ -97,7 +80,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
accelerator='ddp_spawn',
plugin='ddp_sharded',
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25,
)


Expand Down Expand Up @@ -133,8 +116,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):

@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Expand All @@ -145,14 +127,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderMultipleOptimizersModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Expand All @@ -163,7 +144,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers
)


Expand Down Expand Up @@ -259,13 +240,14 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda):


def plugin_parity_test(
model_cls: SeedTrainLoaderModel,
model_cls: Type[SeedTrainLoaderModel],
plugin: Union[str, DDPPlugin],
seed: int = 42,
accelerator: str = 'ddp_spawn',
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.1):
max_percent_speed_diff: float = 0.1,
):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate.
Expand Down
141 changes: 74 additions & 67 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, DeviceType, DistributedType
from pytorch_lightning import _logger as log
from pytorch_lightning import accelerators
from pytorch_lightning.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -81,10 +81,7 @@ def on_trainer_init(
# sync-bn backend
self.trainer.sync_batchnorm = sync_batchnorm

self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
self.trainer.on_tpu = self.trainer.tpu_cores is not None

self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None
self._parse_tpu_device_details(tpu_cores)

if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for `accelerator='ddp_cpu'`. Ignoring it.")
Expand All @@ -100,23 +97,10 @@ def on_trainer_init(

self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(self.trainer.gpus)
self.trainer.root_gpu = device_parser.determine_root_gpu_device(self.trainer.data_parallel_device_ids)
self.trainer.root_device = torch.device("cpu")

self.trainer.on_gpu = True if (self.trainer.data_parallel_device_ids and torch.cuda.is_available()) else False

# tpu state flags
self.trainer.use_tpu = False
self.trainer.tpu_local_core_rank = None
self.trainer.tpu_global_core_rank = None

# distributed backend choice
self.set_distributed_mode()

# override dist backend when using tpus
if self.trainer.on_tpu:
self.trainer.distributed_backend = "tpu"
self.trainer.use_tpu = True

# init flags for SLURM+DDP to work
self.trainer.world_size = 1
self.trainer.interactive_ddp_procs = []
Expand All @@ -135,10 +119,29 @@ def on_trainer_init(

self.trainer.replace_sampler_ddp = replace_sampler_ddp

def _parse_tpu_device_details(self, tpu_cores):
self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
if self.trainer.tpu_cores is not None:
if _TPU_AVAILABLE:
self.trainer._device_type = DeviceType.TPU
self.trainer.distributed_backend = "tpu"
else:
raise MisconfigurationException(
f"You have requested {self.trainer.tpu_cores} TPU cores but none is available."
)

self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None

# tpu state flags
self.trainer.tpu_local_core_rank = None
self.trainer.tpu_global_core_rank = None

def _map_deprecated_dist_backend(self, accelerator, distributed_backend):
if distributed_backend is not None:
rank_zero_warn(DeprecationWarning('distributed_backend has been renamed to accelerator. '
'Deprecated in 1.0.0, will be removed in 1.2.0'))
rank_zero_warn(
'`distributed_backend` has been renamed to accelerator. Deprecated in 1.0.0, will be removed in 1.2.0',
DeprecationWarning
)

# temporary mapping until we remove all the distributed_backend references
if accelerator is not None:
Expand Down Expand Up @@ -276,71 +279,75 @@ def select_accelerator(self):
accelerator_backend = accelerators.CPUAccelerator(self.trainer, cluster_env)
else:
raise MisconfigurationException(
f'Trainer(accelerator={self.trainer.distributed_backend} is not a supported backend'
f'`Trainer(accelerator={self.trainer.distributed_backend}, num_nodes={self.trainer.num_nodes},'
f' num_processes={self.trainer.num_processes}, ...)` is not a supported backend for'
f' num_gpus={self.trainer.num_gpus}'
)

return accelerator_backend

def set_distributed_mode(self):
self.trainer.use_dp = False
self.trainer.use_ddp = False
self.trainer.use_ddp2 = False
self.trainer.use_horovod = False
self.trainer.use_single_gpu = False

if self.trainer.distributed_backend is None:
if self.has_horovodrun():
self._set_horovod_backend()
elif self.trainer.num_gpus == 0:
if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1:
self.trainer.use_ddp = True # ddp_cpu
elif self.trainer.num_gpus == 1:
self.trainer.use_single_gpu = True
elif self.trainer.num_gpus == 0 and (self.trainer.num_nodes > 1 or self.trainer.num_processes > 1):
self.trainer._distrib_type = DistributedType.DDP
elif self.trainer.num_gpus > 1:
rank_zero_warn(
'You requested multiple GPUs but did not specify a backend, e.g.'
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`.'
' Setting `accelerator="ddp_spawn"` for you.'
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
)
self.trainer.distributed_backend = "ddp_spawn"

if self.trainer.distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.trainer.num_gpus == 1:
self.trainer.use_single_gpu = True
self.trainer.use_dp = True
elif self.trainer.num_gpus > 1:
self.trainer.use_dp = True

elif self.trainer.distributed_backend in ("ddp", "ddp_spawn"):
if self.trainer.num_gpus == 0:
if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1:
self.trainer.use_ddp = True # ddp_cpu
elif self.trainer.num_gpus == 1:
self.trainer.use_single_gpu = True
self.trainer.use_ddp = True
elif self.trainer.num_gpus > 1:
self.trainer.use_ddp = True
self.trainer.num_processes = self.trainer.num_gpus

elif self.trainer.distributed_backend == "ddp2":
# do nothing if num_gpus == 0
if self.trainer.num_gpus >= 1:
self.trainer.use_ddp2 = True
elif self.trainer.distributed_backend == "ddp_cpu":
# special case with DDP on CPUs
if self.trainer.distributed_backend == "ddp_cpu":
self.trainer._distrib_type = DistributedType.DDP
self.trainer.data_parallel_device_ids = None
if self.trainer.num_gpus > 0:
rank_zero_warn(
'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.'
)
self.trainer.use_ddp = True
self.trainer.data_parallel_device_ids = None
self.trainer.on_gpu = False
self.trainer.on_cpu = True
elif self.trainer.distributed_backend == "horovod":
if self.trainer.num_processes is None:
# define the max CPU available
self.trainer.num_processes = os.cpu_count()
# special case with TPUs
elif self.trainer.distributed_backend == 'tpu':
self.trainer._device_type = DeviceType.TPU
# set all other requested distrib. types adn if it was not set in the
elif self.trainer.distributed_backend and self.trainer._distrib_type is None:
self.trainer._distrib_type = DistributedType(self.trainer.distributed_backend)

# unless you request explicitly for CPU and some GPU are available use them
_on_cpu = self.trainer.distributed_backend and 'cpu' in self.trainer.distributed_backend
if (self.trainer.num_gpus > 0 and not _on_cpu):
self.trainer._device_type = DeviceType.GPU

_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if (self.trainer.num_gpus == 0 and self.trainer._distrib_type in _distrib_types):
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
# todo: in some cases it yield in comarison None and int
if ((self.trainer.num_nodes and self.trainer.num_nodes > 1)
or (self.trainer.num_processes and self.trainer.num_processes > 1)):
self.trainer._distrib_type = DistributedType.DDP
else:
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
self.trainer._distrib_type = None

# for DDP overwrite nb processes by requested GPUs
if (self.trainer._device_type == DeviceType.GPU
and self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)):
self.trainer.num_processes = self.trainer.num_gpus

# Horovod si an extra case...
if self.trainer.distributed_backend == "horovod":
self._set_horovod_backend()

# throw error to force user ddp or ddp2 choice
if self.trainer.num_nodes > 1 and not (self.trainer.use_ddp2 or self.trainer.use_ddp):
if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP):
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
Expand All @@ -350,20 +357,20 @@ def set_distributed_mode(self):
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')

if torch.cuda.is_available() and not self.trainer.on_gpu:
if torch.cuda.is_available() and self.trainer._device_type != DeviceType.GPU:
rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.')

def _set_horovod_backend(self):
self.check_horovod()
self.trainer.use_horovod = True
self._check_horovod()
self.trainer._distrib_type = DistributedType.HOROVOD

# Initialize Horovod to get rank / size info
hvd.init()
if self.trainer.on_gpu:
# Horovod assigns one local GPU per process
self.trainer.root_gpu = hvd.local_rank()

def check_horovod(self):
def _check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not _HOROVOD_AVAILABLE:
raise MisconfigurationException(
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/plugins/plugin_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def __init__(self, trainer):
self.plugins = []
self.ddp_plugin = DDPPlugin()
self.cloud_environment = None
self.amp_plugin = NativeAMPPlugin(trainer)
self.apex_plugin = ApexPlugin(trainer)

def on_trainer_init(self, plugins: Optional[Union[str, list]]):
self.plugins = plugins
Expand Down
Loading