Skip to content

Commit b1cefec

Browse files
committed
Tutorial for DDP + RPC.
Summary: Based on example from pytorch/examples#800
1 parent 68c22a0 commit b1cefec

File tree

1 file changed

+308
-0
lines changed

1 file changed

+308
-0
lines changed
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
Combining Distributed DataParallel with Distributed RPC Framework
2+
=================================================================
3+
**Author**: `Pritam Damania <https://github.com/pritamdamania87>`_
4+
5+
6+
This tutorial uses a simple example to demonstrate how you can combine
7+
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__ (DDP)
8+
with the `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__
9+
to combine distributed data parallelism with distributed model parallelism to
10+
train a simple model. Source code of the example can be found `here <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__.
11+
12+
Previous tutorials,
13+
`Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
14+
and `Getting Started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__,
15+
described how to perform distributed data parallel and distributed model
16+
parallel training respectively. Although, there are several training paradigms
17+
where you might want to combine these two techniques. For example:
18+
19+
1) If we have a model with a sparse part (large embedding table) and a dense
20+
part (FC layers), we might want to put the embedding table on a parameter
21+
server and replicate the FC layer across multiple trainers using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__.
22+
The `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__
23+
can be used to perform embedding lookups on the parameter server.
24+
2) Enable hybrid parallelism as described in the `PipeDream <https://arxiv.org/abs/1806.03377>`__ paper.
25+
We can use the `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__
26+
to pipeline stages of the model across multiple workers and replicate each
27+
stage (if needed) using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__.
28+
29+
|
30+
In this tutorial we will cover case 1 mentioned above. We have a total of 4
31+
workers in our setup as follows:
32+
33+
34+
1) 1 Master, which is responsible for creating an embedding table
35+
(nn.EmbeddingBag) on the parameter server. The master also drives the
36+
training loop on the two trainers.
37+
2) 1 Parameter Server, which basically holds the embedding table in memory and
38+
responds to RPCs from the Master and Trainers.
39+
3) 2 Trainers, which store a FC layer (nn.Linear) which is replicated amongst
40+
themselves using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__.
41+
The trainers are also responsible for executing the forward pass, backward
42+
pass and optimizer step.
43+
44+
|
45+
The entire training process is executed as follows:
46+
47+
1) The master creates an embedding table on the Parameter Server and holds a
48+
`RRef <https://pytorch.org/docs/master/rpc.html#rref>`__ to it.
49+
2) The master, then kicks of the training loop on the trainers and passes the
50+
embedding table RRef to the trainers.
51+
3) The trainers create a ``HybridModel`` which first performs an embedding lookup
52+
using the embedding table RRef provided by the master and then executes the
53+
FC layer which is wrapped inside DDP.
54+
4) The trainer executes the forward pass of the model and uses the loss to
55+
execute the backward pass using `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__.
56+
5) As part of the backward pass, the gradients for the FC layer are computed
57+
first and synced to all trainers via allreduce in DDP.
58+
6) Next, Distributed Autograd propagates the gradients to the parameter server,
59+
where the gradients for the embedding table are updated.
60+
7) Finally, the `Distributed Optimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__ is used to update all the parameters.
61+
62+
63+
|
64+
**NOTE**: You should always use `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__ for the backward pass if you're combining DDP and RPC.
65+
66+
67+
Now, lets go through each part in detail. Firstly, we need to setup all of our
68+
workers before we can perform any training. We create 4 processes such that
69+
ranks 0 and 1 are our trainers, rank 2 is the master and rank 3 is the
70+
parameter server.
71+
72+
We initialize the RPC framework on all 4 workers using the TCP init_method.
73+
Once RPC initialization is done, the master creates an `EmbeddingBag <https://pytorch.org/docs/master/generated/torch.nn.EmbeddingBag.html>`__
74+
on the Parameter Server using `rpc.remote <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.remote>`__.
75+
The master then loops through each trainer and kicks of the training loop by
76+
calling ``_run_trainer`` on each trainer using `rpc_async <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.rpc_async>`__.
77+
Finally, the master waits for all training to finish before exiting.
78+
79+
The trainers first initialize a ``ProcessGroup`` for DDP with world_size=2
80+
(for two trainers) using `init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.
81+
Next, they initialize the RPC framework using the TCP init_method. Note that
82+
the ports are different in RPC initialization and ProcessGroup intialization.
83+
This is to avoid port conflicts between initialization of both frameworks.
84+
Once the initialization is done, the trainers just wait for the ``_run_trainer``
85+
RPC from the master.
86+
87+
The parameter server just initializes the RPC framework and waits for RPCs from
88+
the trainers and master.
89+
90+
.. code:: python
91+
92+
def run_worker(rank, world_size):
93+
r"""
94+
A wrapper function that initializes RPC, calls the function, and shuts down
95+
RPC.
96+
"""
97+
98+
# We need to use different port numbers in TCP init_method for init_rpc and
99+
# init_process_group to avoid port conflicts.
100+
rpc_backend_options = ProcessGroupRpcBackendOptions()
101+
rpc_backend_options.init_method='tcp://localhost:29501'
102+
103+
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
104+
if rank == 2:
105+
rpc.init_rpc(
106+
"master",
107+
rank=rank,
108+
world_size=world_size,
109+
rpc_backend_options=rpc_backend_options)
110+
111+
# Build the embedding table on the ps.
112+
emb_rref = rpc.remote(
113+
"ps",
114+
torch.nn.EmbeddingBag,
115+
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
116+
kwargs={"mode": "sum"})
117+
118+
# Run the training loop on trainers.
119+
futs = []
120+
for trainer_rank in [0, 1]:
121+
trainer_name = "trainer{}".format(trainer_rank)
122+
fut = rpc.rpc_async(
123+
trainer_name, _run_trainer, args=(emb_rref, rank))
124+
futs.append(fut)
125+
126+
# Wait for all training to finish.
127+
for fut in futs:
128+
fut.wait()
129+
elif rank <= 1:
130+
# Initialize process group for Distributed DataParallel on trainers.
131+
dist.init_process_group(
132+
backend="gloo", rank=rank, world_size=2,
133+
init_method='tcp://localhost:29500')
134+
135+
# Initialize RPC.
136+
trainer_name = "trainer{}".format(rank)
137+
rpc.init_rpc(
138+
trainer_name,
139+
rank=rank,
140+
world_size=world_size,
141+
rpc_backend_options=rpc_backend_options)
142+
143+
# Trainer just waits for RPCs from master.
144+
else:
145+
rpc.init_rpc(
146+
"ps",
147+
rank=rank,
148+
world_size=world_size,
149+
rpc_backend_options=rpc_backend_options)
150+
# parameter server do nothing
151+
pass
152+
153+
# block until all rpcs finish
154+
rpc.shutdown()
155+
156+
157+
if __name__=="__main__":
158+
# 2 trainers, 1 parameter server, 1 master.
159+
world_size = 4
160+
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
161+
162+
Before we discuss details of the Trainer, lets introduce the ``HybridModel`` that
163+
the trainer uses. As described below, the ``HybridModel`` is initialized using an
164+
RRef to the embedding table (emb_rref) on the parameter server and the ``device``
165+
to use for DDP. The initialization of the model wraps a
166+
`nn.Linear <https://pytorch.org/docs/master/generated/torch.nn.Linear.html>`__
167+
layer inside DDP to replicate and synchronize this layer across all trainers.
168+
169+
The forward method of the model is pretty straightforward. It performs an
170+
embedding lookup on the parameter server using an
171+
`RRef helper <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__
172+
and passes its output onto the FC layer.
173+
174+
175+
.. code:: python
176+
177+
class HybridModel(torch.nn.Module):
178+
r"""
179+
The model consists of a sparse part and a dense part. The dense part is an
180+
nn.Linear module that is replicated across all trainers using
181+
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
182+
stored on the parameter server.
183+
184+
The model holds a Remote Reference to the embedding table on the parameter
185+
server.
186+
"""
187+
188+
def __init__(self, emb_rref, device):
189+
super(HybridModel, self).__init__()
190+
self.emb_rref = emb_rref
191+
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
192+
self.device = device
193+
194+
def forward(self, indices, offsets):
195+
emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets)
196+
return self.fc(emb_lookup.cuda(self.device))
197+
198+
Next, lets look at the setup on the Trainer. The trainer first creates the
199+
``HybridModel`` described above using an RRef to the embedding table on the
200+
parameter server and its own rank.
201+
202+
Now, we need to retrieve a list of RRefs to all the parameters that we would
203+
like to optimize with `DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__.
204+
To retrieve the parameters for the embedding table from the parameter server,
205+
we define a simple helper function ``_retrieve_embedding_parameters``, which
206+
basically walks through all the parameters for the embedding table and returns
207+
a list of RRefs. The trainer calls this method on the parameter server via RPC
208+
to receive a list of RRefs to the desired parameters. Since the
209+
DistributedOptimizer always takes a list of RRefs to parameters that need to
210+
be optimized, we need to create RRefs even for the local parameters for our
211+
FC layers. This is done by walking ``model.parameters()``, creating an RRef for
212+
each parameter and appending it to a list. Note that ``model.parameters()`` only
213+
returns local parameters and doesn't include ``emb_rref``.
214+
215+
Finally, we create our DistributedOptimizer using all the RRefs and define a
216+
CrossEntropyLoss function.
217+
218+
.. code:: python
219+
220+
def _retrieve_embedding_parameters(emb_rref):
221+
return [RRef(p) for p in emb_rref.local_value().parameters()]
222+
223+
224+
def _run_trainer(emb_rref, rank):
225+
r"""
226+
Each trainer runs a forward pass which involves an embedding lookup on the
227+
parameter server and running nn.Linear locally. During the backward pass,
228+
DDP is responsible for aggregating the gradients for the dense part
229+
(nn.Linear) and distributed autograd ensures gradients updates are
230+
propagated to the parameter server.
231+
"""
232+
233+
# Setup the model.
234+
model = HybridModel(emb_rref, rank)
235+
236+
# Retrieve all model parameters as rrefs for DistributedOptimizer.
237+
238+
# Retrieve parameters for embedding table.
239+
model_parameter_rrefs = rpc.rpc_sync(
240+
"ps", _retrieve_embedding_parameters, args=(emb_rref,))
241+
242+
# model.parameters() only includes local parameters.
243+
for param in model.parameters():
244+
model_parameter_rrefs.append(RRef(param))
245+
246+
# Setup distributed optimizer
247+
opt = DistributedOptimizer(
248+
optim.SGD,
249+
model_parameter_rrefs,
250+
lr=0.05,
251+
)
252+
253+
criterion = torch.nn.CrossEntropyLoss()
254+
255+
Now we're ready to introduce the main training loop that is run on each trainer.
256+
``get_next_batch`` is just a helper function to generate random inputs and
257+
targets for training. We run the training loop for multiple epochs and for each
258+
batch:
259+
260+
1) Setup a `Distributed Autograd Context <https://pytorch.org/docs/master/rpc.html#torch.distributed.autograd.context>`__
261+
for Distributed Autograd.
262+
2) Run the forward pass of the model and retrieve its output.
263+
3) Compute the loss based on our outputs and targets using the loss function.
264+
4) Use Distributed Autograd to execute a distributed backward pass using the loss.
265+
5) Finally, run a Distributed Optimizer step to optimize all the parameters.
266+
267+
.. code:: python
268+
269+
# def _run_trainer(emb_rref, rank): continued...
270+
271+
def get_next_batch(rank):
272+
for _ in range(10):
273+
num_indices = random.randint(20, 50)
274+
indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)
275+
276+
# Generate offsets.
277+
offsets = []
278+
start = 0
279+
batch_size = 0
280+
while start < num_indices:
281+
offsets.append(start)
282+
start += random.randint(1, 10)
283+
batch_size += 1
284+
285+
offsets_tensor = torch.LongTensor(offsets)
286+
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
287+
yield indices, offsets_tensor, target
288+
289+
# Train for 100 epochs
290+
for epoch in range(100):
291+
# create distributed autograd context
292+
for indices, offsets, target in get_next_batch(rank):
293+
with dist_autograd.context() as context_id:
294+
output = model(indices, offsets)
295+
loss = criterion(output, target)
296+
297+
# Run distributed backward pass
298+
dist_autograd.backward(context_id, [loss])
299+
300+
# Tun distributed optimizer
301+
opt.step(context_id)
302+
303+
# Not necessary to zero grads as each iteration creates a different
304+
# distributed autograd context which hosts different grads
305+
print("Training done for epoch {}".format(epoch))
306+
307+
|
308+
Source code for the entire example can be found `here <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__.

0 commit comments

Comments
 (0)