|
| 1 | +Distributed Pipeline Parallelism Using RPC |
| 2 | +========================================== |
| 3 | +**Author**: `Shen Li <https://mrshenli.github.io/>`_ |
| 4 | + |
| 5 | +Prerequisite: |
| 6 | + |
| 7 | +- `Single-Machine Model Parallel Best Practices <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__ |
| 8 | +- `Getting started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__ |
| 9 | +- RRef helper functions: |
| 10 | + `RRef.rpc_sync() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__, |
| 11 | + `RRef.rpc_async() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_async>`__, and |
| 12 | + `RRef.remote() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.remote>`__ |
| 13 | + |
| 14 | + |
| 15 | +This tutorial uses a Resnet50 model to demonstrate implementing distributed |
| 16 | +pipeline parallelism with `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__ |
| 17 | +APIs. This can be viewed as the distributed counterpart of the multi-GPU |
| 18 | +pipeline parallelism discussed in |
| 19 | +`Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_. |
| 20 | + |
| 21 | + |
| 22 | +Basics |
| 23 | +------ |
| 24 | + |
| 25 | + |
| 26 | +The previous tutorial, `Getting Started with Distributed RPC Framework <rpc_tutorial.html>`_ |
| 27 | +shows how to use `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`_ |
| 28 | +to implement distributed model parallelism for an RNN model. That tutorial uses |
| 29 | +one GPU to host the ``EmbeddingTable``, and the provided code works fine. |
| 30 | +However, if a model lives on multiple GPUs, it would require some extra steps to |
| 31 | +increase the amortized utilization of all GPUs. Pipeline parallelism is one type |
| 32 | +of technique that can help in this case. |
| 33 | + |
| 34 | +In this tutorial, we use ``ResNet50`` as an example model which is also used by |
| 35 | +the `Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_ |
| 36 | +tutorial. Similarly, the ``ResNet50`` model is divided into two shards and |
| 37 | +the input batch is partitioned into multiple splits and fed iinto the two model |
| 38 | +shards in a pipelined fashion. The difference is that, instead of parallelize |
| 39 | +the execution using CUDA streams, this tutorial invokes asynchronous RPCs. So, |
| 40 | +the solution presented in this tutorial also works across machine boundaries. |
| 41 | +The remainder of this tutorial presents the implementation in four steps. |
| 42 | + |
| 43 | + |
| 44 | + |
| 45 | +Step 1: Partition ResNet50 Model |
| 46 | +-------------------------------- |
| 47 | + |
| 48 | +This is the preparation step which implements ``ResNet50`` in two model shards. |
| 49 | +The code below is borrowed from the |
| 50 | +`ResNet implemention in torchvision <https://github.com/pytorch/vision/blob/7c077f6a986f05383bcb86b535aedb5a63dd5c4b/torchvision/models/resnet.py#L124>`_. |
| 51 | +The ``ResNetBase`` module contains the common building blocks and attributes for |
| 52 | +the two ResNet shards. |
| 53 | + |
| 54 | + |
| 55 | +.. code:: python |
| 56 | + import threading |
| 57 | +
|
| 58 | + import torch |
| 59 | + import torch.nn as nn |
| 60 | +
|
| 61 | + from torchvision.models.resnet import Bottleneck |
| 62 | +
|
| 63 | + num_classes = 1000 |
| 64 | +
|
| 65 | +
|
| 66 | + def conv1x1(in_planes, out_planes, stride=1): |
| 67 | + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
| 68 | +
|
| 69 | +
|
| 70 | + class ResNetBase(nn.Module): |
| 71 | + def __init__(self, block, inplanes, num_classes=1000, |
| 72 | + groups=1, width_per_group=64, norm_layer=None): |
| 73 | + super(ResNetBase, self).__init__() |
| 74 | +
|
| 75 | + self._lock = threading.Lock() |
| 76 | + self._block = block |
| 77 | + self._norm_layer = nn.BatchNorm2d |
| 78 | + self.inplanes = inplanes |
| 79 | + self.dilation = 1 |
| 80 | + self.groups = groups |
| 81 | + self.base_width = width_per_group |
| 82 | +
|
| 83 | + def _make_layer(self, planes, blocks, stride=1): |
| 84 | + norm_layer = self._norm_layer |
| 85 | + downsample = None |
| 86 | + previous_dilation = self.dilation |
| 87 | + if stride != 1 or self.inplanes != planes * self._block.expansion: |
| 88 | + downsample = nn.Sequential( |
| 89 | + conv1x1(self.inplanes, planes * self._block.expansion, stride), |
| 90 | + norm_layer(planes * self._block.expansion), |
| 91 | + ) |
| 92 | +
|
| 93 | + layers = [] |
| 94 | + layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups, |
| 95 | + self.base_width, previous_dilation, norm_layer)) |
| 96 | + self.inplanes = planes * self._block.expansion |
| 97 | + for _ in range(1, blocks): |
| 98 | + layers.append(self._block(self.inplanes, planes, groups=self.groups, |
| 99 | + base_width=self.base_width, dilation=self.dilation, |
| 100 | + norm_layer=norm_layer)) |
| 101 | +
|
| 102 | + return nn.Sequential(*layers) |
| 103 | +
|
| 104 | + def parameter_rrefs(self): |
| 105 | + return [RRef(p) for p in self.parameters()] |
| 106 | +
|
| 107 | +
|
| 108 | +Now, we are ready to define the two model shards. For the constructor, we |
| 109 | +simply split all ResNet50 layers into two parts and move each part into the |
| 110 | +provided device. The ``forward`` functions of both shards take an ``RRef`` of |
| 111 | +the input data, fetch the data locally, and then move it to the expected device. |
| 112 | +After applying all layers to the input, it moves the output to CPU and return. |
| 113 | +It is because the RPC API requires tensors to reside on CPU to avoid invalid |
| 114 | +device errors when the numbers of devices in the caller and the callee do not |
| 115 | +match. |
| 116 | + |
| 117 | + |
| 118 | +.. code:: python |
| 119 | +
|
| 120 | + class ResNetShard1(ResNetBase): |
| 121 | + def __init__(self, device, *args, **kwargs): |
| 122 | + super(ResNetShard1, self).__init__( |
| 123 | + Bottleneck, 64, num_classes=num_classes, *args, **kwargs) |
| 124 | +
|
| 125 | + self.device = device |
| 126 | + self.seq = nn.Sequential( |
| 127 | + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), |
| 128 | + self._norm_layer(self.inplanes), |
| 129 | + nn.ReLU(inplace=True), |
| 130 | + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
| 131 | + self._make_layer(64, 3), |
| 132 | + self._make_layer(128, 4, stride=2) |
| 133 | + ).to(self.device) |
| 134 | +
|
| 135 | + for m in self.modules(): |
| 136 | + if isinstance(m, nn.Conv2d): |
| 137 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| 138 | + elif isinstance(m, nn.BatchNorm2d): |
| 139 | + nn.init.constant_(m.weight, 1) |
| 140 | + nn.init.constant_(m.bias, 0) |
| 141 | +
|
| 142 | + def forward(self, x_rref): |
| 143 | + x = x_rref.to_here().to(self.device) |
| 144 | + with self._lock: |
| 145 | + out = self.seq(x) |
| 146 | + return out.cpu() |
| 147 | +
|
| 148 | +
|
| 149 | + class ResNetShard2(ResNetBase): |
| 150 | + def __init__(self, device, *args, **kwargs): |
| 151 | + super(ResNetShard2, self).__init__( |
| 152 | + Bottleneck, 512, num_classes=num_classes, *args, **kwargs) |
| 153 | +
|
| 154 | + self.device = device |
| 155 | + self.seq = nn.Sequential( |
| 156 | + self._make_layer(256, 6, stride=2), |
| 157 | + self._make_layer(512, 3, stride=2), |
| 158 | + nn.AdaptiveAvgPool2d((1, 1)), |
| 159 | + ).to(self.device) |
| 160 | +
|
| 161 | + self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device) |
| 162 | +
|
| 163 | + def forward(self, x_rref): |
| 164 | + x = x_rref.to_here().to(self.device) |
| 165 | + with self._lock: |
| 166 | + out = self.fc(torch.flatten(self.seq(x), 1)) |
| 167 | + return out.cpu() |
| 168 | +
|
| 169 | +
|
| 170 | +Step 2: Stitch ResNet50 Model Shards Into One Module |
| 171 | +---------------------------------------------------- |
| 172 | + |
| 173 | + |
| 174 | +Then, we create an ``DistResNet50`` module to assemble the two shards and |
| 175 | +implement the pipeline parallel logic. In the constructor, we use two |
| 176 | +``rpc.remote`` calls to put the two shards on two different RPC workers |
| 177 | +respectively, and hold on to the ``RRef`` to the two model parts so that they |
| 178 | +can be referenced in the forward pass. The ``forward`` function |
| 179 | +splits the input batch into multiple micro-batches, and feeds these |
| 180 | +micro-batches to the two model parts in a pipelined fashion. It first uses an |
| 181 | +``rpc.remote`` call to apply the first shard to a micro-batch and then forwards |
| 182 | +the returned intermediate output ``RRef`` to the second model shard. After that, |
| 183 | +it collects the ``Future`` of all micro-outputs, and waits for all of them after |
| 184 | +the loop. Note that both ``remote()`` and ``rpc_async()`` return immediately and |
| 185 | +run asynchronously. Therefore, the entire loop is non-blocking, and will launch |
| 186 | +multiple RPCs concurrently. The execution order of one micro-batch on two model |
| 187 | +parts are preserved by intermediate output ``y_rref``. The execution order |
| 188 | +across micro-batches does not matter. In the end, the forward function |
| 189 | +concatenates outputs of all micro-batches into one single output tensor and |
| 190 | +returns. The ``parameter_rrefs`` function is a helper to |
| 191 | +simplify distributed optimizer construction, which will be used later. |
| 192 | + |
| 193 | + |
| 194 | + |
| 195 | +.. code:: python |
| 196 | +
|
| 197 | + class DistResNet50(nn.Module): |
| 198 | + """ |
| 199 | + Assemble two parts as an nn.Module and define pipelining logic |
| 200 | + """ |
| 201 | + def __init__(self, num_split, workers, *args, **kwargs): |
| 202 | + super(DistResNet50, self).__init__() |
| 203 | +
|
| 204 | + self.num_split = num_split |
| 205 | +
|
| 206 | + # Put the first part of the ResNet50 on workers[0] |
| 207 | + self.p1_rref = rpc.remote( |
| 208 | + workers[0], |
| 209 | + ResNetShard1, |
| 210 | + args = ("cuda:0",) + args, |
| 211 | + kwargs = kwargs |
| 212 | + ) |
| 213 | +
|
| 214 | + # Put the second part of the ResNet50 on workers[1] |
| 215 | + self.p2_rref = rpc.remote( |
| 216 | + workers[1], |
| 217 | + ResNetShard2, |
| 218 | + args = ("cuda:1",) + args, |
| 219 | + kwargs = kwargs |
| 220 | + ) |
| 221 | +
|
| 222 | + def forward(self, xs): |
| 223 | + out_futures = [] |
| 224 | + for x in iter(xs.split(self.split_size, dim=0)): |
| 225 | + x_rref = RRef(x) |
| 226 | + y_rref = self.p1_rref.remote().forward(x_rref) |
| 227 | + z_fut = self.p2_rref.rpc_async().forward(y_rref) |
| 228 | + out_futures.append(z_fut) |
| 229 | +
|
| 230 | + return torch.cat(torch.futures.wait_all(out_futures)) |
| 231 | +
|
| 232 | + def parameter_rrefs(self): |
| 233 | + remote_params = [] |
| 234 | + remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here()) |
| 235 | + remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here()) |
| 236 | + return remote_params |
| 237 | +
|
| 238 | +
|
| 239 | +Step 3: Define The Training Loop |
| 240 | +-------------------------------- |
| 241 | + |
| 242 | + |
| 243 | +After defining the model, let us implement the training loop. We use a |
| 244 | +dedicated master worker to prepare random inputs and labels, and control the |
| 245 | +distributed backward pass and distributed optimizer step. It first creates an |
| 246 | +instance of the ``DistResNet50`` module. It specifies the number of |
| 247 | +micro-batches for each batch, and also provides the name of the two RPC workers |
| 248 | +(i.e., "worker1", and "worker2"). Then it defines the loss function and creates |
| 249 | +a ``DistributedOptimizer`` using the ``parameter_rrefs()`` helper to acquire a |
| 250 | +list of parameter ``RRefs``. Then, the main training loop is very similar to |
| 251 | +regular local trainings, except that it uses ``dist_autograd`` to launch |
| 252 | +backward and provides the ``context_id`` for both backward and optimizer step. |
| 253 | + |
| 254 | + |
| 255 | +.. code:: python |
| 256 | +
|
| 257 | + import torch.distributed.autograd as dist_autograd |
| 258 | + import torch.optim as optim |
| 259 | + from torch.distributed.optim import DistributedOptimizer |
| 260 | +
|
| 261 | + num_batches = 3 |
| 262 | + batch_size = 120 |
| 263 | + image_w = 128 |
| 264 | + image_h = 128 |
| 265 | +
|
| 266 | +
|
| 267 | + def run_master(num_split): |
| 268 | + # put the two model parts on worker1 and worker2 respectively |
| 269 | + model = DistResNet50(num_split, ["worker1", "worker2"]) |
| 270 | + loss_fn = nn.MSELoss() |
| 271 | + opt = DistributedOptimizer( |
| 272 | + optim.SGD, |
| 273 | + model.parameter_rrefs(), |
| 274 | + lr=0.05, |
| 275 | + ) |
| 276 | +
|
| 277 | + one_hot_indices = torch.LongTensor(batch_size) \ |
| 278 | + .random_(0, num_classes) \ |
| 279 | + .view(batch_size, 1) |
| 280 | +
|
| 281 | + for i in range(num_batches): |
| 282 | + print(f"Processing batch {i}") |
| 283 | + # generate random inputs and labels |
| 284 | + inputs = torch.randn(batch_size, 3, image_w, image_h) |
| 285 | + labels = torch.zeros(batch_size, num_classes) \ |
| 286 | + .scatter_(1, one_hot_indices, 1) |
| 287 | +
|
| 288 | + with dist_autograd.context() as context_id: |
| 289 | + outputs = model(inputs) |
| 290 | + dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) |
| 291 | + opt.step(context_id) |
| 292 | +
|
| 293 | +
|
| 294 | +Step 4: Launch RPC Processes |
| 295 | +---------------------------- |
| 296 | + |
| 297 | + |
| 298 | +Finally, the code below shows the target function for all processes. The main |
| 299 | +logic is all defined in ``run_master``. The workers passively waiting for |
| 300 | +commands from the master, and hence simply runs ``init_rpc`` and ``shutdwown``, |
| 301 | +where the ``shutdown`` by default will block until all RPC participants finish. |
| 302 | + |
| 303 | +.. code:: python |
| 304 | +
|
| 305 | + import os |
| 306 | + import time |
| 307 | +
|
| 308 | + import torch.multiprocessing as mp |
| 309 | +
|
| 310 | +
|
| 311 | + def run_worker(rank, world_size, num_split): |
| 312 | + os.environ['MASTER_ADDR'] = 'localhost' |
| 313 | + os.environ['MASTER_PORT'] = '29500' |
| 314 | + options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=128) |
| 315 | +
|
| 316 | + if rank == 0: |
| 317 | + rpc.init_rpc( |
| 318 | + "master", |
| 319 | + rank=rank, |
| 320 | + world_size=world_size, |
| 321 | + rpc_backend_options=options |
| 322 | + ) |
| 323 | + run_master(num_split) |
| 324 | + else: |
| 325 | + rpc.init_rpc( |
| 326 | + f"worker{rank}", |
| 327 | + rank=rank, |
| 328 | + world_size=world_size, |
| 329 | + rpc_backend_options=options |
| 330 | + ) |
| 331 | + pass |
| 332 | +
|
| 333 | + # block until all rpcs finish |
| 334 | + rpc.shutdown() |
| 335 | +
|
| 336 | +
|
| 337 | + if __name__=="__main__": |
| 338 | + world_size = 3 |
| 339 | + for num_split in [1, 2, 4, 8]: |
| 340 | + tik = time.time() |
| 341 | + mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True) |
| 342 | + tok = time.time() |
| 343 | + print(f"number of splits = {num_split}, execution time = {tok - tik}") |
| 344 | +
|
| 345 | +
|
| 346 | +The output below shows the speedup attained by increasing the number of splits |
| 347 | +in each batch. |
| 348 | + |
| 349 | +:: |
| 350 | + |
| 351 | + $ python main.py |
| 352 | + Processing batch 0 |
| 353 | + Processing batch 1 |
| 354 | + Processing batch 2 |
| 355 | + number of splits = 1, execution time = 16.45062756538391 |
| 356 | + Processing batch 0 |
| 357 | + Processing batch 1 |
| 358 | + Processing batch 2 |
| 359 | + number of splits = 2, execution time = 12.329529762268066 |
| 360 | + Processing batch 0 |
| 361 | + Processing batch 1 |
| 362 | + Processing batch 2 |
| 363 | + number of splits = 4, execution time = 10.164430618286133 |
| 364 | + Processing batch 0 |
| 365 | + Processing batch 1 |
| 366 | + Processing batch 2 |
| 367 | + number of splits = 8, execution time = 9.076049566268921 |
0 commit comments