-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Adding distributed pipeline parallelism example #749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| Distributed Pipeline Parallel Example | ||
|
|
||
| This example shows how to distribute a ResNet50 model on two RPC workers and | ||
| then implement distributed pipeline parallelism using RPC. | ||
|
|
||
| ``` | ||
| pip install -r requirements.txt | ||
| python main.py | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,288 @@ | ||
| import os | ||
| import threading | ||
| import time | ||
| from functools import wraps | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.distributed.autograd as dist_autograd | ||
| import torch.distributed.rpc as rpc | ||
| import torch.multiprocessing as mp | ||
| import torch.optim as optim | ||
| from torch.distributed.optim import DistributedOptimizer | ||
| from torch.distributed.rpc import RRef | ||
|
|
||
| from torchvision.models.resnet import Bottleneck | ||
|
|
||
|
|
||
| ######################################################### | ||
| # helper functions # | ||
| ######################################################### | ||
|
|
||
|
|
||
| def _call_method(method, rref, *args, **kwargs): | ||
| r""" | ||
| a helper function to call a method on the given RRef | ||
| """ | ||
| return method(rref.local_value(), *args, **kwargs) | ||
|
|
||
|
|
||
| def _remote_on_rref(method, rref, *args, **kwargs): | ||
| r""" | ||
| a helper function to run method on the owner of rref and return an RRef | ||
| of the result. | ||
| """ | ||
| return rpc.remote( | ||
| rref.owner(), | ||
| _call_method, | ||
| args=[method, rref] + list(args), | ||
| kwargs=kwargs | ||
| ) | ||
|
|
||
|
|
||
| def _async_on_rref(method, rref, *args, **kwargs): | ||
| r""" | ||
| a helper function to run method on the owner of rref and fetch back the | ||
| result using RPC | ||
| """ | ||
| return rpc.rpc_async( | ||
| rref.owner(), | ||
| _call_method, | ||
| args=[method, rref] + list(args), | ||
| kwargs=kwargs | ||
| ) | ||
|
|
||
|
|
||
| def _parameter_rrefs(module): | ||
| r""" | ||
| Create one RRef for each parameter in the given local module, and return a | ||
| list of RRefs. | ||
| """ | ||
| param_rrefs = [] | ||
| for param in module.parameters(): | ||
| param_rrefs.append(RRef(param)) | ||
| return param_rrefs | ||
|
|
||
|
|
||
| ######################################################### | ||
| # Define Model Parallel ResNet50 # | ||
| ######################################################### | ||
|
|
||
|
|
||
| num_classes = 1000 | ||
|
|
||
|
|
||
| def conv1x1(in_planes, out_planes, stride=1): | ||
| """1x1 convolution""" | ||
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||
|
|
||
|
|
||
| class ResNetBase(nn.Module): | ||
| def __init__(self, block, inplanes, num_classes=1000, | ||
| groups=1, width_per_group=64, norm_layer=None): | ||
| super(ResNetBase, self).__init__() | ||
|
|
||
| self._lock = threading.Lock() | ||
| self._block = block | ||
| self._norm_layer = nn.BatchNorm2d | ||
| self.inplanes = inplanes | ||
| self.dilation = 1 | ||
| self.groups = groups | ||
| self.base_width = width_per_group | ||
|
|
||
| def _make_layer(self, planes, blocks, stride=1): | ||
| norm_layer = self._norm_layer | ||
| downsample = None | ||
| previous_dilation = self.dilation | ||
| if stride != 1 or self.inplanes != planes * self._block.expansion: | ||
| downsample = nn.Sequential( | ||
| conv1x1(self.inplanes, planes * self._block.expansion, stride), | ||
| norm_layer(planes * self._block.expansion), | ||
| ) | ||
|
|
||
| layers = [] | ||
| layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups, | ||
| self.base_width, previous_dilation, norm_layer)) | ||
| self.inplanes = planes * self._block.expansion | ||
| for _ in range(1, blocks): | ||
| layers.append(self._block(self.inplanes, planes, groups=self.groups, | ||
| base_width=self.base_width, dilation=self.dilation, | ||
| norm_layer=norm_layer)) | ||
|
|
||
| return nn.Sequential(*layers) | ||
|
|
||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should include some comments about what we're doing here at a high level (defining resnet with 2 partitions so we can place them on separate machines.) Also, should we call these Partitions or Shards instead of parts? |
||
| class ResNetPart1(ResNetBase): | ||
| """ | ||
| The first part of ResNet. | ||
| """ | ||
| def __init__(self, device, *args, **kwargs): | ||
| super(ResNetPart1, self).__init__( | ||
| Bottleneck, 64, num_classes=num_classes, *args, **kwargs) | ||
|
|
||
| self.device = device | ||
| self.seq = nn.Sequential( | ||
| nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), | ||
| self._norm_layer(self.inplanes), | ||
| nn.ReLU(inplace=True), | ||
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | ||
| self._make_layer(64, 3), | ||
| self._make_layer(128, 4, stride=2) | ||
| ).to(self.device) | ||
|
|
||
| for m in self.modules(): | ||
| if isinstance(m, nn.Conv2d): | ||
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
| elif isinstance(m, nn.BatchNorm2d): | ||
| nn.init.constant_(m.weight, 1) | ||
| nn.init.constant_(m.bias, 0) | ||
|
|
||
| def forward(self, x_rref): | ||
| x = x_rref.to_here().to(self.device) | ||
| with self._lock: | ||
| out = self.seq(x) | ||
| return out.cpu() | ||
|
|
||
|
|
||
| class ResNetPart2(ResNetBase): | ||
| """ | ||
| The second part of ResNet. | ||
| """ | ||
| def __init__(self, device, *args, **kwargs): | ||
| super(ResNetPart2, self).__init__( | ||
| Bottleneck, 512, num_classes=num_classes, *args, **kwargs) | ||
|
|
||
| self.device = device | ||
| self.seq = nn.Sequential( | ||
| self._make_layer(256, 6, stride=2), | ||
| self._make_layer(512, 3, stride=2), | ||
| nn.AdaptiveAvgPool2d((1, 1)), | ||
| ).to(self.device) | ||
|
|
||
| self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device) | ||
|
|
||
| def forward(self, x_rref): | ||
| x = x_rref.to_here().to(self.device) | ||
| with self._lock: | ||
| out = self.fc(torch.flatten(self.seq(x), 1)) | ||
| return out.cpu() | ||
|
|
||
|
|
||
| class DistResNet50(nn.Module): | ||
| """ | ||
| Assemble two parts as an nn.Module and define pipelining logic | ||
| """ | ||
| def __init__(self, split_size, workers, *args, **kwargs): | ||
| super(DistResNet50, self).__init__() | ||
|
|
||
| self.split_size = split_size | ||
|
|
||
| # Put the first part of the ResNet50 on workers[0] | ||
| self.p1_rref = rpc.remote( | ||
| workers[0], | ||
| ResNetPart1, | ||
| args = ("cuda:0",) + args, | ||
| kwargs = kwargs | ||
| ) | ||
|
|
||
| # Put the second part of the ResNet50 on workers[1] | ||
| self.p2_rref = rpc.remote( | ||
| workers[1], | ||
| ResNetPart2, | ||
| args = ("cuda:1",) + args, | ||
| kwargs = kwargs | ||
| ) | ||
|
|
||
| def forward(self, xs): | ||
| # Split the input batch xs into micro-batches, and collect async RPC | ||
| # futures into a list | ||
| out_futures = [] | ||
| for x in iter(xs.split(self.split_size, dim=0)): | ||
| x_rref = RRef(x) | ||
| y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref) | ||
| z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref) | ||
| out_futures.append(z_fut) | ||
|
|
||
| # wait for all RPC to finish | ||
| outs = [fut.wait() for fut in out_futures] | ||
| # cat all tensors into one tensor. | ||
| out = torch.cat(outs) | ||
| return out | ||
|
|
||
| def parameter_rrefs(self): | ||
| remote_params = [] | ||
| remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here()) | ||
| remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here()) | ||
| return remote_params | ||
|
|
||
|
|
||
| ######################################################### | ||
| # Run RPC Processes # | ||
| ######################################################### | ||
|
|
||
| num_batches = 3 | ||
| batch_size = 120 | ||
| image_w = 128 | ||
| image_h = 128 | ||
|
|
||
|
|
||
| def run_master(num_split): | ||
| # put the two model parts on worker1 and worker2 respectively | ||
| model = DistResNet50(split_size, ["worker1", "worker2"]) | ||
| loss_fn = nn.MSELoss() | ||
| opt = DistributedOptimizer( | ||
| optim.SGD, | ||
| model.parameter_rrefs(), | ||
| lr=0.05, | ||
| ) | ||
|
|
||
| one_hot_indices = torch.LongTensor(batch_size) \ | ||
| .random_(0, num_classes) \ | ||
| .view(batch_size, 1) | ||
|
|
||
| for i in range(num_batches): | ||
| print(f"Processing batch {i}") | ||
| # generate random inputs and labels | ||
| inputs = torch.randn(batch_size, 3, image_w, image_h) | ||
| labels = torch.zeros(batch_size, num_classes) \ | ||
| .scatter_(1, one_hot_indices, 1) | ||
|
|
||
| with dist_autograd.context() as context_id: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a short comment about what dist_autograd/dist_optimizer is doing here? |
||
| outputs = model(inputs) | ||
| dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) | ||
| opt.step(context_id) | ||
|
|
||
|
|
||
| def run_worker(rank, world_size, num_split): | ||
| os.environ['MASTER_ADDR'] = 'localhost' | ||
| os.environ['MASTER_PORT'] = '29500' | ||
| options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=256) | ||
|
|
||
| if rank == 0: | ||
| rpc.init_rpc( | ||
| "master", | ||
| rank=rank, | ||
| world_size=world_size, | ||
| rpc_backend_options=options | ||
| ) | ||
| run_master(num_split) | ||
| else: | ||
| rpc.init_rpc( | ||
| f"worker{rank}", | ||
| rank=rank, | ||
| world_size=world_size, | ||
| rpc_backend_options=options | ||
| ) | ||
| pass | ||
|
|
||
| # block until all rpcs finish | ||
| rpc.shutdown() | ||
|
|
||
|
|
||
| if __name__=="__main__": | ||
| world_size = 3 | ||
| for num_split in [1, 2, 4, 8]: | ||
| tik = time.time() | ||
| mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True) | ||
| tok = time.time() | ||
| print(f"number of splits = {num_split}, execution time = {tok - tik}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| torch==1.5.0 | ||
| torchvision==0.6.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we include a quick description of the pipelining strategy (pipelining micro-batches within a batch and then synchronously running the optimizer step)? Since this is like GPipe, should we also link the paper here?