From f3fd2296d67b0e08abd042e8b32ea470a608b276 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Apr 2020 14:36:35 -0700 Subject: [PATCH] Adding a distributed pipeline parallelism example for RPC --- distributed/rpc/pipeline/README.md | 9 + distributed/rpc/pipeline/main.py | 288 ++++++++++++++++++++++ distributed/rpc/pipeline/requirements.txt | 2 + 3 files changed, 299 insertions(+) create mode 100644 distributed/rpc/pipeline/README.md create mode 100644 distributed/rpc/pipeline/main.py create mode 100644 distributed/rpc/pipeline/requirements.txt diff --git a/distributed/rpc/pipeline/README.md b/distributed/rpc/pipeline/README.md new file mode 100644 index 0000000000..f495917a85 --- /dev/null +++ b/distributed/rpc/pipeline/README.md @@ -0,0 +1,9 @@ +Distributed Pipeline Parallel Example + +This example shows how to distribute a ResNet50 model on two RPC workers and +then implement distributed pipeline parallelism using RPC. + +``` +pip install -r requirements.txt +python main.py +``` diff --git a/distributed/rpc/pipeline/main.py b/distributed/rpc/pipeline/main.py new file mode 100644 index 0000000000..2fcd8220a0 --- /dev/null +++ b/distributed/rpc/pipeline/main.py @@ -0,0 +1,288 @@ +import os +import threading +import time +from functools import wraps + +import torch +import torch.nn as nn +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +import torch.optim as optim +from torch.distributed.optim import DistributedOptimizer +from torch.distributed.rpc import RRef + +from torchvision.models.resnet import Bottleneck + + +######################################################### +# helper functions # +######################################################### + + +def _call_method(method, rref, *args, **kwargs): + r""" + a helper function to call a method on the given RRef + """ + return method(rref.local_value(), *args, **kwargs) + + +def _remote_on_rref(method, rref, *args, **kwargs): + r""" + a helper function to run method on the owner of rref and return an RRef + of the result. + """ + return rpc.remote( + rref.owner(), + _call_method, + args=[method, rref] + list(args), + kwargs=kwargs + ) + + +def _async_on_rref(method, rref, *args, **kwargs): + r""" + a helper function to run method on the owner of rref and fetch back the + result using RPC + """ + return rpc.rpc_async( + rref.owner(), + _call_method, + args=[method, rref] + list(args), + kwargs=kwargs + ) + + +def _parameter_rrefs(module): + r""" + Create one RRef for each parameter in the given local module, and return a + list of RRefs. + """ + param_rrefs = [] + for param in module.parameters(): + param_rrefs.append(RRef(param)) + return param_rrefs + + +######################################################### +# Define Model Parallel ResNet50 # +######################################################### + + +num_classes = 1000 + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class ResNetBase(nn.Module): + def __init__(self, block, inplanes, num_classes=1000, + groups=1, width_per_group=64, norm_layer=None): + super(ResNetBase, self).__init__() + + self._lock = threading.Lock() + self._block = block + self._norm_layer = nn.BatchNorm2d + self.inplanes = inplanes + self.dilation = 1 + self.groups = groups + self.base_width = width_per_group + + def _make_layer(self, planes, blocks, stride=1): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if stride != 1 or self.inplanes != planes * self._block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * self._block.expansion, stride), + norm_layer(planes * self._block.expansion), + ) + + layers = [] + layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * self._block.expansion + for _ in range(1, blocks): + layers.append(self._block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + +class ResNetPart1(ResNetBase): + """ + The first part of ResNet. + """ + def __init__(self, device, *args, **kwargs): + super(ResNetPart1, self).__init__( + Bottleneck, 64, num_classes=num_classes, *args, **kwargs) + + self.device = device + self.seq = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + self._norm_layer(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + self._make_layer(64, 3), + self._make_layer(128, 4, stride=2) + ).to(self.device) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x_rref): + x = x_rref.to_here().to(self.device) + with self._lock: + out = self.seq(x) + return out.cpu() + + +class ResNetPart2(ResNetBase): + """ + The second part of ResNet. + """ + def __init__(self, device, *args, **kwargs): + super(ResNetPart2, self).__init__( + Bottleneck, 512, num_classes=num_classes, *args, **kwargs) + + self.device = device + self.seq = nn.Sequential( + self._make_layer(256, 6, stride=2), + self._make_layer(512, 3, stride=2), + nn.AdaptiveAvgPool2d((1, 1)), + ).to(self.device) + + self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device) + + def forward(self, x_rref): + x = x_rref.to_here().to(self.device) + with self._lock: + out = self.fc(torch.flatten(self.seq(x), 1)) + return out.cpu() + + +class DistResNet50(nn.Module): + """ + Assemble two parts as an nn.Module and define pipelining logic + """ + def __init__(self, split_size, workers, *args, **kwargs): + super(DistResNet50, self).__init__() + + self.split_size = split_size + + # Put the first part of the ResNet50 on workers[0] + self.p1_rref = rpc.remote( + workers[0], + ResNetPart1, + args = ("cuda:0",) + args, + kwargs = kwargs + ) + + # Put the second part of the ResNet50 on workers[1] + self.p2_rref = rpc.remote( + workers[1], + ResNetPart2, + args = ("cuda:1",) + args, + kwargs = kwargs + ) + + def forward(self, xs): + # Split the input batch xs into micro-batches, and collect async RPC + # futures into a list + out_futures = [] + for x in iter(xs.split(self.split_size, dim=0)): + x_rref = RRef(x) + y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref) + z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref) + out_futures.append(z_fut) + + # wait for all RPC to finish + outs = [fut.wait() for fut in out_futures] + # cat all tensors into one tensor. + out = torch.cat(outs) + return out + + def parameter_rrefs(self): + remote_params = [] + remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here()) + remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here()) + return remote_params + + +######################################################### +# Run RPC Processes # +######################################################### + +num_batches = 3 +batch_size = 120 +image_w = 128 +image_h = 128 + + +def run_master(num_split): + # put the two model parts on worker1 and worker2 respectively + model = DistResNet50(split_size, ["worker1", "worker2"]) + loss_fn = nn.MSELoss() + opt = DistributedOptimizer( + optim.SGD, + model.parameter_rrefs(), + lr=0.05, + ) + + one_hot_indices = torch.LongTensor(batch_size) \ + .random_(0, num_classes) \ + .view(batch_size, 1) + + for i in range(num_batches): + print(f"Processing batch {i}") + # generate random inputs and labels + inputs = torch.randn(batch_size, 3, image_w, image_h) + labels = torch.zeros(batch_size, num_classes) \ + .scatter_(1, one_hot_indices, 1) + + with dist_autograd.context() as context_id: + outputs = model(inputs) + dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) + opt.step(context_id) + + +def run_worker(rank, world_size, num_split): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29500' + options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=256) + + if rank == 0: + rpc.init_rpc( + "master", + rank=rank, + world_size=world_size, + rpc_backend_options=options + ) + run_master(num_split) + else: + rpc.init_rpc( + f"worker{rank}", + rank=rank, + world_size=world_size, + rpc_backend_options=options + ) + pass + + # block until all rpcs finish + rpc.shutdown() + + +if __name__=="__main__": + world_size = 3 + for num_split in [1, 2, 4, 8]: + tik = time.time() + mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True) + tok = time.time() + print(f"number of splits = {num_split}, execution time = {tok - tik}") \ No newline at end of file diff --git a/distributed/rpc/pipeline/requirements.txt b/distributed/rpc/pipeline/requirements.txt new file mode 100644 index 0000000000..6aea1bd158 --- /dev/null +++ b/distributed/rpc/pipeline/requirements.txt @@ -0,0 +1,2 @@ +torch==1.5.0 +torchvision==0.6.0 \ No newline at end of file