Skip to content

Commit 86b5c4a

Browse files
Shen Limrshenli
authored andcommitted
Adding distributed pipeline parallel tutorial
1 parent 43e3eb7 commit 86b5c4a

File tree

3 files changed

+378
-0
lines changed

3 files changed

+378
-0
lines changed
34.9 KB
Loading

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,13 @@ Welcome to PyTorch Tutorials
332332
:link: intermediate/rpc_param_server_tutorial.html
333333
:tags: Parallel-and-Distributed-Training
334334

335+
.. customcarditem::
336+
:header: Distributed Pipeline Parallelism Using RPC
337+
:card_description: Demonstrate how to implement distributed pipeline parallelism using RPC
338+
:image: _static/img/thumbnails/cropped/Distributed-Pipeline-Parallelism-Using-RPC.png
339+
:link: intermediate/dist_pipeline_parallel_tutorial.html
340+
:tags: Parallel-and-Distributed-Training
341+
335342
.. End of tutorial card section
336343
337344
.. raw:: html
@@ -497,3 +504,4 @@ Additional Resources
497504
intermediate/rpc_tutorial
498505
beginner/aws_distributed_training_tutorial
499506
intermediate/rpc_param_server_tutorial
507+
intermediate/dist_pipeline_parallel_tutorial
Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
Distributed Pipeline Parallelism Using RPC
2+
==========================================
3+
**Author**: `Shen Li <https://mrshenli.github.io/>`_
4+
5+
Prerequisites:
6+
7+
- `Single-Machine Model Parallel Best Practices <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
8+
- `Getting started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__
9+
- RRef helper functions:
10+
`RRef.rpc_sync() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__,
11+
`RRef.rpc_async() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_async>`__, and
12+
`RRef.remote() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.remote>`__
13+
14+
15+
16+
This tutorial uses a Resnet50 model to demonstrate implementing distributed
17+
pipeline parallelism with `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__
18+
APIs. This can be viewed as the distributed counterpart of the multi-GPU
19+
pipeline parallelism discussed in
20+
`Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_.
21+
22+
.. note:: This tutorial requires PyTorch v1.6.0 or above.
23+
24+
.. note:: Full source code of this tutorial can be found at
25+
`pytorch/examples <https://github.com/pytorch/examples/tree/master/distributed/rpc/pipeline>`__.
26+
27+
Basics
28+
------
29+
30+
31+
The previous tutorial, `Getting Started with Distributed RPC Framework <rpc_tutorial.html>`_
32+
shows how to use `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`_
33+
to implement distributed model parallelism for an RNN model. That tutorial uses
34+
one GPU to host the ``EmbeddingTable``, and the provided code works fine.
35+
However, if a model lives on multiple GPUs, it would require some extra steps to
36+
increase the amortized utilization of all GPUs. Pipeline parallelism is one type
37+
of paradigm that can help in this case.
38+
39+
In this tutorial, we use ``ResNet50`` as an example model which is also used by
40+
the `Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_
41+
tutorial. Similarly, the ``ResNet50`` model is divided into two shards and
42+
the input batch is partitioned into multiple splits and fed into the two model
43+
shards in a pipelined fashion. The difference is that, instead of parallelizing
44+
the execution using CUDA streams, this tutorial invokes asynchronous RPCs. So,
45+
the solution presented in this tutorial also works across machine boundaries.
46+
The remainder of this tutorial presents the implementation in four steps.
47+
48+
49+
50+
Step 1: Partition ResNet50 Model
51+
--------------------------------
52+
53+
This is the preparation step which implements ``ResNet50`` in two model shards.
54+
The code below is borrowed from the
55+
`ResNet implementation in torchvision <https://github.com/pytorch/vision/blob/7c077f6a986f05383bcb86b535aedb5a63dd5c4b/torchvision/models/resnet.py#L124>`_.
56+
The ``ResNetBase`` module contains the common building blocks and attributes for
57+
the two ResNet shards.
58+
59+
60+
.. code:: python
61+
import threading
62+
63+
import torch
64+
import torch.nn as nn
65+
66+
from torchvision.models.resnet import Bottleneck
67+
68+
num_classes = 1000
69+
70+
71+
def conv1x1(in_planes, out_planes, stride=1):
72+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
73+
74+
75+
class ResNetBase(nn.Module):
76+
def __init__(self, block, inplanes, num_classes=1000,
77+
groups=1, width_per_group=64, norm_layer=None):
78+
super(ResNetBase, self).__init__()
79+
80+
self._lock = threading.Lock()
81+
self._block = block
82+
self._norm_layer = nn.BatchNorm2d
83+
self.inplanes = inplanes
84+
self.dilation = 1
85+
self.groups = groups
86+
self.base_width = width_per_group
87+
88+
def _make_layer(self, planes, blocks, stride=1):
89+
norm_layer = self._norm_layer
90+
downsample = None
91+
previous_dilation = self.dilation
92+
if stride != 1 or self.inplanes != planes * self._block.expansion:
93+
downsample = nn.Sequential(
94+
conv1x1(self.inplanes, planes * self._block.expansion, stride),
95+
norm_layer(planes * self._block.expansion),
96+
)
97+
98+
layers = []
99+
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
100+
self.base_width, previous_dilation, norm_layer))
101+
self.inplanes = planes * self._block.expansion
102+
for _ in range(1, blocks):
103+
layers.append(self._block(self.inplanes, planes, groups=self.groups,
104+
base_width=self.base_width, dilation=self.dilation,
105+
norm_layer=norm_layer))
106+
107+
return nn.Sequential(*layers)
108+
109+
def parameter_rrefs(self):
110+
return [RRef(p) for p in self.parameters()]
111+
112+
113+
Now, we are ready to define the two model shards. For the constructor, we
114+
simply split all ResNet50 layers into two parts and move each part into the
115+
provided device. The ``forward`` functions of both shards take an ``RRef`` of
116+
the input data, fetch the data locally, and then move it to the expected device.
117+
After applying all layers to the input, it moves the output to CPU and returns.
118+
It is because the RPC API requires tensors to reside on CPU to avoid invalid
119+
device errors when the numbers of devices in the caller and the callee do not
120+
match.
121+
122+
123+
.. code:: python
124+
125+
class ResNetShard1(ResNetBase):
126+
def __init__(self, device, *args, **kwargs):
127+
super(ResNetShard1, self).__init__(
128+
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)
129+
130+
self.device = device
131+
self.seq = nn.Sequential(
132+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
133+
self._norm_layer(self.inplanes),
134+
nn.ReLU(inplace=True),
135+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
136+
self._make_layer(64, 3),
137+
self._make_layer(128, 4, stride=2)
138+
).to(self.device)
139+
140+
for m in self.modules():
141+
if isinstance(m, nn.Conv2d):
142+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
143+
elif isinstance(m, nn.BatchNorm2d):
144+
nn.init.constant_(m.weight, 1)
145+
nn.init.constant_(m.bias, 0)
146+
147+
def forward(self, x_rref):
148+
x = x_rref.to_here().to(self.device)
149+
with self._lock:
150+
out = self.seq(x)
151+
return out.cpu()
152+
153+
154+
class ResNetShard2(ResNetBase):
155+
def __init__(self, device, *args, **kwargs):
156+
super(ResNetShard2, self).__init__(
157+
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)
158+
159+
self.device = device
160+
self.seq = nn.Sequential(
161+
self._make_layer(256, 6, stride=2),
162+
self._make_layer(512, 3, stride=2),
163+
nn.AdaptiveAvgPool2d((1, 1)),
164+
).to(self.device)
165+
166+
self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)
167+
168+
def forward(self, x_rref):
169+
x = x_rref.to_here().to(self.device)
170+
with self._lock:
171+
out = self.fc(torch.flatten(self.seq(x), 1))
172+
return out.cpu()
173+
174+
175+
Step 2: Stitch ResNet50 Model Shards Into One Module
176+
----------------------------------------------------
177+
178+
179+
Then, we create a ``DistResNet50`` module to assemble the two shards and
180+
implement the pipeline parallel logic. In the constructor, we use two
181+
``rpc.remote`` calls to put the two shards on two different RPC workers
182+
respectively and hold on to the ``RRef`` to the two model parts so that they
183+
can be referenced in the forward pass. The ``forward`` function
184+
splits the input batch into multiple micro-batches, and feeds these
185+
micro-batches to the two model parts in a pipelined fashion. It first uses an
186+
``rpc.remote`` call to apply the first shard to a micro-batch and then forwards
187+
the returned intermediate output ``RRef`` to the second model shard. After that,
188+
it collects the ``Future`` of all micro-outputs, and waits for all of them after
189+
the loop. Note that both ``remote()`` and ``rpc_async()`` return immediately and
190+
run asynchronously. Therefore, the entire loop is non-blocking, and will launch
191+
multiple RPCs concurrently. The execution order of one micro-batch on two model
192+
parts are preserved by intermediate output ``y_rref``. The execution order
193+
across micro-batches does not matter. In the end, the forward function
194+
concatenates outputs of all micro-batches into one single output tensor and
195+
returns. The ``parameter_rrefs`` function is a helper to
196+
simplify distributed optimizer construction, which will be used later.
197+
198+
199+
200+
.. code:: python
201+
202+
class DistResNet50(nn.Module):
203+
def __init__(self, num_split, workers, *args, **kwargs):
204+
super(DistResNet50, self).__init__()
205+
206+
self.num_split = num_split
207+
208+
# Put the first part of the ResNet50 on workers[0]
209+
self.p1_rref = rpc.remote(
210+
workers[0],
211+
ResNetShard1,
212+
args = ("cuda:0",) + args,
213+
kwargs = kwargs
214+
)
215+
216+
# Put the second part of the ResNet50 on workers[1]
217+
self.p2_rref = rpc.remote(
218+
workers[1],
219+
ResNetShard2,
220+
args = ("cuda:1",) + args,
221+
kwargs = kwargs
222+
)
223+
224+
def forward(self, xs):
225+
out_futures = []
226+
for x in iter(xs.split(self.split_size, dim=0)):
227+
x_rref = RRef(x)
228+
y_rref = self.p1_rref.remote().forward(x_rref)
229+
z_fut = self.p2_rref.rpc_async().forward(y_rref)
230+
out_futures.append(z_fut)
231+
232+
return torch.cat(torch.futures.wait_all(out_futures))
233+
234+
def parameter_rrefs(self):
235+
remote_params = []
236+
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
237+
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
238+
return remote_params
239+
240+
241+
Step 3: Define The Training Loop
242+
--------------------------------
243+
244+
245+
After defining the model, let us implement the training loop. We use a
246+
dedicated "master" worker to prepare random inputs and labels, and control the
247+
distributed backward pass and distributed optimizer step. It first creates an
248+
instance of the ``DistResNet50`` module. It specifies the number of
249+
micro-batches for each batch, and also provides the name of the two RPC workers
250+
(i.e., "worker1", and "worker2"). Then it defines the loss function and creates
251+
a ``DistributedOptimizer`` using the ``parameter_rrefs()`` helper to acquire a
252+
list of parameter ``RRefs``. Then, the main training loop is very similar to
253+
regular local training, except that it uses ``dist_autograd`` to launch
254+
backward and provides the ``context_id`` for both backward and optimizer
255+
``step()``.
256+
257+
258+
.. code:: python
259+
260+
import torch.distributed.autograd as dist_autograd
261+
import torch.optim as optim
262+
from torch.distributed.optim import DistributedOptimizer
263+
264+
num_batches = 3
265+
batch_size = 120
266+
image_w = 128
267+
image_h = 128
268+
269+
270+
def run_master(num_split):
271+
# put the two model parts on worker1 and worker2 respectively
272+
model = DistResNet50(num_split, ["worker1", "worker2"])
273+
loss_fn = nn.MSELoss()
274+
opt = DistributedOptimizer(
275+
optim.SGD,
276+
model.parameter_rrefs(),
277+
lr=0.05,
278+
)
279+
280+
one_hot_indices = torch.LongTensor(batch_size) \
281+
.random_(0, num_classes) \
282+
.view(batch_size, 1)
283+
284+
for i in range(num_batches):
285+
print(f"Processing batch {i}")
286+
# generate random inputs and labels
287+
inputs = torch.randn(batch_size, 3, image_w, image_h)
288+
labels = torch.zeros(batch_size, num_classes) \
289+
.scatter_(1, one_hot_indices, 1)
290+
291+
with dist_autograd.context() as context_id:
292+
outputs = model(inputs)
293+
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
294+
opt.step(context_id)
295+
296+
297+
Step 4: Launch RPC Processes
298+
----------------------------
299+
300+
301+
Finally, the code below shows the target function for all processes. The main
302+
logic is defined in ``run_master``. The workers passively waiting for
303+
commands from the master, and hence simply runs ``init_rpc`` and ``shutdown``,
304+
where the ``shutdown`` by default will block until all RPC participants finish.
305+
306+
.. code:: python
307+
308+
import os
309+
import time
310+
311+
import torch.multiprocessing as mp
312+
313+
314+
def run_worker(rank, world_size, num_split):
315+
os.environ['MASTER_ADDR'] = 'localhost'
316+
os.environ['MASTER_PORT'] = '29500'
317+
options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=128)
318+
319+
if rank == 0:
320+
rpc.init_rpc(
321+
"master",
322+
rank=rank,
323+
world_size=world_size,
324+
rpc_backend_options=options
325+
)
326+
run_master(num_split)
327+
else:
328+
rpc.init_rpc(
329+
f"worker{rank}",
330+
rank=rank,
331+
world_size=world_size,
332+
rpc_backend_options=options
333+
)
334+
pass
335+
336+
# block until all rpcs finish
337+
rpc.shutdown()
338+
339+
340+
if __name__=="__main__":
341+
world_size = 3
342+
for num_split in [1, 2, 4, 8]:
343+
tik = time.time()
344+
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
345+
tok = time.time()
346+
print(f"number of splits = {num_split}, execution time = {tok - tik}")
347+
348+
349+
The output below shows the speedup attained by increasing the number of splits
350+
in each batch.
351+
352+
::
353+
354+
$ python main.py
355+
Processing batch 0
356+
Processing batch 1
357+
Processing batch 2
358+
number of splits = 1, execution time = 16.45062756538391
359+
Processing batch 0
360+
Processing batch 1
361+
Processing batch 2
362+
number of splits = 2, execution time = 12.329529762268066
363+
Processing batch 0
364+
Processing batch 1
365+
Processing batch 2
366+
number of splits = 4, execution time = 10.164430618286133
367+
Processing batch 0
368+
Processing batch 1
369+
Processing batch 2
370+
number of splits = 8, execution time = 9.076049566268921

0 commit comments

Comments
 (0)