Skip to content

Commit 8df8e74

Browse files
Example for combining DDP + RPC (#800)
* Example for combining DDP + RPC Summary: The example includes a simple model consisting of a sparse part and a dense part. The sparse part is an nn.EmbeddingBag stored on a parameter server and the dense part is an nn.Linear module residing on the trainers. The dense part on the trainers are replicated via DistributedDataParallel. A master creates the nn.EmbeddingBag and drives the training loop on the trainers. The training loop performs an embedding lookup via the Distributed RPC Framework and then executes the local dense component. Test Plan: Reviewers: Subscribers: Tasks: Tags: * Address review comments. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Address more comments. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Co-authored-by: pritam <[email protected]>
1 parent e7870c1 commit 8df8e74

File tree

3 files changed

+202
-0
lines changed

3 files changed

+202
-0
lines changed

distributed/rpc/ddp_rpc/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
Distributed DataParallel + Distributed RPC Framework Example
2+
3+
The example shows how to combine Distributed DataParallel with the Distributed
4+
RPC Framework. There are two trainer nodes, 1 master node and 1 parameter
5+
server in the example.
6+
7+
The master node creates an embedding table on the parameter server and drives
8+
the training loop on the trainers. The model consists of a dense part
9+
(nn.Linear) replicated on the trainers via Distributed DataParallel and a
10+
sparse part (nn.EmbeddingBag) which resides on the parameter server. Each
11+
trainer performs an embedding lookup on the parameter server (using the
12+
Distributed RPC Framework) and then executes its local nn.Linear module.
13+
During the backward pass, the gradients for the dense part are aggregated via
14+
allreduce by DDP and the distributed backward pass updates the parameters for
15+
the embedding table on the parameter server.
16+
17+
18+
```
19+
pip install -r requirements.txt
20+
python main.py
21+
```

distributed/rpc/ddp_rpc/main.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from functools import wraps
2+
import os
3+
import random
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.distributed.autograd as dist_autograd
8+
from torch.distributed.optim import DistributedOptimizer
9+
import torch.distributed.rpc as rpc
10+
from torch.distributed.rpc import RRef
11+
from torch.distributed.rpc import ProcessGroupRpcBackendOptions
12+
import torch.multiprocessing as mp
13+
from torch.nn.parallel import DistributedDataParallel as DDP
14+
import torch.optim as optim
15+
16+
NUM_EMBEDDINGS = 100
17+
EMBEDDING_DIM = 16
18+
19+
class HybridModel(torch.nn.Module):
20+
r"""
21+
The model consists of a sparse part and a dense part. The dense part is an
22+
nn.Linear module that is replicated across all trainers using
23+
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
24+
stored on the parameter server.
25+
26+
The model holds a Remote Reference to the embedding table on the parameter
27+
server.
28+
"""
29+
30+
def __init__(self, emb_rref, device):
31+
super(HybridModel, self).__init__()
32+
self.emb_rref = emb_rref
33+
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
34+
self.device = device
35+
36+
def forward(self, indices, offsets):
37+
emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets)
38+
return self.fc(emb_lookup.cuda(self.device))
39+
40+
def _retrieve_embedding_parameters(emb_rref):
41+
return [RRef(p) for p in emb_rref.local_value().parameters()]
42+
43+
44+
def _run_trainer(emb_rref, rank):
45+
r"""
46+
Each trainer runs a forward pass which involves an embedding lookup on the
47+
parameter server and running nn.Linear locally. During the backward pass,
48+
DDP is responsible for aggregating the gradients for the dense part
49+
(nn.Linear) and distributed autograd ensures gradients updates are
50+
propagated to the parameter server.
51+
"""
52+
53+
# Setup the model.
54+
model = HybridModel(emb_rref, rank)
55+
56+
# Retrieve all model parameters as rrefs for DistributedOptimizer.
57+
58+
# Retrieve parameters for embedding table.
59+
model_parameter_rrefs = rpc.rpc_sync(
60+
"ps", _retrieve_embedding_parameters, args=(emb_rref,))
61+
62+
# model.parameters() only includes local parameters.
63+
for param in model.parameters():
64+
model_parameter_rrefs.append(RRef(param))
65+
66+
# Setup distributed optimizer
67+
opt = DistributedOptimizer(
68+
optim.SGD,
69+
model_parameter_rrefs,
70+
lr=0.05,
71+
)
72+
73+
criterion = torch.nn.CrossEntropyLoss()
74+
75+
def get_next_batch(rank):
76+
for _ in range(10):
77+
num_indices = random.randint(20, 50)
78+
indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)
79+
80+
# Generate offsets.
81+
offsets = []
82+
start = 0
83+
batch_size = 0
84+
while start < num_indices:
85+
offsets.append(start)
86+
start += random.randint(1, 10)
87+
batch_size += 1
88+
89+
offsets_tensor = torch.LongTensor(offsets)
90+
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
91+
yield indices, offsets_tensor, target
92+
93+
# Train for 100 epochs
94+
for epoch in range(100):
95+
# create distributed autograd context
96+
for indices, offsets, target in get_next_batch(rank):
97+
with dist_autograd.context() as context_id:
98+
output = model(indices, offsets)
99+
loss = criterion(output, target)
100+
101+
# Run distributed backward pass
102+
dist_autograd.backward(context_id, [loss])
103+
104+
# Tun distributed optimizer
105+
opt.step(context_id)
106+
107+
# Not necessary to zero grads as each iteration creates a different
108+
# distributed autograd context which hosts different grads
109+
print("Training done for epoch {}".format(epoch))
110+
111+
112+
def run_worker(rank, world_size):
113+
r"""
114+
A wrapper function that initializes RPC, calls the function, and shuts down
115+
RPC.
116+
"""
117+
118+
# We need to use different port numbers in TCP init_method for init_rpc and
119+
# init_process_group to avoid port conflicts.
120+
rpc_backend_options = ProcessGroupRpcBackendOptions()
121+
rpc_backend_options.init_method='tcp://localhost:29501'
122+
123+
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
124+
if rank == 2:
125+
rpc.init_rpc(
126+
"master",
127+
rank=rank,
128+
world_size=world_size,
129+
rpc_backend_options=rpc_backend_options)
130+
131+
# Build the embedding table on the ps.
132+
emb_rref = rpc.remote(
133+
"ps",
134+
torch.nn.EmbeddingBag,
135+
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
136+
kwargs={"mode": "sum"})
137+
138+
# Run the training loop on trainers.
139+
futs = []
140+
for trainer_rank in [0, 1]:
141+
trainer_name = "trainer{}".format(trainer_rank)
142+
fut = rpc.rpc_async(
143+
trainer_name, _run_trainer, args=(emb_rref, rank))
144+
futs.append(fut)
145+
146+
# Wait for all training to finish.
147+
for fut in futs:
148+
fut.wait()
149+
elif rank <= 1:
150+
# Initialize process group for Distributed DataParallel on trainers.
151+
dist.init_process_group(
152+
backend="gloo", rank=rank, world_size=2,
153+
init_method='tcp://localhost:29500')
154+
155+
# Initialize RPC.
156+
trainer_name = "trainer{}".format(rank)
157+
rpc.init_rpc(
158+
trainer_name,
159+
rank=rank,
160+
world_size=world_size,
161+
rpc_backend_options=rpc_backend_options)
162+
163+
# Trainer just waits for RPCs from master.
164+
else:
165+
rpc.init_rpc(
166+
"ps",
167+
rank=rank,
168+
world_size=world_size,
169+
rpc_backend_options=rpc_backend_options)
170+
# parameter server do nothing
171+
pass
172+
173+
# block until all rpcs finish
174+
rpc.shutdown()
175+
176+
177+
if __name__=="__main__":
178+
# 2 trainers, 1 parameter server, 1 master.
179+
world_size = 4
180+
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torch>=1.6.0

0 commit comments

Comments
 (0)