Skip to content

Commit 0adeab1

Browse files
Shen Limrshenli
authored andcommitted
Adding distributed pipeline parallel tutorial
1 parent 43e3eb7 commit 0adeab1

File tree

3 files changed

+375
-0
lines changed

3 files changed

+375
-0
lines changed
34.9 KB
Loading

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,13 @@ Welcome to PyTorch Tutorials
332332
:link: intermediate/rpc_param_server_tutorial.html
333333
:tags: Parallel-and-Distributed-Training
334334

335+
.. customcarditem::
336+
:header: Distributed Pipeline Parallelism Using RPC
337+
:card_description: Demonstrate how to implement distributed pipeline parallelism using RPC
338+
:image: _static/img/thumbnails/cropped/Distributed-Pipeline-Parallelism-Using-RPC.png
339+
:link: intermediate/dist_pipeline_parallel_tutorial.html
340+
:tags: Parallel-and-Distributed-Training
341+
335342
.. End of tutorial card section
336343
337344
.. raw:: html
@@ -497,3 +504,4 @@ Additional Resources
497504
intermediate/rpc_tutorial
498505
beginner/aws_distributed_training_tutorial
499506
intermediate/rpc_param_server_tutorial
507+
intermediate/dist_pipeline_parallel_tutorial
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
Distributed Pipeline Parallelism Using RPC
2+
==========================================
3+
**Author**: `Shen Li <https://mrshenli.github.io/>`_
4+
5+
Prerequisite:
6+
7+
- `Single-Machine Model Parallel Best Practices <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
8+
- `Getting started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__
9+
- RRef helper functions:
10+
`RRef.rpc_sync() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__,
11+
`RRef.rpc_async() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_async>`__, and
12+
`RRef.remote() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.remote>`__
13+
14+
15+
This tutorial uses a Resnet50 model to demonstrate implementing distributed
16+
pipeline parallelism with `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__
17+
APIs. This can be viewed as the distributed counterpart of the multi-GPU
18+
pipeline parallelism discussed in
19+
`Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_.
20+
21+
22+
Basics
23+
------
24+
25+
26+
The previous tutorial, `Getting Started with Distributed RPC Framework <rpc_tutorial.html>`_
27+
shows how to use `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`_
28+
to implement distributed model parallelism for an RNN model. That tutorial uses
29+
one GPU to host the ``EmbeddingTable``, and the provided code works fine.
30+
However, if a model lives on multiple GPUs, it would require some extra steps to
31+
increase the amortized utilization of all GPUs. Pipeline parallelism is one type
32+
of technique that can help in this case.
33+
34+
In this tutorial, we use ``ResNet50`` as an example model which is also used by
35+
the `Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_
36+
tutorial. Similarly, the ``ResNet50`` model is divided into two shards and
37+
the input batch is partitioned into multiple splits and fed iinto the two model
38+
shards in a pipelined fashion. The difference is that, instead of parallelize
39+
the execution using CUDA streams, this tutorial invokes asynchronous RPCs. So,
40+
the solution presented in this tutorial also works across machine boundaries.
41+
The remainder of this tutorial presents the implementation in four steps.
42+
43+
44+
45+
Step 1: Partition ResNet50 Model
46+
--------------------------------
47+
48+
This is the preparation step which implements ``ResNet50`` in two model shards.
49+
The code below is borrowed from the
50+
`ResNet implemention in torchvision <https://github.com/pytorch/vision/blob/7c077f6a986f05383bcb86b535aedb5a63dd5c4b/torchvision/models/resnet.py#L124>`_.
51+
The ``ResNetBase`` module contains the common building blocks and attributes for
52+
the two ResNet shards.
53+
54+
55+
.. code:: python
56+
import threading
57+
58+
import torch
59+
import torch.nn as nn
60+
61+
from torchvision.models.resnet import Bottleneck
62+
63+
num_classes = 1000
64+
65+
66+
def conv1x1(in_planes, out_planes, stride=1):
67+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
68+
69+
70+
class ResNetBase(nn.Module):
71+
def __init__(self, block, inplanes, num_classes=1000,
72+
groups=1, width_per_group=64, norm_layer=None):
73+
super(ResNetBase, self).__init__()
74+
75+
self._lock = threading.Lock()
76+
self._block = block
77+
self._norm_layer = nn.BatchNorm2d
78+
self.inplanes = inplanes
79+
self.dilation = 1
80+
self.groups = groups
81+
self.base_width = width_per_group
82+
83+
def _make_layer(self, planes, blocks, stride=1):
84+
norm_layer = self._norm_layer
85+
downsample = None
86+
previous_dilation = self.dilation
87+
if stride != 1 or self.inplanes != planes * self._block.expansion:
88+
downsample = nn.Sequential(
89+
conv1x1(self.inplanes, planes * self._block.expansion, stride),
90+
norm_layer(planes * self._block.expansion),
91+
)
92+
93+
layers = []
94+
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
95+
self.base_width, previous_dilation, norm_layer))
96+
self.inplanes = planes * self._block.expansion
97+
for _ in range(1, blocks):
98+
layers.append(self._block(self.inplanes, planes, groups=self.groups,
99+
base_width=self.base_width, dilation=self.dilation,
100+
norm_layer=norm_layer))
101+
102+
return nn.Sequential(*layers)
103+
104+
def parameter_rrefs(self):
105+
return [RRef(p) for p in self.parameters()]
106+
107+
108+
Now, we are ready to define the two model shards. For the constructor, we
109+
simply split all ResNet50 layers into two parts and move each part into the
110+
provided device. The ``forward`` functions of both shards take an ``RRef`` of
111+
the input data, fetch the data locally, and then move it to the expected device.
112+
After applying all layers to the input, it moves the output to CPU and return.
113+
It is because the RPC API requires tensors to reside on CPU to avoid invalid
114+
device errors when the numbers of devices in the caller and the callee do not
115+
match.
116+
117+
118+
.. code:: python
119+
120+
class ResNetShard1(ResNetBase):
121+
def __init__(self, device, *args, **kwargs):
122+
super(ResNetShard1, self).__init__(
123+
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)
124+
125+
self.device = device
126+
self.seq = nn.Sequential(
127+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
128+
self._norm_layer(self.inplanes),
129+
nn.ReLU(inplace=True),
130+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
131+
self._make_layer(64, 3),
132+
self._make_layer(128, 4, stride=2)
133+
).to(self.device)
134+
135+
for m in self.modules():
136+
if isinstance(m, nn.Conv2d):
137+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
138+
elif isinstance(m, nn.BatchNorm2d):
139+
nn.init.constant_(m.weight, 1)
140+
nn.init.constant_(m.bias, 0)
141+
142+
def forward(self, x_rref):
143+
x = x_rref.to_here().to(self.device)
144+
with self._lock:
145+
out = self.seq(x)
146+
return out.cpu()
147+
148+
149+
class ResNetShard2(ResNetBase):
150+
def __init__(self, device, *args, **kwargs):
151+
super(ResNetShard2, self).__init__(
152+
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)
153+
154+
self.device = device
155+
self.seq = nn.Sequential(
156+
self._make_layer(256, 6, stride=2),
157+
self._make_layer(512, 3, stride=2),
158+
nn.AdaptiveAvgPool2d((1, 1)),
159+
).to(self.device)
160+
161+
self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)
162+
163+
def forward(self, x_rref):
164+
x = x_rref.to_here().to(self.device)
165+
with self._lock:
166+
out = self.fc(torch.flatten(self.seq(x), 1))
167+
return out.cpu()
168+
169+
170+
Step 2: Stitch ResNet50 Model Shards Into One Module
171+
----------------------------------------------------
172+
173+
174+
Then, we create an ``DistResNet50`` module to assemble the two shards and
175+
implement the pipeline parallel logic. In the constructor, we use two
176+
``rpc.remote`` calls to put the two shards on two different RPC workers
177+
respectively, and hold on to the ``RRef`` to the two model parts so that they
178+
can be referenced in the forward pass. The ``forward`` function
179+
splits the input batch into multiple micro-batches, and feeds these
180+
micro-batches to the two model parts in a pipelined fashion. It first uses an
181+
``rpc.remote`` call to apply the first shard to a micro-batch and then forwards
182+
the returned intermediate output ``RRef`` to the second model shard. After that,
183+
it collects the ``Future`` of all micro-outputs, and waits for all of them after
184+
the loop. Note that both ``remote()`` and ``rpc_async()`` return immediately and
185+
run asynchronously. Therefore, the entire loop is non-blocking, and will launch
186+
multiple RPCs concurrently. The execution order of one micro-batch on two model
187+
parts are preserved by intermediate output ``y_rref``. The execution order
188+
across micro-batches does not matter. In the end, the forward function
189+
concatenates outputs of all micro-batches into one single output tensor and
190+
returns. The ``parameter_rrefs`` function is a helper to
191+
simplify distributed optimizer construction, which will be used later.
192+
193+
194+
195+
.. code:: python
196+
197+
class DistResNet50(nn.Module):
198+
"""
199+
Assemble two parts as an nn.Module and define pipelining logic
200+
"""
201+
def __init__(self, num_split, workers, *args, **kwargs):
202+
super(DistResNet50, self).__init__()
203+
204+
self.num_split = num_split
205+
206+
# Put the first part of the ResNet50 on workers[0]
207+
self.p1_rref = rpc.remote(
208+
workers[0],
209+
ResNetShard1,
210+
args = ("cuda:0",) + args,
211+
kwargs = kwargs
212+
)
213+
214+
# Put the second part of the ResNet50 on workers[1]
215+
self.p2_rref = rpc.remote(
216+
workers[1],
217+
ResNetShard2,
218+
args = ("cuda:1",) + args,
219+
kwargs = kwargs
220+
)
221+
222+
def forward(self, xs):
223+
out_futures = []
224+
for x in iter(xs.split(self.split_size, dim=0)):
225+
x_rref = RRef(x)
226+
y_rref = self.p1_rref.remote().forward(x_rref)
227+
z_fut = self.p2_rref.rpc_async().forward(y_rref)
228+
out_futures.append(z_fut)
229+
230+
return torch.cat(torch.futures.wait_all(out_futures))
231+
232+
def parameter_rrefs(self):
233+
remote_params = []
234+
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
235+
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
236+
return remote_params
237+
238+
239+
Step 3: Define The Training Loop
240+
--------------------------------
241+
242+
243+
After defining the model, let us implement the training loop. We use a
244+
dedicated master worker to prepare random inputs and labels, and control the
245+
distributed backward pass and distributed optimizer step. It first creates an
246+
instance of the ``DistResNet50`` module. It specifies the number of
247+
micro-batches for each batch, and also provides the name of the two RPC workers
248+
(i.e., "worker1", and "worker2"). Then it defines the loss function and creates
249+
a ``DistributedOptimizer`` using the ``parameter_rrefs()`` helper to acquire a
250+
list of parameter ``RRefs``. Then, the main training loop is very similar to
251+
regular local trainings, except that it uses ``dist_autograd`` to launch
252+
backward and provides the ``context_id`` for both backward and optimizer step.
253+
254+
255+
.. code:: python
256+
257+
import torch.distributed.autograd as dist_autograd
258+
import torch.optim as optim
259+
from torch.distributed.optim import DistributedOptimizer
260+
261+
num_batches = 3
262+
batch_size = 120
263+
image_w = 128
264+
image_h = 128
265+
266+
267+
def run_master(num_split):
268+
# put the two model parts on worker1 and worker2 respectively
269+
model = DistResNet50(num_split, ["worker1", "worker2"])
270+
loss_fn = nn.MSELoss()
271+
opt = DistributedOptimizer(
272+
optim.SGD,
273+
model.parameter_rrefs(),
274+
lr=0.05,
275+
)
276+
277+
one_hot_indices = torch.LongTensor(batch_size) \
278+
.random_(0, num_classes) \
279+
.view(batch_size, 1)
280+
281+
for i in range(num_batches):
282+
print(f"Processing batch {i}")
283+
# generate random inputs and labels
284+
inputs = torch.randn(batch_size, 3, image_w, image_h)
285+
labels = torch.zeros(batch_size, num_classes) \
286+
.scatter_(1, one_hot_indices, 1)
287+
288+
with dist_autograd.context() as context_id:
289+
outputs = model(inputs)
290+
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
291+
opt.step(context_id)
292+
293+
294+
Step 4: Launch RPC Processes
295+
----------------------------
296+
297+
298+
Finally, the code below shows the target function for all processes. The main
299+
logic is all defined in ``run_master``. The workers passively waiting for
300+
commands from the master, and hence simply runs ``init_rpc`` and ``shutdwown``,
301+
where the ``shutdown`` by default will block until all RPC participants finish.
302+
303+
.. code:: python
304+
305+
import os
306+
import time
307+
308+
import torch.multiprocessing as mp
309+
310+
311+
def run_worker(rank, world_size, num_split):
312+
os.environ['MASTER_ADDR'] = 'localhost'
313+
os.environ['MASTER_PORT'] = '29500'
314+
options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=128)
315+
316+
if rank == 0:
317+
rpc.init_rpc(
318+
"master",
319+
rank=rank,
320+
world_size=world_size,
321+
rpc_backend_options=options
322+
)
323+
run_master(num_split)
324+
else:
325+
rpc.init_rpc(
326+
f"worker{rank}",
327+
rank=rank,
328+
world_size=world_size,
329+
rpc_backend_options=options
330+
)
331+
pass
332+
333+
# block until all rpcs finish
334+
rpc.shutdown()
335+
336+
337+
if __name__=="__main__":
338+
world_size = 3
339+
for num_split in [1, 2, 4, 8]:
340+
tik = time.time()
341+
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
342+
tok = time.time()
343+
print(f"number of splits = {num_split}, execution time = {tok - tik}")
344+
345+
346+
The output below shows the speedup attained by increasing the number of splits
347+
in each batch.
348+
349+
::
350+
351+
$ python main.py
352+
Processing batch 0
353+
Processing batch 1
354+
Processing batch 2
355+
number of splits = 1, execution time = 16.45062756538391
356+
Processing batch 0
357+
Processing batch 1
358+
Processing batch 2
359+
number of splits = 2, execution time = 12.329529762268066
360+
Processing batch 0
361+
Processing batch 1
362+
Processing batch 2
363+
number of splits = 4, execution time = 10.164430618286133
364+
Processing batch 0
365+
Processing batch 1
366+
Processing batch 2
367+
number of splits = 8, execution time = 9.076049566268921

0 commit comments

Comments
 (0)