Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions distributed/rpc/pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Distributed Pipeline Parallel Example

This example shows how to distribute a ResNet50 model on two RPC workers and
then implement distributed pipeline parallelism using RPC.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include a quick description of the pipelining strategy (pipelining micro-batches within a batch and then synchronously running the optimizer step)? Since this is like GPipe, should we also link the paper here?


```
pip install -r requirements.txt
python main.py
```
288 changes: 288 additions & 0 deletions distributed/rpc/pipeline/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import os
import threading
import time
from functools import wraps

import torch
import torch.nn as nn
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.optim as optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef

from torchvision.models.resnet import Bottleneck


#########################################################
# helper functions #
#########################################################


def _call_method(method, rref, *args, **kwargs):
r"""
a helper function to call a method on the given RRef
"""
return method(rref.local_value(), *args, **kwargs)


def _remote_on_rref(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and return an RRef
of the result.
"""
return rpc.remote(
rref.owner(),
_call_method,
args=[method, rref] + list(args),
kwargs=kwargs
)


def _async_on_rref(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and fetch back the
result using RPC
"""
return rpc.rpc_async(
rref.owner(),
_call_method,
args=[method, rref] + list(args),
kwargs=kwargs
)


def _parameter_rrefs(module):
r"""
Create one RRef for each parameter in the given local module, and return a
list of RRefs.
"""
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs


#########################################################
# Define Model Parallel ResNet50 #
#########################################################


num_classes = 1000


def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class ResNetBase(nn.Module):
def __init__(self, block, inplanes, num_classes=1000,
groups=1, width_per_group=64, norm_layer=None):
super(ResNetBase, self).__init__()

self._lock = threading.Lock()
self._block = block
self._norm_layer = nn.BatchNorm2d
self.inplanes = inplanes
self.dilation = 1
self.groups = groups
self.base_width = width_per_group

def _make_layer(self, planes, blocks, stride=1):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if stride != 1 or self.inplanes != planes * self._block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * self._block.expansion, stride),
norm_layer(planes * self._block.expansion),
)

layers = []
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * self._block.expansion
for _ in range(1, blocks):
layers.append(self._block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should include some comments about what we're doing here at a high level (defining resnet with 2 partitions so we can place them on separate machines.) Also, should we call these Partitions or Shards instead of parts?

class ResNetPart1(ResNetBase):
"""
The first part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetPart1, self).__init__(
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)

self.device = device
self.seq = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
self._norm_layer(self.inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
self._make_layer(64, 3),
self._make_layer(128, 4, stride=2)
).to(self.device)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
with self._lock:
out = self.seq(x)
return out.cpu()


class ResNetPart2(ResNetBase):
"""
The second part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetPart2, self).__init__(
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)

self.device = device
self.seq = nn.Sequential(
self._make_layer(256, 6, stride=2),
self._make_layer(512, 3, stride=2),
nn.AdaptiveAvgPool2d((1, 1)),
).to(self.device)

self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)

def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
with self._lock:
out = self.fc(torch.flatten(self.seq(x), 1))
return out.cpu()


class DistResNet50(nn.Module):
"""
Assemble two parts as an nn.Module and define pipelining logic
"""
def __init__(self, split_size, workers, *args, **kwargs):
super(DistResNet50, self).__init__()

self.split_size = split_size

# Put the first part of the ResNet50 on workers[0]
self.p1_rref = rpc.remote(
workers[0],
ResNetPart1,
args = ("cuda:0",) + args,
kwargs = kwargs
)

# Put the second part of the ResNet50 on workers[1]
self.p2_rref = rpc.remote(
workers[1],
ResNetPart2,
args = ("cuda:1",) + args,
kwargs = kwargs
)

def forward(self, xs):
# Split the input batch xs into micro-batches, and collect async RPC
# futures into a list
out_futures = []
for x in iter(xs.split(self.split_size, dim=0)):
x_rref = RRef(x)
y_rref = _remote_on_rref(ResNetPart1.forward, self.p1_rref, x_rref)
z_fut = _async_on_rref(ResNetPart2.forward, self.p2_rref, y_rref)
out_futures.append(z_fut)

# wait for all RPC to finish
outs = [fut.wait() for fut in out_futures]
# cat all tensors into one tensor.
out = torch.cat(outs)
return out

def parameter_rrefs(self):
remote_params = []
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p1_rref).to_here())
remote_params.extend(_remote_on_rref(_parameter_rrefs, self.p2_rref).to_here())
return remote_params


#########################################################
# Run RPC Processes #
#########################################################

num_batches = 3
batch_size = 120
image_w = 128
image_h = 128


def run_master(num_split):
# put the two model parts on worker1 and worker2 respectively
model = DistResNet50(split_size, ["worker1", "worker2"])
loss_fn = nn.MSELoss()
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,
)

one_hot_indices = torch.LongTensor(batch_size) \
.random_(0, num_classes) \
.view(batch_size, 1)

for i in range(num_batches):
print(f"Processing batch {i}")
# generate random inputs and labels
inputs = torch.randn(batch_size, 3, image_w, image_h)
labels = torch.zeros(batch_size, num_classes) \
.scatter_(1, one_hot_indices, 1)

with dist_autograd.context() as context_id:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a short comment about what dist_autograd/dist_optimizer is doing here?

outputs = model(inputs)
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
opt.step(context_id)


def run_worker(rank, world_size, num_split):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=256)

if rank == 0:
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
run_master(num_split)
else:
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
pass

# block until all rpcs finish
rpc.shutdown()


if __name__=="__main__":
world_size = 3
for num_split in [1, 2, 4, 8]:
tik = time.time()
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
tok = time.time()
print(f"number of splits = {num_split}, execution time = {tok - tik}")
2 changes: 2 additions & 0 deletions distributed/rpc/pipeline/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch==1.5.0
torchvision==0.6.0