Skip to content

Commit 3b8563a

Browse files
Shen Limrshenli
authored andcommitted
Adding distributed pipeline parallel tutorial
1 parent 43e3eb7 commit 3b8563a

File tree

3 files changed

+375
-0
lines changed

3 files changed

+375
-0
lines changed
34.9 KB
Loading

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,13 @@ Welcome to PyTorch Tutorials
332332
:link: intermediate/rpc_param_server_tutorial.html
333333
:tags: Parallel-and-Distributed-Training
334334

335+
.. customcarditem::
336+
:header: Distributed Pipeline Parallelism Using RPC
337+
:card_description: Demonstrate how to implement distributed pipeline parallelism using RPC
338+
:image: _static/img/thumbnails/cropped/Distributed-Pipeline-Parallelism-Using-RPC.png
339+
:link: intermediate/dist_pipeline_parallel_tutorial.html
340+
:tags: Parallel-and-Distributed-Training
341+
335342
.. End of tutorial card section
336343
337344
.. raw:: html
@@ -497,3 +504,4 @@ Additional Resources
497504
intermediate/rpc_tutorial
498505
beginner/aws_distributed_training_tutorial
499506
intermediate/rpc_param_server_tutorial
507+
intermediate/dist_pipeline_parallel_tutorial
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
Distributed Pipeline Parallelism Using RPC
2+
==========================================
3+
**Author**: `Shen Li <https://mrshenli.github.io/>`_
4+
5+
Prerequisite:
6+
7+
- `Single-Machine Model Parallel Best Practices <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
8+
- `Getting started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__
9+
- RRef helper functions: `RRef.rpc_sync() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__,
10+
`RRef.rpc_async() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_async>`__, and
11+
`RRef.remote() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.remote>`__
12+
13+
14+
This tutorial uses a Resnet50 model to demonstrate implementing distributed
15+
pipeline parallelism with `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__
16+
APIs. This can be viewed as the distributed counterpart of the multi-GPU
17+
pipeline parallelism discussed in
18+
`Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_.
19+
20+
21+
Basics
22+
------
23+
24+
25+
The previous tutorial, `Getting Started with Distributed RPC Framework <rpc_tutorial.html>`_
26+
shows how to use `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`_
27+
to implement distributed model parallelism for an RNN model. That tutorial uses
28+
one GPU to host the ``EmbeddingTable``, and the provided code works fine.
29+
However, if a model lives on multiple GPUs, it would require some extra steps to
30+
increase the amortized utilization of all GPUs. Pipeline parallelism is one type
31+
of paradigm that can help in this case.
32+
33+
In this tutorial, we use ``ResNet50`` as an example model which is also used by
34+
the `Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_
35+
tutorial. Similarly, the ``ResNet50`` model is divided into two shards and
36+
the input batch is partitioned into multiple splits and fed into the two model
37+
shards in a pipelined fashion. The difference is that, instead of parallelizing
38+
the execution using CUDA streams, this tutorial invokes asynchronous RPCs. So,
39+
the solution presented in this tutorial also works across machine boundaries.
40+
The remainder of this tutorial presents the implementation in four steps.
41+
42+
43+
44+
Step 1: Partition ResNet50 Model
45+
--------------------------------
46+
47+
This is the preparation step which implements ``ResNet50`` in two model shards.
48+
The code below is borrowed from the
49+
`ResNet implementation in torchvision <https://github.com/pytorch/vision/blob/7c077f6a986f05383bcb86b535aedb5a63dd5c4b/torchvision/models/resnet.py#L124>`_.
50+
The ``ResNetBase`` module contains the common building blocks and attributes for
51+
the two ResNet shards.
52+
53+
54+
.. code:: python
55+
import threading
56+
57+
import torch
58+
import torch.nn as nn
59+
60+
from torchvision.models.resnet import Bottleneck
61+
62+
num_classes = 1000
63+
64+
65+
def conv1x1(in_planes, out_planes, stride=1):
66+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
67+
68+
69+
class ResNetBase(nn.Module):
70+
def __init__(self, block, inplanes, num_classes=1000,
71+
groups=1, width_per_group=64, norm_layer=None):
72+
super(ResNetBase, self).__init__()
73+
74+
self._lock = threading.Lock()
75+
self._block = block
76+
self._norm_layer = nn.BatchNorm2d
77+
self.inplanes = inplanes
78+
self.dilation = 1
79+
self.groups = groups
80+
self.base_width = width_per_group
81+
82+
def _make_layer(self, planes, blocks, stride=1):
83+
norm_layer = self._norm_layer
84+
downsample = None
85+
previous_dilation = self.dilation
86+
if stride != 1 or self.inplanes != planes * self._block.expansion:
87+
downsample = nn.Sequential(
88+
conv1x1(self.inplanes, planes * self._block.expansion, stride),
89+
norm_layer(planes * self._block.expansion),
90+
)
91+
92+
layers = []
93+
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
94+
self.base_width, previous_dilation, norm_layer))
95+
self.inplanes = planes * self._block.expansion
96+
for _ in range(1, blocks):
97+
layers.append(self._block(self.inplanes, planes, groups=self.groups,
98+
base_width=self.base_width, dilation=self.dilation,
99+
norm_layer=norm_layer))
100+
101+
return nn.Sequential(*layers)
102+
103+
def parameter_rrefs(self):
104+
return [RRef(p) for p in self.parameters()]
105+
106+
107+
Now, we are ready to define the two model shards. For the constructor, we
108+
simply split all ResNet50 layers into two parts and move each part into the
109+
provided device. The ``forward`` functions of both shards take an ``RRef`` of
110+
the input data, fetch the data locally, and then move it to the expected device.
111+
After applying all layers to the input, it moves the output to CPU and returns.
112+
It is because the RPC API requires tensors to reside on CPU to avoid invalid
113+
device errors when the numbers of devices in the caller and the callee do not
114+
match.
115+
116+
117+
.. code:: python
118+
119+
class ResNetShard1(ResNetBase):
120+
def __init__(self, device, *args, **kwargs):
121+
super(ResNetShard1, self).__init__(
122+
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)
123+
124+
self.device = device
125+
self.seq = nn.Sequential(
126+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
127+
self._norm_layer(self.inplanes),
128+
nn.ReLU(inplace=True),
129+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
130+
self._make_layer(64, 3),
131+
self._make_layer(128, 4, stride=2)
132+
).to(self.device)
133+
134+
for m in self.modules():
135+
if isinstance(m, nn.Conv2d):
136+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
137+
elif isinstance(m, nn.BatchNorm2d):
138+
nn.init.constant_(m.weight, 1)
139+
nn.init.constant_(m.bias, 0)
140+
141+
def forward(self, x_rref):
142+
x = x_rref.to_here().to(self.device)
143+
with self._lock:
144+
out = self.seq(x)
145+
return out.cpu()
146+
147+
148+
class ResNetShard2(ResNetBase):
149+
def __init__(self, device, *args, **kwargs):
150+
super(ResNetShard2, self).__init__(
151+
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)
152+
153+
self.device = device
154+
self.seq = nn.Sequential(
155+
self._make_layer(256, 6, stride=2),
156+
self._make_layer(512, 3, stride=2),
157+
nn.AdaptiveAvgPool2d((1, 1)),
158+
).to(self.device)
159+
160+
self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)
161+
162+
def forward(self, x_rref):
163+
x = x_rref.to_here().to(self.device)
164+
with self._lock:
165+
out = self.fc(torch.flatten(self.seq(x), 1))
166+
return out.cpu()
167+
168+
169+
Step 2: Stitch ResNet50 Model Shards Into One Module
170+
----------------------------------------------------
171+
172+
173+
Then, we create a ``DistResNet50`` module to assemble the two shards and
174+
implement the pipeline parallel logic. In the constructor, we use two
175+
``rpc.remote`` calls to put the two shards on two different RPC workers
176+
respectively and hold on to the ``RRef`` to the two model parts so that they
177+
can be referenced in the forward pass. The ``forward`` function
178+
splits the input batch into multiple micro-batches, and feeds these
179+
micro-batches to the two model parts in a pipelined fashion. It first uses an
180+
``rpc.remote`` call to apply the first shard to a micro-batch and then forwards
181+
the returned intermediate output ``RRef`` to the second model shard. After that,
182+
it collects the ``Future`` of all micro-outputs, and waits for all of them after
183+
the loop. Note that both ``remote()`` and ``rpc_async()`` return immediately and
184+
run asynchronously. Therefore, the entire loop is non-blocking, and will launch
185+
multiple RPCs concurrently. The execution order of one micro-batch on two model
186+
parts are preserved by intermediate output ``y_rref``. The execution order
187+
across micro-batches does not matter. In the end, the forward function
188+
concatenates outputs of all micro-batches into one single output tensor and
189+
returns. The ``parameter_rrefs`` function is a helper to
190+
simplify distributed optimizer construction, which will be used later.
191+
192+
193+
194+
.. code:: python
195+
196+
class DistResNet50(nn.Module):
197+
"""
198+
Assemble two parts as an nn.Module and define pipelining logic
199+
"""
200+
def __init__(self, num_split, workers, *args, **kwargs):
201+
super(DistResNet50, self).__init__()
202+
203+
self.num_split = num_split
204+
205+
# Put the first part of the ResNet50 on workers[0]
206+
self.p1_rref = rpc.remote(
207+
workers[0],
208+
ResNetShard1,
209+
args = ("cuda:0",) + args,
210+
kwargs = kwargs
211+
)
212+
213+
# Put the second part of the ResNet50 on workers[1]
214+
self.p2_rref = rpc.remote(
215+
workers[1],
216+
ResNetShard2,
217+
args = ("cuda:1",) + args,
218+
kwargs = kwargs
219+
)
220+
221+
def forward(self, xs):
222+
out_futures = []
223+
for x in iter(xs.split(self.split_size, dim=0)):
224+
x_rref = RRef(x)
225+
y_rref = self.p1_rref.remote().forward(x_rref)
226+
z_fut = self.p2_rref.rpc_async().forward(y_rref)
227+
out_futures.append(z_fut)
228+
229+
return torch.cat(torch.futures.wait_all(out_futures))
230+
231+
def parameter_rrefs(self):
232+
remote_params = []
233+
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
234+
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
235+
return remote_params
236+
237+
238+
Step 3: Define The Training Loop
239+
--------------------------------
240+
241+
242+
After defining the model, let us implement the training loop. We use a
243+
dedicated "master" worker to prepare random inputs and labels, and control the
244+
distributed backward pass and distributed optimizer step. It first creates an
245+
instance of the ``DistResNet50`` module. It specifies the number of
246+
micro-batches for each batch, and also provides the name of the two RPC workers
247+
(i.e., "worker1", and "worker2"). Then it defines the loss function and creates
248+
a ``DistributedOptimizer`` using the ``parameter_rrefs()`` helper to acquire a
249+
list of parameter ``RRefs``. Then, the main training loop is very similar to
250+
regular local training, except that it uses ``dist_autograd`` to launch
251+
backward and provides the ``context_id`` for both backward and optimizer
252+
``step()``.
253+
254+
255+
.. code:: python
256+
257+
import torch.distributed.autograd as dist_autograd
258+
import torch.optim as optim
259+
from torch.distributed.optim import DistributedOptimizer
260+
261+
num_batches = 3
262+
batch_size = 120
263+
image_w = 128
264+
image_h = 128
265+
266+
267+
def run_master(num_split):
268+
# put the two model parts on worker1 and worker2 respectively
269+
model = DistResNet50(num_split, ["worker1", "worker2"])
270+
loss_fn = nn.MSELoss()
271+
opt = DistributedOptimizer(
272+
optim.SGD,
273+
model.parameter_rrefs(),
274+
lr=0.05,
275+
)
276+
277+
one_hot_indices = torch.LongTensor(batch_size) \
278+
.random_(0, num_classes) \
279+
.view(batch_size, 1)
280+
281+
for i in range(num_batches):
282+
print(f"Processing batch {i}")
283+
# generate random inputs and labels
284+
inputs = torch.randn(batch_size, 3, image_w, image_h)
285+
labels = torch.zeros(batch_size, num_classes) \
286+
.scatter_(1, one_hot_indices, 1)
287+
288+
with dist_autograd.context() as context_id:
289+
outputs = model(inputs)
290+
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
291+
opt.step(context_id)
292+
293+
294+
Step 4: Launch RPC Processes
295+
----------------------------
296+
297+
298+
Finally, the code below shows the target function for all processes. The main
299+
logic is defined in ``run_master``. The workers passively waiting for
300+
commands from the master, and hence simply runs ``init_rpc`` and ``shutdown``,
301+
where the ``shutdown`` by default will block until all RPC participants finish.
302+
303+
.. code:: python
304+
305+
import os
306+
import time
307+
308+
import torch.multiprocessing as mp
309+
310+
311+
def run_worker(rank, world_size, num_split):
312+
os.environ['MASTER_ADDR'] = 'localhost'
313+
os.environ['MASTER_PORT'] = '29500'
314+
options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=128)
315+
316+
if rank == 0:
317+
rpc.init_rpc(
318+
"master",
319+
rank=rank,
320+
world_size=world_size,
321+
rpc_backend_options=options
322+
)
323+
run_master(num_split)
324+
else:
325+
rpc.init_rpc(
326+
f"worker{rank}",
327+
rank=rank,
328+
world_size=world_size,
329+
rpc_backend_options=options
330+
)
331+
pass
332+
333+
# block until all rpcs finish
334+
rpc.shutdown()
335+
336+
337+
if __name__=="__main__":
338+
world_size = 3
339+
for num_split in [1, 2, 4, 8]:
340+
tik = time.time()
341+
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
342+
tok = time.time()
343+
print(f"number of splits = {num_split}, execution time = {tok - tik}")
344+
345+
346+
The output below shows the speedup attained by increasing the number of splits
347+
in each batch.
348+
349+
::
350+
351+
$ python main.py
352+
Processing batch 0
353+
Processing batch 1
354+
Processing batch 2
355+
number of splits = 1, execution time = 16.45062756538391
356+
Processing batch 0
357+
Processing batch 1
358+
Processing batch 2
359+
number of splits = 2, execution time = 12.329529762268066
360+
Processing batch 0
361+
Processing batch 1
362+
Processing batch 2
363+
number of splits = 4, execution time = 10.164430618286133
364+
Processing batch 0
365+
Processing batch 1
366+
Processing batch 2
367+
number of splits = 8, execution time = 9.076049566268921

0 commit comments

Comments
 (0)