Skip to content

Commit d431037

Browse files
mrshenliShen Li
andauthored
Adding a distributed pipeline parallelism example for RPC (#749)
Co-authored-by: Shen Li <[email protected]>
1 parent 391be73 commit d431037

File tree

3 files changed

+299
-0
lines changed

3 files changed

+299
-0
lines changed

distributed/rpc/pipeline/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Distributed Pipeline Parallel Example
2+
3+
This example shows how to distribute a ResNet50 model on two RPC workers and
4+
then implement distributed pipeline parallelism using RPC.
5+
6+
```
7+
pip install -r requirements.txt
8+
python main.py
9+
```

distributed/rpc/pipeline/main.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import os
2+
import threading
3+
import time
4+
from functools import wraps
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.distributed.autograd as dist_autograd
9+
import torch.distributed.rpc as rpc
10+
import torch.multiprocessing as mp
11+
import torch.optim as optim
12+
from torch.distributed.optim import DistributedOptimizer
13+
from torch.distributed.rpc import RRef
14+
15+
from torchvision.models.resnet import Bottleneck
16+
17+
18+
#########################################################
19+
# helper functions #
20+
#########################################################
21+
22+
23+
def _call_method(method, rref, *args, **kwargs):
24+
r"""
25+
a helper function to call a method on the given RRef
26+
"""
27+
return method(rref.local_value(), *args, **kwargs)
28+
29+
30+
def _remote_on_rref(method, rref, *args, **kwargs):
31+
r"""
32+
a helper function to run method on the owner of rref and return an RRef
33+
of the result.
34+
"""
35+
return rpc.remote(
36+
rref.owner(),
37+
_call_method,
38+
args=[method, rref] + list(args),
39+
kwargs=kwargs
40+
)
41+
42+
43+
def _async_on_rref(method, rref, *args, **kwargs):
44+
r"""
45+
a helper function to run method on the owner of rref and fetch back the
46+
result using RPC
47+
"""
48+
return rpc.rpc_async(
49+
rref.owner(),
50+
_call_method,
51+
args=[method, rref] + list(args),
52+
kwargs=kwargs
53+
)
54+
55+
56+
def _parameter_rrefs(module):
57+
r"""
58+
Create one RRef for each parameter in the given local module, and return a
59+
list of RRefs.
60+
"""
61+
param_rrefs = []
62+
for param in module.parameters():
63+
param_rrefs.append(RRef(param))
64+
return param_rrefs
65+
66+
67+
#########################################################
68+
# Define Model Parallel ResNet50 #
69+
#########################################################
70+
71+
72+
num_classes = 1000
73+
74+
75+
def conv1x1(in_planes, out_planes, stride=1):
76+
"""1x1 convolution"""
77+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
78+
79+
80+
class ResNetBase(nn.Module):
81+
def __init__(self, block, inplanes, num_classes=1000,
82+
groups=1, width_per_group=64, norm_layer=None):
83+
super(ResNetBase, self).__init__()
84+
85+
self._lock = threading.Lock()
86+
self._block = block
87+
self._norm_layer = nn.BatchNorm2d
88+
self.inplanes = inplanes
89+
self.dilation = 1
90+
self.groups = groups
91+
self.base_width = width_per_group
92+
93+
def _make_layer(self, planes, blocks, stride=1):
94+
norm_layer = self._norm_layer
95+
downsample = None
96+
previous_dilation = self.dilation
97+
if stride != 1 or self.inplanes != planes * self._block.expansion:
98+
downsample = nn.Sequential(
99+
conv1x1(self.inplanes, planes * self._block.expansion, stride),
100+
norm_layer(planes * self._block.expansion),
101+
)
102+
103+
layers = []
104+
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
105+
self.base_width, previous_dilation, norm_layer))
106+
self.inplanes = planes * self._block.expansion
107+
for _ in range(1, blocks):
108+
layers.append(self._block(self.inplanes, planes, groups=self.groups,
109+
base_width=self.base_width, dilation=self.dilation,
110+
norm_layer=norm_layer))
111+
112+
return nn.Sequential(*layers)
113+
114+
115+
class ResNetPart1(ResNetBase):
116+
"""
117+
The first part of ResNet.
118+
"""
119+
def __init__(self, device, *args, **kwargs):
120+
super(ResNetPart1, self).__init__(
121+
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)
122+
123+
self.device = device
124+
self.seq = nn.Sequential(
125+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
126+
self._norm_layer(self.inplanes),
127+
nn.ReLU(inplace=True),
128+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
129+
self._make_layer(64, 3),
130+
self._make_layer(128, 4, stride=2)
131+
).to(self.device)
132+
133+
for m in self.modules():
134+
if isinstance(m, nn.Conv2d):
135+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
136+
elif isinstance(m, nn.BatchNorm2d):
137+
nn.init.constant_(m.weight, 1)
138+
nn.init.constant_(m.bias, 0)
139+
140+
def forward(self, x_rref):
141+
x = x_rref.to_here().to(self.device)
142+
with self._lock:
143+
out = self.seq(x)
144+
return out.cpu()
145+
146+
147+
class ResNetPart2(ResNetBase):
148+
"""
149+
The second part of ResNet.
150+
"""
151+
def __init__(self, device, *args, **kwargs):
152+
super(ResNetPart2, self).__init__(
153+
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)
154+
155+
self.device = device
156+
self.seq = nn.Sequential(
157+
self._make_layer(256, 6, stride=2),
158+
self._make_layer(512, 3, stride=2),
159+
nn.AdaptiveAvgPool2d((1, 1)),
160+
).to(self.device)
161+
162+
self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)
163+
164+
def forward(self, x_rref):
165+
x = x_rref.to_here().to(self.device)
166+
with self._lock:
167+
out = self.fc(torch.flatten(self.seq(x), 1))
168+
return out.cpu()
169+
170+
171+
class DistResNet50(nn.Module):
172+
"""
173+
Assemble two parts as an nn.Module and define pipelining logic
174+
"""
175+
def __init__(self, split_size, workers, *args, **kwargs):
176+
super(DistResNet50, self).__init__()
177+
178+
self.split_size = split_size
179+
180+
# Put the first part of the ResNet50 on workers[0]
181+
self.p1_rref = rpc.remote(
182+
workers[0],
183+
ResNetPart1,
184+
args = ("cuda:0",) + args,
185+
kwargs = kwargs
186+
)
187+
188+
# Put the second part of the ResNet50 on workers[1]
189+
self.p2_rref = rpc.remote(
190+
workers[1],
191+
ResNetPart2,
192+
args = ("cuda:1",) + args,
193+
kwargs = kwargs
194+
)
195+
196+
def forward(self, xs):
197+
# Split the input batch xs into micro-batches, and collect async RPC
198+
# futures into a list
199+
out_futures = []
200+
for x in iter(xs.split(self.split_size, dim=0)):
201+
x_rref = RRef(x)
202+
y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref)
203+
z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref)
204+
out_futures.append(z_fut)
205+
206+
# wait for all RPC to finish
207+
outs = [fut.wait() for fut in out_futures]
208+
# cat all tensors into one tensor.
209+
out = torch.cat(outs)
210+
return out
211+
212+
def parameter_rrefs(self):
213+
remote_params = []
214+
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here())
215+
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here())
216+
return remote_params
217+
218+
219+
#########################################################
220+
# Run RPC Processes #
221+
#########################################################
222+
223+
num_batches = 3
224+
batch_size = 120
225+
image_w = 128
226+
image_h = 128
227+
228+
229+
def run_master(num_split):
230+
# put the two model parts on worker1 and worker2 respectively
231+
model = DistResNet50(split_size, ["worker1", "worker2"])
232+
loss_fn = nn.MSELoss()
233+
opt = DistributedOptimizer(
234+
optim.SGD,
235+
model.parameter_rrefs(),
236+
lr=0.05,
237+
)
238+
239+
one_hot_indices = torch.LongTensor(batch_size) \
240+
.random_(0, num_classes) \
241+
.view(batch_size, 1)
242+
243+
for i in range(num_batches):
244+
print(f"Processing batch {i}")
245+
# generate random inputs and labels
246+
inputs = torch.randn(batch_size, 3, image_w, image_h)
247+
labels = torch.zeros(batch_size, num_classes) \
248+
.scatter_(1, one_hot_indices, 1)
249+
250+
with dist_autograd.context() as context_id:
251+
outputs = model(inputs)
252+
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
253+
opt.step(context_id)
254+
255+
256+
def run_worker(rank, world_size, num_split):
257+
os.environ['MASTER_ADDR'] = 'localhost'
258+
os.environ['MASTER_PORT'] = '29500'
259+
options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=256)
260+
261+
if rank == 0:
262+
rpc.init_rpc(
263+
"master",
264+
rank=rank,
265+
world_size=world_size,
266+
rpc_backend_options=options
267+
)
268+
run_master(num_split)
269+
else:
270+
rpc.init_rpc(
271+
f"worker{rank}",
272+
rank=rank,
273+
world_size=world_size,
274+
rpc_backend_options=options
275+
)
276+
pass
277+
278+
# block until all rpcs finish
279+
rpc.shutdown()
280+
281+
282+
if __name__=="__main__":
283+
world_size = 3
284+
for num_split in [1, 2, 4, 8]:
285+
tik = time.time()
286+
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
287+
tok = time.time()
288+
print(f"number of splits = {num_split}, execution time = {tok - tik}")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch==1.5.0
2+
torchvision==0.6.0

0 commit comments

Comments
 (0)