|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import time |
| 4 | +from threading import Lock |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.distributed.autograd as dist_autograd |
| 8 | +import torch.distributed.rpc as rpc |
| 9 | +import torch.multiprocessing as mp |
| 10 | +import torch.nn as nn |
| 11 | +import torch.nn.functional as F |
| 12 | +from torch import optim |
| 13 | +from torch.distributed.optim import DistributedOptimizer |
| 14 | +from torchvision import datasets, transforms |
| 15 | + |
| 16 | +# --------- MNIST Network to train, from pytorch/examples ----- |
| 17 | + |
| 18 | + |
| 19 | +class Net(nn.Module): |
| 20 | + def __init__(self, num_gpus=0): |
| 21 | + super(Net, self).__init__() |
| 22 | + print(f"Using {num_gpus} GPUs to train") |
| 23 | + self.num_gpus = num_gpus |
| 24 | + device = torch.device( |
| 25 | + "cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu") |
| 26 | + print(f"Putting first 2 convs on {str(device)}") |
| 27 | + # Put conv layers on the first cuda device |
| 28 | + self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device) |
| 29 | + self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device) |
| 30 | + # Put rest of the network on the 2nd cuda device, if there is one |
| 31 | + if "cuda" in str(device) and num_gpus > 1: |
| 32 | + device = torch.device("cuda:1") |
| 33 | + |
| 34 | + print(f"Putting rest of layers on {str(device)}") |
| 35 | + self.dropout1 = nn.Dropout2d(0.25).to(device) |
| 36 | + self.dropout2 = nn.Dropout2d(0.5).to(device) |
| 37 | + self.fc1 = nn.Linear(9216, 128).to(device) |
| 38 | + self.fc2 = nn.Linear(128, 10).to(device) |
| 39 | + |
| 40 | + def forward(self, x): |
| 41 | + x = self.conv1(x) |
| 42 | + x = F.relu(x) |
| 43 | + x = self.conv2(x) |
| 44 | + x = F.max_pool2d(x, 2) |
| 45 | + |
| 46 | + x = self.dropout1(x) |
| 47 | + x = torch.flatten(x, 1) |
| 48 | + # Move tensor to next device if necessary |
| 49 | + next_device = next(self.fc1.parameters()).device |
| 50 | + x = x.to(next_device) |
| 51 | + |
| 52 | + x = self.fc1(x) |
| 53 | + x = F.relu(x) |
| 54 | + x = self.dropout2(x) |
| 55 | + x = self.fc2(x) |
| 56 | + output = F.log_softmax(x, dim=1) |
| 57 | + return output |
| 58 | + |
| 59 | + |
| 60 | +# --------- Helper Methods -------------------- |
| 61 | + |
| 62 | +# On the local node, call a method with first arg as the value held by the |
| 63 | +# RRef. Other args are passed in as arguments to the function called. |
| 64 | +# Useful for calling instance methods. |
| 65 | +def call_method(method, rref, *args, **kwargs): |
| 66 | + return method(rref.local_value(), *args, **kwargs) |
| 67 | + |
| 68 | +# Given an RRef, return the result of calling the passed in method on the value |
| 69 | +# held by the RRef. This call is done on the remote node that owns |
| 70 | +# the RRef. args and kwargs are passed into the method. |
| 71 | +# Example: If the value held by the RRef is of type Foo, then |
| 72 | +# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling |
| 73 | +# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result |
| 74 | +# back. |
| 75 | + |
| 76 | + |
| 77 | +def remote_method(method, rref, *args, **kwargs): |
| 78 | + args = [method, rref] + list(args) |
| 79 | + return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs) |
| 80 | + |
| 81 | + |
| 82 | +# --------- Parameter Server -------------------- |
| 83 | +class ParameterServer(nn.Module): |
| 84 | + def __init__(self, num_gpus=0): |
| 85 | + super().__init__() |
| 86 | + model = Net(num_gpus=num_gpus) |
| 87 | + self.model = model |
| 88 | + self.input_device = torch.device( |
| 89 | + "cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu") |
| 90 | + |
| 91 | + def forward(self, inp): |
| 92 | + inp = inp.to(self.input_device) |
| 93 | + out = self.model(inp) |
| 94 | + # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors. |
| 95 | + # Tensors must be moved in and out of GPU memory due to this. |
| 96 | + out = out.to("cpu") |
| 97 | + return out |
| 98 | + |
| 99 | + # Use dist autograd to retrieve gradients accumulated for this model. |
| 100 | + # Primarily used for verification. |
| 101 | + def get_dist_gradients(self, cid): |
| 102 | + grads = dist_autograd.get_gradients(cid) |
| 103 | + # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors. |
| 104 | + # Tensors must be moved in and out of GPU memory due to this. |
| 105 | + cpu_grads = {} |
| 106 | + for k, v in grads.items(): |
| 107 | + k_cpu, v_cpu = k.to("cpu"), v.to("cpu") |
| 108 | + cpu_grads[k_cpu] = v_cpu |
| 109 | + return cpu_grads |
| 110 | + |
| 111 | + # Wrap local parameters in a RRef. Needed for building the |
| 112 | + # DistributedOptimizer which optimizes paramters remotely. |
| 113 | + def get_param_rrefs(self): |
| 114 | + param_rrefs = [rpc.RRef(param) for param in self.model.parameters()] |
| 115 | + return param_rrefs |
| 116 | + |
| 117 | + |
| 118 | +param_server = None |
| 119 | +global_lock = Lock() |
| 120 | + |
| 121 | + |
| 122 | +def get_parameter_server(num_gpus=0): |
| 123 | + global param_server |
| 124 | + # Ensure that we get only one handle to the ParameterServer. |
| 125 | + with global_lock: |
| 126 | + if not param_server: |
| 127 | + # construct it once |
| 128 | + param_server = ParameterServer(num_gpus=num_gpus) |
| 129 | + return param_server |
| 130 | + |
| 131 | + |
| 132 | +def run_parameter_server(rank, world_size): |
| 133 | + # The parameter server just acts as a host for the model and responds to |
| 134 | + # requests from trainers, hence it does not need to run a loop. |
| 135 | + # rpc.shutdown() will wait for all workers to complete by default, which |
| 136 | + # in this case means that the parameter server will wait for all trainers |
| 137 | + # to complete, and then exit. |
| 138 | + print("PS master initializing RPC") |
| 139 | + rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size) |
| 140 | + print("RPC initialized! Running parameter server...") |
| 141 | + rpc.shutdown() |
| 142 | + print("RPC shutdown on parameter server.") |
| 143 | + |
| 144 | + |
| 145 | +# --------- Trainers -------------------- |
| 146 | + |
| 147 | +# nn.Module corresponding to the network trained by this trainer. The |
| 148 | +# forward() method simply invokes the network on the given parameter |
| 149 | +# server. |
| 150 | +class TrainerNet(nn.Module): |
| 151 | + def __init__(self, num_gpus=0): |
| 152 | + super().__init__() |
| 153 | + self.num_gpus = num_gpus |
| 154 | + self.param_server_rref = rpc.remote( |
| 155 | + "parameter_server", get_parameter_server, args=(num_gpus,)) |
| 156 | + |
| 157 | + def get_global_param_rrefs(self): |
| 158 | + remote_params = remote_method( |
| 159 | + ParameterServer.get_param_rrefs, |
| 160 | + self.param_server_rref) |
| 161 | + return remote_params |
| 162 | + |
| 163 | + def forward(self, x, cid): |
| 164 | + model_output = remote_method( |
| 165 | + ParameterServer.forward, self.param_server_rref, x) |
| 166 | + return model_output |
| 167 | + |
| 168 | + |
| 169 | +def run_training_loop(rank, num_gpus, train_loader, test_loader): |
| 170 | + # Runs the typical nueral network forward + backward + optimizer step, but |
| 171 | + # in a distributed fashion. |
| 172 | + net = TrainerNet(num_gpus=num_gpus) |
| 173 | + # Build DistributedOptmizer. |
| 174 | + param_rrefs = net.get_global_param_rrefs() |
| 175 | + opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03) |
| 176 | + for i, (data, target) in enumerate(train_loader): |
| 177 | + with dist_autograd.context() as cid: |
| 178 | + model_output = net(data, cid) |
| 179 | + target = target.to(model_output.device) |
| 180 | + loss = F.nll_loss(model_output, target) |
| 181 | + if i % 5 == 0: |
| 182 | + print(f"Rank {rank} training batch {i} loss {loss.item()}") |
| 183 | + dist_autograd.backward(cid, [loss]) |
| 184 | + # Ensure that dist autograd ran successfully and gradients were |
| 185 | + # returned. |
| 186 | + assert remote_method( |
| 187 | + ParameterServer.get_dist_gradients, |
| 188 | + net.param_server_rref, |
| 189 | + cid) != {} |
| 190 | + opt.step(cid) |
| 191 | + |
| 192 | + print("Training complete!") |
| 193 | + print("Getting accuracy....") |
| 194 | + get_accuracy(test_loader, net) |
| 195 | + |
| 196 | + |
| 197 | +def get_accuracy(test_loader, model): |
| 198 | + model.eval() |
| 199 | + correct_sum = 0 |
| 200 | + # Use GPU to evaluate if possible |
| 201 | + device = torch.device("cuda:0" if model.num_gpus > 0 |
| 202 | + and torch.cuda.is_available() else "cpu") |
| 203 | + with torch.no_grad(): |
| 204 | + for i, (data, target) in enumerate(test_loader): |
| 205 | + out = model(data, -1) |
| 206 | + pred = out.argmax(dim=1, keepdim=True) |
| 207 | + pred, target = pred.to(device), target.to(device) |
| 208 | + correct = pred.eq(target.view_as(pred)).sum().item() |
| 209 | + correct_sum += correct |
| 210 | + |
| 211 | + print(f"Accuracy {correct_sum / len(test_loader.dataset)}") |
| 212 | + |
| 213 | + |
| 214 | +# Main loop for trainers. |
| 215 | +def run_worker(rank, world_size, num_gpus, train_loader, test_loader): |
| 216 | + print(f"Worker rank {rank} initializing RPC") |
| 217 | + rpc.init_rpc( |
| 218 | + name=f"trainer_{rank}", |
| 219 | + rank=rank, |
| 220 | + world_size=world_size) |
| 221 | + |
| 222 | + print(f"Worker {rank} done initializing RPC") |
| 223 | + |
| 224 | + run_training_loop(rank, num_gpus, train_loader, test_loader) |
| 225 | + rpc.shutdown() |
| 226 | + |
| 227 | +# --------- Launcher -------------------- |
| 228 | + |
| 229 | + |
| 230 | +if __name__ == '__main__': |
| 231 | + parser = argparse.ArgumentParser( |
| 232 | + description="Parameter-Server RPC based training") |
| 233 | + parser.add_argument( |
| 234 | + "world_size", |
| 235 | + type=int, |
| 236 | + default=4, |
| 237 | + help="""Total number of participating processes. Should be the sum of |
| 238 | + master node and all training nodes.""") |
| 239 | + parser.add_argument( |
| 240 | + "rank", |
| 241 | + type=int, |
| 242 | + default=None, |
| 243 | + help="Global rank of this process. Pass in 0 for master.") |
| 244 | + parser.add_argument( |
| 245 | + "num_gpus", |
| 246 | + type=int, |
| 247 | + default=0, |
| 248 | + help="""Number of GPUs to use for training, Currently supports between 0 |
| 249 | + and 2 GPUs. Note that this argument will be passed to the parameter servers.""") |
| 250 | + parser.add_argument( |
| 251 | + "--master_addr", |
| 252 | + type=str, |
| 253 | + default="localhost", |
| 254 | + help="""Address of master, will default to localhost if not provided. |
| 255 | + Master must be able to accept network traffic on the address + port.""") |
| 256 | + parser.add_argument( |
| 257 | + "--master_port", |
| 258 | + type=str, |
| 259 | + default="29500", |
| 260 | + help="""Port that master is listening on, will default to 29500 if not |
| 261 | + provided. Master must be able to accept network traffic on the host and port.""") |
| 262 | + |
| 263 | + args = parser.parse_args() |
| 264 | + assert args.rank is not None, "must provide rank argument." |
| 265 | + os.environ['MASTER_ADDR'] = args.master_addr |
| 266 | + os.environ["MASTER_PORT"] = args.master_port |
| 267 | + processes = [] |
| 268 | + world_size = args.world_size |
| 269 | + if args.rank == 0: |
| 270 | + p = mp.Process(target=run_parameter_server, args=(0, world_size)) |
| 271 | + p.start() |
| 272 | + processes.append(p) |
| 273 | + else: |
| 274 | + # Get data to train on |
| 275 | + train_loader = torch.utils.data.DataLoader( |
| 276 | + datasets.MNIST('../data', train=True, download=True, |
| 277 | + transform=transforms.Compose([ |
| 278 | + transforms.ToTensor(), |
| 279 | + transforms.Normalize((0.1307,), (0.3081,)) |
| 280 | + ])), |
| 281 | + batch_size=32, shuffle=True,) |
| 282 | + test_loader = torch.utils.data.DataLoader( |
| 283 | + datasets.MNIST( |
| 284 | + '../data', |
| 285 | + train=False, |
| 286 | + transform=transforms.Compose( |
| 287 | + [ |
| 288 | + transforms.ToTensor(), |
| 289 | + transforms.Normalize( |
| 290 | + (0.1307, |
| 291 | + ), |
| 292 | + (0.3081, |
| 293 | + ))])), |
| 294 | + batch_size=32, |
| 295 | + shuffle=True, |
| 296 | + ) |
| 297 | + # start training worker on this node |
| 298 | + p = mp.Process( |
| 299 | + target=run_worker, |
| 300 | + args=( |
| 301 | + args.rank, |
| 302 | + world_size, args.num_gpus, |
| 303 | + train_loader, |
| 304 | + test_loader)) |
| 305 | + p.start() |
| 306 | + processes.append(p) |
| 307 | + |
| 308 | + for p in processes: |
| 309 | + p.join() |
0 commit comments