Skip to content

Commit 955a744

Browse files
authored
Create main.py
1 parent 4f1c1c5 commit 955a744

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

advanced_source/rpc_ddp/main.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import random
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed.autograd as dist_autograd
6+
import torch.distributed.rpc as rpc
7+
import torch.multiprocessing as mp
8+
import torch.optim as optim
9+
from torch.distributed.nn import RemoteModule
10+
from torch.distributed.optim import DistributedOptimizer
11+
from torch.distributed.rpc import RRef
12+
from torch.distributed.rpc import TensorPipeRpcBackendOptions
13+
from torch.nn.parallel import DistributedDataParallel as DDP
14+
15+
NUM_EMBEDDINGS = 100
16+
EMBEDDING_DIM = 16
17+
18+
# BEGIN hybrid_model
19+
class HybridModel(torch.nn.Module):
20+
r"""
21+
The model consists of a sparse part and a dense part.
22+
1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
23+
2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
24+
This remote model can get a Remote Reference to the embedding table on the parameter server.
25+
"""
26+
27+
def __init__(self, remote_emb_module, device):
28+
super(HybridModel, self).__init__()
29+
self.remote_emb_module = remote_emb_module
30+
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
31+
self.device = device
32+
33+
def forward(self, indices, offsets):
34+
emb_lookup = self.remote_emb_module.forward(indices, offsets)
35+
return self.fc(emb_lookup.cuda(self.device))
36+
# END hybrid_model
37+
38+
# BEGIN setup_trainer
39+
def _run_trainer(remote_emb_module, rank):
40+
r"""
41+
Each trainer runs a forward pass which involves an embedding lookup on the
42+
parameter server and running nn.Linear locally. During the backward pass,
43+
DDP is responsible for aggregating the gradients for the dense part
44+
(nn.Linear) and distributed autograd ensures gradients updates are
45+
propagated to the parameter server.
46+
"""
47+
48+
# Setup the model.
49+
model = HybridModel(remote_emb_module, rank)
50+
51+
# Retrieve all model parameters as rrefs for DistributedOptimizer.
52+
53+
# Retrieve parameters for embedding table.
54+
model_parameter_rrefs = model.remote_emb_module.remote_parameters()
55+
56+
# model.fc.parameters() only includes local parameters.
57+
# NOTE: Cannot call model.parameters() here,
58+
# because this will call remote_emb_module.parameters(),
59+
# which supports remote_parameters() but not parameters().
60+
for param in model.fc.parameters():
61+
model_parameter_rrefs.append(RRef(param))
62+
63+
# Setup distributed optimizer
64+
opt = DistributedOptimizer(
65+
optim.SGD,
66+
model_parameter_rrefs,
67+
lr=0.05,
68+
)
69+
70+
criterion = torch.nn.CrossEntropyLoss()
71+
# END setup_trainer
72+
73+
# BEGIN run_trainer
74+
def get_next_batch(rank):
75+
for _ in range(10):
76+
num_indices = random.randint(20, 50)
77+
indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)
78+
79+
# Generate offsets.
80+
offsets = []
81+
start = 0
82+
batch_size = 0
83+
while start < num_indices:
84+
offsets.append(start)
85+
start += random.randint(1, 10)
86+
batch_size += 1
87+
88+
offsets_tensor = torch.LongTensor(offsets)
89+
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
90+
yield indices, offsets_tensor, target
91+
92+
# Train for 100 epochs
93+
for epoch in range(100):
94+
# create distributed autograd context
95+
for indices, offsets, target in get_next_batch(rank):
96+
with dist_autograd.context() as context_id:
97+
output = model(indices, offsets)
98+
loss = criterion(output, target)
99+
100+
# Run distributed backward pass
101+
dist_autograd.backward(context_id, [loss])
102+
103+
# Tun distributed optimizer
104+
opt.step(context_id)
105+
106+
# Not necessary to zero grads as each iteration creates a different
107+
# distributed autograd context which hosts different grads
108+
print("Training done for epoch {}".format(epoch))
109+
# END run_trainer
110+
111+
# BEGIN run_worker
112+
def run_worker(rank, world_size):
113+
r"""
114+
A wrapper function that initializes RPC, calls the function, and shuts down
115+
RPC.
116+
"""
117+
118+
# We need to use different port numbers in TCP init_method for init_rpc and
119+
# init_process_group to avoid port conflicts.
120+
rpc_backend_options = TensorPipeRpcBackendOptions()
121+
rpc_backend_options.init_method = "tcp://localhost:29501"
122+
123+
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
124+
if rank == 2:
125+
rpc.init_rpc(
126+
"master",
127+
rank=rank,
128+
world_size=world_size,
129+
rpc_backend_options=rpc_backend_options,
130+
)
131+
132+
remote_emb_module = RemoteModule(
133+
"ps",
134+
torch.nn.EmbeddingBag,
135+
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
136+
kwargs={"mode": "sum"},
137+
)
138+
139+
# Run the training loop on trainers.
140+
futs = []
141+
for trainer_rank in [0, 1]:
142+
trainer_name = "trainer{}".format(trainer_rank)
143+
fut = rpc.rpc_async(
144+
trainer_name, _run_trainer, args=(remote_emb_module, rank)
145+
)
146+
futs.append(fut)
147+
148+
# Wait for all training to finish.
149+
for fut in futs:
150+
fut.wait()
151+
elif rank <= 1:
152+
# Initialize process group for Distributed DataParallel on trainers.
153+
dist.init_process_group(
154+
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
155+
)
156+
157+
# Initialize RPC.
158+
trainer_name = "trainer{}".format(rank)
159+
rpc.init_rpc(
160+
trainer_name,
161+
rank=rank,
162+
world_size=world_size,
163+
rpc_backend_options=rpc_backend_options,
164+
)
165+
166+
# Trainer just waits for RPCs from master.
167+
else:
168+
rpc.init_rpc(
169+
"ps",
170+
rank=rank,
171+
world_size=world_size,
172+
rpc_backend_options=rpc_backend_options,
173+
)
174+
# parameter server do nothing
175+
pass
176+
177+
# block until all rpcs finish
178+
rpc.shutdown()
179+
180+
181+
if __name__ == "__main__":
182+
# 2 trainers, 1 parameter server, 1 master.
183+
world_size = 4
184+
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
185+
# END run_worker

0 commit comments

Comments
 (0)