Skip to content

Commit 8dcb9c7

Browse files
author
Jessica Lin
authored
Merge pull request #743 from drdarshan/master
Add example for distributed launcher
2 parents b9f3b2e + a5fdab9 commit 8dcb9c7

File tree

2 files changed

+265
-6
lines changed

2 files changed

+265
-6
lines changed

distributed/ddp/README.md

Lines changed: 188 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,192 @@
1-
`DistributedDataParallel` Example
1+
# Launching and configuring distributed data parallel applications
22

3-
This example demonstrates basic use cases of `DistributedDataParallel`, and
4-
also covers some more advanced scenarios including checkpointing models and
5-
combining DDP with model parallelism.
3+
In this tutorial we will demonstrate how to structure a distributed
4+
model training application so it can be launched conveniently on
5+
multiple nodes, each with multiple GPUs using PyTorch's distributed
6+
[launcher script](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py).
67

8+
# Prerequisites
9+
We assume you are familiar with [PyTorch](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html), the primitives it provides for [writing distributed applications](https://pytorch.org/tutorials/intermediate/dist_tuto.html) as well as training [distributed models](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
10+
11+
The example program in this tutorial uses the
12+
[`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) class for training models
13+
in a _data parallel_ fashion: multiple workers train the same global
14+
model by processing different portions of a large dataset, computing
15+
local gradients (aka _sub_-gradients) independently and then
16+
collectively synchronizing gradients using the AllReduce primitive. In
17+
HPC terminology, this model of execution is called _Single Program
18+
Multiple Data_ or SPMD since the same application runs on all
19+
application but each one operates on different portions of the
20+
training dataset.
21+
22+
# Application process topologies
23+
A Distributed Data Parallel (DDP) application can be executed on
24+
multiple nodes where each node can consist of multiple GPU
25+
devices. Each node in turn can run multiple copies of the DDP
26+
application, each of which processes its models on multiple GPUs.
27+
28+
Let _N_ be the number of nodes on which the application is running and
29+
_G_ be the number of GPUs per node. The total number of application
30+
processes running across all the nodes at one time is called the
31+
**World Size**, _W_ and the number of processes running on each node
32+
is referred to as the **Local World Size**, _L_.
33+
34+
Each application process is assigned two IDs: a _local_ rank in \[0,
35+
_L_-1\] and a _global_ rank in \[0, _W_-1\].
36+
37+
To illustrate the terminology defined above, consider the case where a
38+
DDP application is launched on two nodes, each of which has four
39+
GPUs. We would then like each process to span two GPUs each. The
40+
mapping of processes to nodes is shown in the figure below:
41+
42+
![ProcessMapping](https://user-images.githubusercontent.com/875518/77676984-4c81e400-6f4c-11ea-87d8-f2ff505a99da.png)
43+
44+
While there are quite a few ways to map processes to nodes, a good
45+
rule of thumb is to have one process span a single GPU. This enables
46+
the DDP application to have as many parallel reader streams as there
47+
are GPUs and in practice provides a good balance between I/O and
48+
computational costs. In the rest of this tutorial, we assume that the
49+
application follows this heuristic.
50+
51+
# Preparing and launching a DDP application
52+
Independent of how a DDP application is launched, each process needs a
53+
mechanism to know its global and local ranks. Once this is known, all
54+
processes create a `ProcessGroup` that enables them to participate in
55+
collective communication operations such as AllReduce.
56+
57+
A convenient way to start multiple DDP processes and initialize all
58+
values needed to create a `ProcessGroup` is to use the distributed
59+
`launch.py` script provided with PyTorch. The launcher can be found
60+
under the `distributed` subdirectory under the local `torch`
61+
installation directory. Here is a quick way to get the path of
62+
`launch.py` on any operating system:
63+
64+
```sh
65+
python -c "from os import path; import torch; print(path.join(path.dirname(torch.__file__), 'distributed', 'launch.py'))"
66+
```
67+
68+
This will print something like this:
69+
```sh
70+
/home/username/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/distributed/launch.py
771
```
8-
pip install -r requirements.txt
9-
python main.py
72+
73+
When the DDP application is started via `launch.py`, it passes the world size, global rank, master address and master port via environment variables and the local rank as a command-line parameter to each instance.
74+
To use the launcher, an application needs to adhere to the following convention:
75+
1. It must provide an entry-point function for a _single worker_. For example, it should not launch subprocesses using `torch.multiprocessing.spawn`
76+
2. It must use environment variables for initializing the process group.
77+
78+
For simplicity, the application can assume each process maps to a single GPU but in the next section we also show how a more general process-to-GPU mapping can be performed.
79+
80+
# Sample application
81+
The sample DDP application in this repo is based on the "Hello, World" [DDP tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
82+
83+
## Argument passing convention
84+
The DDP application takes two command-line arguments:
85+
1. `--local_rank`: This is passed in via `launch.py`
86+
2. `--local_world_size`: This is passed in explicitly and is typically either $1$ or the number of GPUs per node.
87+
88+
The application parses these and calls the `spmd_main` entrypoint:
89+
```py
90+
if __name__ == "__main__":
91+
parser = argparse.ArgumentParser()
92+
parser.add_argument("--local_rank", type=int, default=0)
93+
parser.add_argument("--local_world_size", type=int, default=1)
94+
args = parser.parse_args()
95+
spmd_main(args.local_world_size, args.local_rank)
1096
```
97+
In `spmd_main`, the process group is initialized with just the backend (NCCL or Gloo). The rest of the information needed for rendezvous comes from environment variables set by `launch.py`:
98+
```py
99+
def spmd_main(local_world_size, local_rank):
100+
# These are the parameters used to initialize the process group
101+
env_dict = {
102+
key: os.environ[key]
103+
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
104+
}
105+
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
106+
dist.init_process_group(backend="nccl")
107+
print(
108+
f"[{os.getpid()}] world_size = {dist.get_world_size()}, "
109+
+ f"rank = {dist.get_rank()}, backend={dist.get_backend()}"
110+
)
111+
112+
demo_basic(local_world_size, local_rank)
113+
114+
# Tear down the process group
115+
dist.destroy_process_group()
116+
```
117+
118+
Given the local rank and world size, the training function, `demo_basic` initializes the `DistributedDataParallel` model across a set of GPUs local to the node via `device_ids`:
119+
```py
120+
def demo_basic(local_world_size, local_rank):
121+
122+
# setup devices for this process. For local_world_size = 2, num_gpus = 8,
123+
# rank 1 uses GPUs [0, 1, 2, 3] and
124+
# rank 2 uses GPUs [4, 5, 6, 7].
125+
n = torch.cuda.device_count() // local_world_size
126+
device_ids = list(range(local_rank * n, (local_rank + 1) * n))
127+
128+
print(
129+
f"[{os.getpid()}] rank = {dist.get_rank()}, "
130+
+ f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids}"
131+
)
132+
133+
model = ToyModel().cuda(device_ids[0])
134+
ddp_model = DDP(model, device_ids)
135+
136+
loss_fn = nn.MSELoss()
137+
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
138+
139+
optimizer.zero_grad()
140+
outputs = ddp_model(torch.randn(20, 10))
141+
labels = torch.randn(20, 5).to(device_ids[0])
142+
loss_fn(outputs, labels).backward()
143+
optimizer.step()
144+
```
145+
146+
The application can be launched via `launch.py` as follows on a 8 GPU node with one process per GPU:
147+
```sh
148+
python /path/to/launch.py --nnode=1 --node_rank=0 --nproc_per_node=8 example.py --local_world_size=8
149+
```
150+
and produces an output similar to the one shown below:
151+
```sh
152+
*****************************************
153+
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
154+
*****************************************
155+
[238627] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '0', 'WORLD_SIZE': '8'}
156+
[238630] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '3', 'WORLD_SIZE': '8'}
157+
[238628] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '1', 'WORLD_SIZE': '8'}
158+
[238634] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '7', 'WORLD_SIZE': '8'}
159+
[238631] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '4', 'WORLD_SIZE': '8'}
160+
[238632] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '5', 'WORLD_SIZE': '8'}
161+
[238629] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '2', 'WORLD_SIZE': '8'}
162+
[238633] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '6', 'WORLD_SIZE': '8'}
163+
[238633] world_size = 8, rank = 6, backend=nccl
164+
[238628] world_size = 8, rank = 1, backend=nccl
165+
[238629] world_size = 8, rank = 2, backend=nccl
166+
[238631] world_size = 8, rank = 4, backend=nccl
167+
[238630] world_size = 8, rank = 3, backend=nccl
168+
[238632] world_size = 8, rank = 5, backend=nccl
169+
[238634] world_size = 8, rank = 7, backend=nccl
170+
[238627] world_size = 8, rank = 0, backend=nccl
171+
[238633] rank = 6, world_size = 8, n = 1, device_ids = [6]
172+
[238628] rank = 1, world_size = 8, n = 1, device_ids = [1]
173+
[238632] rank = 5, world_size = 8, n = 1, device_ids = [5]
174+
[238634] rank = 7, world_size = 8, n = 1, device_ids = [7]
175+
[238629] rank = 2, world_size = 8, n = 1, device_ids = [2]
176+
[238630] rank = 3, world_size = 8, n = 1, device_ids = [3]
177+
[238631] rank = 4, world_size = 8, n = 1, device_ids = [4]
178+
[238627] rank = 0, world_size = 8, n = 1, device_ids = [0]
179+
```
180+
Similarly, it can be launched with a single process that spans all 8 GPUs using:
181+
```sh
182+
python /path/to/launch.py --nnode=1 --node_rank=0 --nproc_per_node=1 example.py --local_world_size=1
183+
```
184+
that in turn produces the following output
185+
```sh
186+
[262816] Initializing process group with: {'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': '29500', 'RANK': '0', 'WORLD_SIZE': '1'}
187+
[262816]: world_size = 1, rank = 0, backend=nccl
188+
[262816] rank = 0, world_size = 1, n = 8, device_ids = [0, 1, 2, 3, 4, 5, 6, 7]
189+
```
190+
191+
# Conclusions
192+
As the author of a distributed data parallel application, your code needs to be aware of two types of resources: compute nodes and the GPUs within each node. The process of setting up bookkeeping to track how the set of GPUs is mapped to the processes of your application can be tedious and error-prone. We hope that by structuring your application as shown in this example and using the launcher, the mechanics of setting up distributed training can be significantly simplified.

distributed/ddp/example.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import argparse
2+
import os
3+
import sys
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.nn as nn
8+
import torch.optim as optim
9+
10+
from torch.nn.parallel import DistributedDataParallel as DDP
11+
12+
13+
class ToyModel(nn.Module):
14+
def __init__(self):
15+
super(ToyModel, self).__init__()
16+
self.net1 = nn.Linear(10, 10)
17+
self.relu = nn.ReLU()
18+
self.net2 = nn.Linear(10, 5)
19+
20+
def forward(self, x):
21+
return self.net2(self.relu(self.net1(x)))
22+
23+
24+
def demo_basic(local_world_size, local_rank):
25+
26+
# setup devices for this process. For local_world_size = 2, num_gpus = 8,
27+
# rank 1 uses GPUs [0, 1, 2, 3] and
28+
# rank 2 uses GPUs [4, 5, 6, 7].
29+
n = torch.cuda.device_count() // local_world_size
30+
device_ids = list(range(local_rank * n, (local_rank + 1) * n))
31+
32+
print(
33+
f"[{os.getpid()}] rank = {dist.get_rank()}, "
34+
+ f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids}"
35+
)
36+
37+
model = ToyModel().cuda(device_ids[0])
38+
ddp_model = DDP(model, device_ids)
39+
40+
loss_fn = nn.MSELoss()
41+
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
42+
43+
optimizer.zero_grad()
44+
outputs = ddp_model(torch.randn(20, 10))
45+
labels = torch.randn(20, 5).to(device_ids[0])
46+
loss_fn(outputs, labels).backward()
47+
optimizer.step()
48+
49+
50+
def spmd_main(local_world_size, local_rank):
51+
# These are the parameters used to initialize the process group
52+
env_dict = {
53+
key: os.environ[key]
54+
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
55+
}
56+
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
57+
dist.init_process_group(backend="nccl")
58+
print(
59+
f"[{os.getpid()}]: world_size = {dist.get_world_size()}, "
60+
+ f"rank = {dist.get_rank()}, backend={dist.get_backend()}"
61+
)
62+
63+
demo_basic(local_world_size, local_rank)
64+
65+
# Tear down the process group
66+
dist.destroy_process_group()
67+
68+
69+
if __name__ == "__main__":
70+
parser = argparse.ArgumentParser()
71+
# This is passed in via launch.py
72+
parser.add_argument("--local_rank", type=int, default=0)
73+
# This needs to be explicitly passed in
74+
parser.add_argument("--local_world_size", type=int, default=1)
75+
args = parser.parse_args()
76+
# The main entry point is called directly without using subprocess
77+
spmd_main(args.local_world_size, args.local_rank)

0 commit comments

Comments
 (0)