Skip to content

Commit 3bacac7

Browse files
accelerator refactor - add parallel plugins (#5714)
Co-authored-by: Justus Schock <[email protected]>
1 parent 692f77b commit 3bacac7

File tree

7 files changed

+873
-3
lines changed

7 files changed

+873
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
107107
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))
108108

109109

110-
- Refactored Accelerators and Plugins (
111-
[#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715),
112-
)
110+
- Refactored Accelerators and Plugins
111+
* Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
112+
* Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714))
113113

114114

115115
### Deprecated
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import subprocess
16+
import sys
17+
from time import sleep
18+
from typing import Any, Dict, Optional, Union
19+
20+
import numpy as np
21+
import torch
22+
import torch.distributed as torch_distrib
23+
24+
from pytorch_lightning import _logger as log
25+
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
26+
from pytorch_lightning.distributed import LightningDistributed
27+
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
28+
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
29+
from pytorch_lightning.utilities import _HYDRA_AVAILABLE
30+
from pytorch_lightning.utilities.distributed import (
31+
find_free_network_port,
32+
rank_zero_only,
33+
ReduceOp,
34+
sync_ddp_if_available,
35+
)
36+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
37+
from pytorch_lightning.utilities.seed import seed_everything
38+
39+
if _HYDRA_AVAILABLE:
40+
from hydra.core.hydra_config import HydraConfig
41+
from hydra.utils import get_original_cwd, to_absolute_path
42+
43+
44+
class DDPPlugin(ParallelPlugin):
45+
"""
46+
Plugin for multi-process single-device training on one or multiple nodes.
47+
48+
The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`,
49+
where N is the number of devices (e.g. GPU) per node.
50+
It is very similar to how :mod:`torch.distributed.launch` launches processes.
51+
"""
52+
53+
distributed_backend = "ddp"
54+
55+
def __init__(
56+
self,
57+
parallel_devices,
58+
num_nodes=1,
59+
cluster_environment: ClusterEnvironment = None,
60+
sync_batchnorm=False,
61+
**kwargs: Dict[str, Any],
62+
) -> None:
63+
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
64+
self.interactive_ddp_procs = []
65+
self.num_nodes = num_nodes
66+
self.sync_batchnorm = sync_batchnorm
67+
self.dist = LightningDistributed()
68+
self._ddp_kwargs = kwargs
69+
self._has_spawned_children = False
70+
self.task_idx = None
71+
self.node_rank = 0
72+
self.num_processes = len(parallel_devices)
73+
74+
@property
75+
def root_device(self):
76+
return self.parallel_devices[self.local_rank]
77+
78+
@property
79+
def lightning_module(self):
80+
# the model may not be wrapped with DistributedDataParallel if calling this too early
81+
# fixme: uncomment when this class will actually be used
82+
# return unwrap_lightning_module(self._model)
83+
pass
84+
85+
@property
86+
def distributed_sampler_kwargs(self):
87+
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
88+
return distributed_sampler_kwargs
89+
90+
def setup(self, model):
91+
self._model = model
92+
93+
# start the other scripts
94+
# TODO: make sure this works, in torchelastic we should not launch child processes!
95+
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
96+
self._call_children_scripts()
97+
98+
# set the task idx
99+
self.task_idx = self.cluster_environment.local_rank()
100+
101+
def _call_children_scripts(self):
102+
103+
# bookkeeping of spawned processes
104+
assert self.global_rank == 0
105+
self._check_can_spawn_children()
106+
self._has_spawned_children = True
107+
108+
# DDP Environment variables
109+
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
110+
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))
111+
112+
# allow the user to pass the node rank
113+
node_rank = "0"
114+
node_rank = os.environ.get("NODE_RANK", node_rank)
115+
node_rank = os.environ.get("GROUP_RANK", node_rank)
116+
os.environ["NODE_RANK"] = node_rank
117+
os.environ["LOCAL_RANK"] = "0"
118+
119+
# when user is using hydra find the absolute path
120+
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path
121+
122+
# pull out the commands used to run the script and resolve the abs file path
123+
command = sys.argv
124+
try:
125+
full_path = path_lib(command[0])
126+
except Exception as e:
127+
full_path = os.path.abspath(command[0])
128+
129+
command[0] = full_path
130+
# use the same python interpreter and actually running
131+
command = [sys.executable] + command
132+
133+
# the visible devices tell us how many GPUs we want to use.
134+
# when the trainer script was called the device has already been scoped by the time
135+
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
136+
# but forward the GPUs selected via environment variables
137+
if self.parallel_devices is None:
138+
raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)")
139+
140+
os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices])
141+
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"
142+
143+
if self.lightning_module.logger is not None:
144+
os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version)
145+
146+
num_gpus = len(self.parallel_devices)
147+
os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}"
148+
149+
self.interactive_ddp_procs = []
150+
151+
for local_rank in range(1, self.num_processes):
152+
env_copy = os.environ.copy()
153+
env_copy["LOCAL_RANK"] = f"{local_rank}"
154+
155+
# remove env var if global seed not set
156+
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
157+
del env_copy["PL_GLOBAL_SEED"]
158+
159+
# start process
160+
# if hydra is available and initialized, make sure to set the cwd correctly
161+
cwd: Optional[str] = None
162+
if _HYDRA_AVAILABLE:
163+
if HydraConfig.initialized():
164+
cwd = get_original_cwd()
165+
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
166+
self.interactive_ddp_procs.append(proc)
167+
168+
# starting all processes at once can cause issues
169+
# with dataloaders delay between 1-10 seconds
170+
delay = np.random.uniform(1, 5, 1)[0]
171+
sleep(delay)
172+
173+
def _check_can_spawn_children(self):
174+
if self._has_spawned_children:
175+
raise RuntimeError(
176+
"You tried to run `.fit` or `.test` multiple times in the same script."
177+
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
178+
)
179+
180+
def set_world_ranks(self):
181+
self.local_rank = self.task_idx
182+
self.node_rank = self.cluster_environment.node_rank()
183+
self.global_rank = self.node_rank * self.num_processes + self.local_rank
184+
self.world_size = self.num_nodes * self.num_processes
185+
186+
def configure_ddp(self):
187+
# if unset, default `find_unused_parameters` `True`
188+
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
189+
self._model = LightningDistributedDataParallel(
190+
self.model,
191+
device_ids=self.determine_ddp_device_ids(),
192+
**self._ddp_kwargs,
193+
)
194+
195+
def determine_ddp_device_ids(self):
196+
if self.root_device.type == "cpu":
197+
return None
198+
return [self.root_device.index]
199+
200+
def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
201+
# TODO: From where to get cluster environment?
202+
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
203+
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
204+
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())
205+
torch_backend = "nccl" if self.on_gpu else "gloo"
206+
207+
if not torch.distributed.is_initialized():
208+
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
209+
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
210+
211+
def pre_training(self):
212+
# TODO: check if needed
213+
seed = os.environ.get("PL_GLOBAL_SEED")
214+
if seed is not None:
215+
seed_everything(int(seed))
216+
217+
# determine which process we are and world size
218+
self.set_world_ranks()
219+
220+
# set warning rank
221+
rank_zero_only.rank = self.global_rank
222+
223+
# set up server using proc 0's ip address
224+
# try to init for 20 times at max in case ports are taken
225+
# where to store ip_table
226+
self.init_ddp_connection(self.global_rank, self.world_size)
227+
228+
# TODO: we moved it to the trainer.fit after calling pre_training
229+
# ... need to double check that it is the correct place
230+
# self.trainer.call_setup_hook(self.model)
231+
232+
# on world_size=0 let everyone know training is starting
233+
if self.is_global_zero and not torch.distributed.is_initialized():
234+
log.info("-" * 100)
235+
log.info(f"distributed_backend={self.distributed_backend}")
236+
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
237+
log.info("-" * 100)
238+
239+
# set the ranks and devices
240+
self.dist.rank = self.global_rank
241+
self.dist.device = self.root_device
242+
243+
if self.sync_batchnorm:
244+
self.model = self.configure_sync_batchnorm(self.model)
245+
246+
# move the model to the correct device
247+
self.model_to_device()
248+
249+
self.configure_ddp()
250+
251+
self.barrier()
252+
253+
def post_training(self):
254+
if "WORLD_SIZE" in os.environ:
255+
del os.environ["WORLD_SIZE"]
256+
257+
def barrier(self, *args, **kwargs):
258+
if torch_distrib.is_initialized():
259+
torch_distrib.barrier()
260+
261+
def broadcast(self, obj: object, src: int = 0) -> object:
262+
return self.dist.broadcast(obj)
263+
264+
def model_to_device(self):
265+
if self.root_device.type == "cuda":
266+
torch.cuda.set_device(self.root_device)
267+
self.model.to(self.root_device)
268+
269+
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
270+
if isinstance(output, torch.Tensor):
271+
output = sync_ddp_if_available(output, group, reduce_op)
272+
return output
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
16+
from pytorch_lightning.core.step_result import Result
17+
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
18+
19+
20+
class DDP2Plugin(DDPPlugin):
21+
22+
def setup(self, model):
23+
self._model = model
24+
# set the task idx
25+
self.task_idx = self.cluster_environment.local_rank()
26+
# the difference to DDP is that we don't call children processes here
27+
28+
def reduce(self, output, *args, **kwargs):
29+
if isinstance(output, Result):
30+
output.dp_reduce()
31+
32+
elif isinstance(output, torch.Tensor):
33+
output = output.mean()
34+
35+
return output
36+
37+
@property
38+
def root_device(self):
39+
return self.parallel_devices[0]
40+
41+
def model_to_device(self):
42+
# no need to do anything when model is wrapped in torch.nn.DataParallel
43+
pass
44+
45+
@property
46+
def distributed_sampler_kwargs(self):
47+
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank)
48+
return distributed_sampler_kwargs
49+
50+
def set_world_ranks(self):
51+
self.local_rank = self.task_idx
52+
self.node_rank = self.cluster_environment.node_rank()
53+
self.global_rank = self.node_rank
54+
self.world_size = self.num_nodes

0 commit comments

Comments
 (0)