Skip to content

Commit 5ce3240

Browse files
committed
Update pipeline parallel example to use RRef helpers
1 parent 13acec6 commit 5ce3240

File tree

3 files changed

+45
-77
lines changed

3 files changed

+45
-77
lines changed

distributed/rpc/pipeline/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
Distributed Pipeline Parallel Example
22

33
This example shows how to distribute a ResNet50 model on two RPC workers and
4-
then implement distributed pipeline parallelism using RPC.
4+
then implement distributed pipeline parallelism using RPC. With pipeline
5+
parallelism, every input batch is divided into micro-batches and thse
6+
micro-batches are feed into the model in a pipelined fashion to increase the
7+
amortized device utilization. Note that this example only parallelizes the
8+
forward pass which can be viewed as the distributed counterpart of the
9+
[single machine pipeline parallel](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html#speed-up-by-pipelining-inputs)
10+
example.
511

612
```
713
pip install -r requirements.txt

distributed/rpc/pipeline/main.py

Lines changed: 36 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,59 +15,15 @@
1515
from torchvision.models.resnet import Bottleneck
1616

1717

18-
#########################################################
19-
# helper functions #
20-
#########################################################
21-
22-
23-
def _call_method(method, rref, *args, **kwargs):
24-
r"""
25-
a helper function to call a method on the given RRef
26-
"""
27-
return method(rref.local_value(), *args, **kwargs)
28-
29-
30-
def _remote_on_rref(method, rref, *args, **kwargs):
31-
r"""
32-
a helper function to run method on the owner of rref and return an RRef
33-
of the result.
34-
"""
35-
return rpc.remote(
36-
rref.owner(),
37-
_call_method,
38-
args=[method, rref] + list(args),
39-
kwargs=kwargs
40-
)
41-
42-
43-
def _async_on_rref(method, rref, *args, **kwargs):
44-
r"""
45-
a helper function to run method on the owner of rref and fetch back the
46-
result using RPC
47-
"""
48-
return rpc.rpc_async(
49-
rref.owner(),
50-
_call_method,
51-
args=[method, rref] + list(args),
52-
kwargs=kwargs
53-
)
54-
55-
56-
def _parameter_rrefs(module):
57-
r"""
58-
Create one RRef for each parameter in the given local module, and return a
59-
list of RRefs.
60-
"""
61-
param_rrefs = []
62-
for param in module.parameters():
63-
param_rrefs.append(RRef(param))
64-
return param_rrefs
65-
66-
6718
#########################################################
6819
# Define Model Parallel ResNet50 #
6920
#########################################################
7021

22+
# In order to split the ResNet50 and place it on two different workers, we
23+
# implement it in two model shards. The ResNetBase class defines common
24+
# attributes and methods shared by two shards. ResNetShard1 and ResNetShard2
25+
# contain two partitions of the model layers respectively.
26+
7127

7228
num_classes = 1000
7329

@@ -76,9 +32,8 @@ def conv1x1(in_planes, out_planes, stride=1):
7632
"""1x1 convolution"""
7733
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
7834

79-
8035
class ResNetBase(nn.Module):
81-
def __init__(self, block, inplanes, num_classes=1000,
36+
def __init__(self, block, inplanes, num_classes=1000,
8237
groups=1, width_per_group=64, norm_layer=None):
8338
super(ResNetBase, self).__init__()
8439

@@ -111,13 +66,20 @@ def _make_layer(self, planes, blocks, stride=1):
11166

11267
return nn.Sequential(*layers)
11368

69+
def parameter_rrefs(self):
70+
r"""
71+
Create one RRef for each parameter in the given local module, and return a
72+
list of RRefs.
73+
"""
74+
return [RRef(p) for p in self.parameters()]
75+
11476

115-
class ResNetPart1(ResNetBase):
77+
class ResNetShard1(ResNetBase):
11678
"""
11779
The first part of ResNet.
11880
"""
11981
def __init__(self, device, *args, **kwargs):
120-
super(ResNetPart1, self).__init__(
82+
super(ResNetShard1, self).__init__(
12183
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)
12284

12385
self.device = device
@@ -144,12 +106,12 @@ def forward(self, x_rref):
144106
return out.cpu()
145107

146108

147-
class ResNetPart2(ResNetBase):
109+
class ResNetShard2(ResNetBase):
148110
"""
149111
The second part of ResNet.
150112
"""
151113
def __init__(self, device, *args, **kwargs):
152-
super(ResNetPart2, self).__init__(
114+
super(ResNetShard2, self).__init__(
153115
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)
154116

155117
self.device = device
@@ -180,15 +142,15 @@ def __init__(self, split_size, workers, *args, **kwargs):
180142
# Put the first part of the ResNet50 on workers[0]
181143
self.p1_rref = rpc.remote(
182144
workers[0],
183-
ResNetPart1,
145+
ResNetShard1,
184146
args = ("cuda:0",) + args,
185147
kwargs = kwargs
186148
)
187149

188150
# Put the second part of the ResNet50 on workers[1]
189151
self.p2_rref = rpc.remote(
190152
workers[1],
191-
ResNetPart2,
153+
ResNetShard2,
192154
args = ("cuda:1",) + args,
193155
kwargs = kwargs
194156
)
@@ -199,22 +161,19 @@ def forward(self, xs):
199161
out_futures = []
200162
for x in iter(xs.split(self.split_size, dim=0)):
201163
x_rref = RRef(x)
202-
y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref)
203-
z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref)
164+
y_rref = self.p1_rref.remote().forward(x_rref)
165+
z_fut = self.p2_rref.rpc_async().forward(y_rref)
204166
out_futures.append(z_fut)
205167

206-
# wait for all RPC to finish
207-
outs = [fut.wait() for fut in out_futures]
208-
# cat all tensors into one tensor.
209-
out = torch.cat(outs)
210-
return out
211-
168+
# collect and cat all output tensors into one tensor.
169+
return torch.cat(torch.futures.wait_all(out_futures))
170+
212171
def parameter_rrefs(self):
213172
remote_params = []
214-
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here())
215-
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here())
173+
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
174+
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
216175
return remote_params
217-
176+
218177

219178
#########################################################
220179
# Run RPC Processes #
@@ -248,6 +207,9 @@ def run_master(split_size):
248207
labels = torch.zeros(batch_size, num_classes) \
249208
.scatter_(1, one_hot_indices, 1)
250209

210+
# The distributed autograd context is the deciated scope for the
211+
# distributed backward pass to store gradients, which can later be
212+
# retrieved using the context_id by the distributed optimizer.
251213
with dist_autograd.context() as context_id:
252214
outputs = model(inputs)
253215
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
@@ -261,17 +223,17 @@ def run_worker(rank, world_size, num_split):
261223

262224
if rank == 0:
263225
rpc.init_rpc(
264-
"master",
265-
rank=rank,
266-
world_size=world_size,
226+
"master",
227+
rank=rank,
228+
world_size=world_size,
267229
rpc_backend_options=options
268230
)
269231
run_master(num_split)
270232
else:
271233
rpc.init_rpc(
272-
f"worker{rank}",
273-
rank=rank,
274-
world_size=world_size,
234+
f"worker{rank}",
235+
rank=rank,
236+
world_size=world_size,
275237
rpc_backend_options=options
276238
)
277239
pass
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
torch==1.5.0
2-
torchvision==0.6.0
1+
torch==1.6.0
2+
torchvision==0.7.0

0 commit comments

Comments
 (0)