Skip to content

Commit 8a5b379

Browse files
author
Jessica Lin
authored
Merge pull request #705 from rohan-varma/add_param_server
Simple example to demonstrate parameter server training pattern
2 parents 4902431 + 6d29c3b commit 8a5b379

File tree

2 files changed

+320
-0
lines changed

2 files changed

+320
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
### RPC-based distributed training
2+
3+
This is a basic example of RPC-based training that uses several trainers remotely train a model hosted on a server.
4+
5+
To run the example locally, run the following command worker for the server and each worker you wish to spawn, in separate terminal windows:
6+
`python rpc_parameter_server.py [world_size] [rank] [num_gpus]`. For example, for a master node with world size of 2, the command would be `python rpc_parameter_server.py 2 0 0`. The trainer can then be launched with the command `python rpc_parameter_server.py 2 1 0` in a separate window, and this will begin training with one server and a single trainer.
7+
8+
Note that for demonstration purposes, this example supports only between 0-2 GPUs, although the pattern can be extended to make use of additional GPUs.
9+
10+
You can pass in the command line arguments `--master_addr=<address>` and `master_port=PORT` to indicate the address:port that the master worker is listening on. All workers will contact the master for rendezvous during worker discovery. By default, `master_addr` will be `localhost` and `master_port` will be 29500.
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
import argparse
2+
import os
3+
import time
4+
from threading import Lock
5+
6+
import torch
7+
import torch.distributed.autograd as dist_autograd
8+
import torch.distributed.rpc as rpc
9+
import torch.multiprocessing as mp
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
from torch import optim
13+
from torch.distributed.optim import DistributedOptimizer
14+
from torchvision import datasets, transforms
15+
16+
# --------- MNIST Network to train, from pytorch/examples -----
17+
18+
19+
class Net(nn.Module):
20+
def __init__(self, num_gpus=0):
21+
super(Net, self).__init__()
22+
print(f"Using {num_gpus} GPUs to train")
23+
self.num_gpus = num_gpus
24+
device = torch.device(
25+
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
26+
print(f"Putting first 2 convs on {str(device)}")
27+
# Put conv layers on the first cuda device
28+
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
29+
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
30+
# Put rest of the network on the 2nd cuda device, if there is one
31+
if "cuda" in str(device) and num_gpus > 1:
32+
device = torch.device("cuda:1")
33+
34+
print(f"Putting rest of layers on {str(device)}")
35+
self.dropout1 = nn.Dropout2d(0.25).to(device)
36+
self.dropout2 = nn.Dropout2d(0.5).to(device)
37+
self.fc1 = nn.Linear(9216, 128).to(device)
38+
self.fc2 = nn.Linear(128, 10).to(device)
39+
40+
def forward(self, x):
41+
x = self.conv1(x)
42+
x = F.relu(x)
43+
x = self.conv2(x)
44+
x = F.max_pool2d(x, 2)
45+
46+
x = self.dropout1(x)
47+
x = torch.flatten(x, 1)
48+
# Move tensor to next device if necessary
49+
next_device = next(self.fc1.parameters()).device
50+
x = x.to(next_device)
51+
52+
x = self.fc1(x)
53+
x = F.relu(x)
54+
x = self.dropout2(x)
55+
x = self.fc2(x)
56+
output = F.log_softmax(x, dim=1)
57+
return output
58+
59+
60+
# --------- Helper Methods --------------------
61+
62+
# On the local node, call a method with first arg as the value held by the
63+
# RRef. Other args are passed in as arguments to the function called.
64+
# Useful for calling instance methods.
65+
def call_method(method, rref, *args, **kwargs):
66+
return method(rref.local_value(), *args, **kwargs)
67+
68+
# Given an RRef, return the result of calling the passed in method on the value
69+
# held by the RRef. This call is done on the remote node that owns
70+
# the RRef. args and kwargs are passed into the method.
71+
# Example: If the value held by the RRef is of type Foo, then
72+
# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling
73+
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
74+
# back.
75+
76+
77+
def remote_method(method, rref, *args, **kwargs):
78+
args = [method, rref] + list(args)
79+
return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)
80+
81+
82+
# --------- Parameter Server --------------------
83+
class ParameterServer(nn.Module):
84+
def __init__(self, num_gpus=0):
85+
super().__init__()
86+
model = Net(num_gpus=num_gpus)
87+
self.model = model
88+
self.input_device = torch.device(
89+
"cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")
90+
91+
def forward(self, inp):
92+
inp = inp.to(self.input_device)
93+
out = self.model(inp)
94+
# This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
95+
# Tensors must be moved in and out of GPU memory due to this.
96+
out = out.to("cpu")
97+
return out
98+
99+
# Use dist autograd to retrieve gradients accumulated for this model.
100+
# Primarily used for verification.
101+
def get_dist_gradients(self, cid):
102+
grads = dist_autograd.get_gradients(cid)
103+
# This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
104+
# Tensors must be moved in and out of GPU memory due to this.
105+
cpu_grads = {}
106+
for k, v in grads.items():
107+
k_cpu, v_cpu = k.to("cpu"), v.to("cpu")
108+
cpu_grads[k_cpu] = v_cpu
109+
return cpu_grads
110+
111+
# Wrap local parameters in a RRef. Needed for building the
112+
# DistributedOptimizer which optimizes paramters remotely.
113+
def get_param_rrefs(self):
114+
param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
115+
return param_rrefs
116+
117+
118+
param_server = None
119+
global_lock = Lock()
120+
121+
122+
def get_parameter_server(num_gpus=0):
123+
global param_server
124+
# Ensure that we get only one handle to the ParameterServer.
125+
with global_lock:
126+
if not param_server:
127+
# construct it once
128+
param_server = ParameterServer(num_gpus=num_gpus)
129+
return param_server
130+
131+
132+
def run_parameter_server(rank, world_size):
133+
# The parameter server just acts as a host for the model and responds to
134+
# requests from trainers, hence it does not need to run a loop.
135+
# rpc.shutdown() will wait for all workers to complete by default, which
136+
# in this case means that the parameter server will wait for all trainers
137+
# to complete, and then exit.
138+
print("PS master initializing RPC")
139+
rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size)
140+
print("RPC initialized! Running parameter server...")
141+
rpc.shutdown()
142+
print("RPC shutdown on parameter server.")
143+
144+
145+
# --------- Trainers --------------------
146+
147+
# nn.Module corresponding to the network trained by this trainer. The
148+
# forward() method simply invokes the network on the given parameter
149+
# server.
150+
class TrainerNet(nn.Module):
151+
def __init__(self, num_gpus=0):
152+
super().__init__()
153+
self.num_gpus = num_gpus
154+
self.param_server_rref = rpc.remote(
155+
"parameter_server", get_parameter_server, args=(num_gpus,))
156+
157+
def get_global_param_rrefs(self):
158+
remote_params = remote_method(
159+
ParameterServer.get_param_rrefs,
160+
self.param_server_rref)
161+
return remote_params
162+
163+
def forward(self, x, cid):
164+
model_output = remote_method(
165+
ParameterServer.forward, self.param_server_rref, x)
166+
return model_output
167+
168+
169+
def run_training_loop(rank, num_gpus, train_loader, test_loader):
170+
# Runs the typical nueral network forward + backward + optimizer step, but
171+
# in a distributed fashion.
172+
net = TrainerNet(num_gpus=num_gpus)
173+
# Build DistributedOptmizer.
174+
param_rrefs = net.get_global_param_rrefs()
175+
opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
176+
for i, (data, target) in enumerate(train_loader):
177+
with dist_autograd.context() as cid:
178+
model_output = net(data, cid)
179+
target = target.to(model_output.device)
180+
loss = F.nll_loss(model_output, target)
181+
if i % 5 == 0:
182+
print(f"Rank {rank} training batch {i} loss {loss.item()}")
183+
dist_autograd.backward(cid, [loss])
184+
# Ensure that dist autograd ran successfully and gradients were
185+
# returned.
186+
assert remote_method(
187+
ParameterServer.get_dist_gradients,
188+
net.param_server_rref,
189+
cid) != {}
190+
opt.step(cid)
191+
192+
print("Training complete!")
193+
print("Getting accuracy....")
194+
get_accuracy(test_loader, net)
195+
196+
197+
def get_accuracy(test_loader, model):
198+
model.eval()
199+
correct_sum = 0
200+
# Use GPU to evaluate if possible
201+
device = torch.device("cuda:0" if model.num_gpus > 0
202+
and torch.cuda.is_available() else "cpu")
203+
with torch.no_grad():
204+
for i, (data, target) in enumerate(test_loader):
205+
out = model(data, -1)
206+
pred = out.argmax(dim=1, keepdim=True)
207+
pred, target = pred.to(device), target.to(device)
208+
correct = pred.eq(target.view_as(pred)).sum().item()
209+
correct_sum += correct
210+
211+
print(f"Accuracy {correct_sum / len(test_loader.dataset)}")
212+
213+
214+
# Main loop for trainers.
215+
def run_worker(rank, world_size, num_gpus, train_loader, test_loader):
216+
print(f"Worker rank {rank} initializing RPC")
217+
rpc.init_rpc(
218+
name=f"trainer_{rank}",
219+
rank=rank,
220+
world_size=world_size)
221+
222+
print(f"Worker {rank} done initializing RPC")
223+
224+
run_training_loop(rank, num_gpus, train_loader, test_loader)
225+
rpc.shutdown()
226+
227+
# --------- Launcher --------------------
228+
229+
230+
if __name__ == '__main__':
231+
parser = argparse.ArgumentParser(
232+
description="Parameter-Server RPC based training")
233+
parser.add_argument(
234+
"world_size",
235+
type=int,
236+
default=4,
237+
help="""Total number of participating processes. Should be the sum of
238+
master node and all training nodes.""")
239+
parser.add_argument(
240+
"rank",
241+
type=int,
242+
default=None,
243+
help="Global rank of this process. Pass in 0 for master.")
244+
parser.add_argument(
245+
"num_gpus",
246+
type=int,
247+
default=0,
248+
help="""Number of GPUs to use for training, Currently supports between 0
249+
and 2 GPUs. Note that this argument will be passed to the parameter servers.""")
250+
parser.add_argument(
251+
"--master_addr",
252+
type=str,
253+
default="localhost",
254+
help="""Address of master, will default to localhost if not provided.
255+
Master must be able to accept network traffic on the address + port.""")
256+
parser.add_argument(
257+
"--master_port",
258+
type=str,
259+
default="29500",
260+
help="""Port that master is listening on, will default to 29500 if not
261+
provided. Master must be able to accept network traffic on the host and port.""")
262+
263+
args = parser.parse_args()
264+
assert args.rank is not None, "must provide rank argument."
265+
assert args.num_gpus <= 3, f"Only 0-2 GPUs currently supported (got {args.num_gpus})."
266+
os.environ['MASTER_ADDR'] = args.master_addr
267+
os.environ["MASTER_PORT"] = args.master_port
268+
processes = []
269+
world_size = args.world_size
270+
if args.rank == 0:
271+
p = mp.Process(target=run_parameter_server, args=(0, world_size))
272+
p.start()
273+
processes.append(p)
274+
else:
275+
# Get data to train on
276+
train_loader = torch.utils.data.DataLoader(
277+
datasets.MNIST('../data', train=True, download=True,
278+
transform=transforms.Compose([
279+
transforms.ToTensor(),
280+
transforms.Normalize((0.1307,), (0.3081,))
281+
])),
282+
batch_size=32, shuffle=True,)
283+
test_loader = torch.utils.data.DataLoader(
284+
datasets.MNIST(
285+
'../data',
286+
train=False,
287+
transform=transforms.Compose(
288+
[
289+
transforms.ToTensor(),
290+
transforms.Normalize(
291+
(0.1307,
292+
),
293+
(0.3081,
294+
))])),
295+
batch_size=32,
296+
shuffle=True,
297+
)
298+
# start training worker on this node
299+
p = mp.Process(
300+
target=run_worker,
301+
args=(
302+
args.rank,
303+
world_size, args.num_gpus,
304+
train_loader,
305+
test_loader))
306+
p.start()
307+
processes.append(p)
308+
309+
for p in processes:
310+
p.join()

0 commit comments

Comments
 (0)