| 
 | 1 | +Distributed Pipeline Parallelism Using RPC  | 
 | 2 | +==========================================  | 
 | 3 | +**Author**: `Shen Li <https://mrshenli.github.io/>`_  | 
 | 4 | + | 
 | 5 | +Prerequisite:  | 
 | 6 | + | 
 | 7 | +-  `Single-Machine Model Parallel Best Practices <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__  | 
 | 8 | +-  `Getting started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__  | 
 | 9 | +-  RRef helper functions: `RRef.rpc_sync() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__,  | 
 | 10 | +    `RRef.rpc_async() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_async>`__, and  | 
 | 11 | +    `RRef.remote() <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.remote>`__  | 
 | 12 | + | 
 | 13 | + | 
 | 14 | +This tutorial uses a Resnet50 model to demonstrate implementing distributed  | 
 | 15 | +pipeline parallelism with `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__  | 
 | 16 | +APIs. This can be viewed as the distributed counterpart of the multi-GPU  | 
 | 17 | +pipeline parallelism discussed in  | 
 | 18 | +`Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_.  | 
 | 19 | + | 
 | 20 | + | 
 | 21 | +Basics  | 
 | 22 | +------  | 
 | 23 | + | 
 | 24 | + | 
 | 25 | +The previous tutorial, `Getting Started with Distributed RPC Framework <rpc_tutorial.html>`_  | 
 | 26 | +shows how to use `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`_  | 
 | 27 | +to implement distributed model parallelism for an RNN model. That tutorial uses  | 
 | 28 | +one GPU to host the ``EmbeddingTable``, and the provided code works fine.  | 
 | 29 | +However, if a model lives on multiple GPUs, it would require some extra steps to  | 
 | 30 | +increase the amortized utilization of all GPUs. Pipeline parallelism is one type  | 
 | 31 | +of paradigm that can help in this case.  | 
 | 32 | + | 
 | 33 | +In this tutorial, we use ``ResNet50`` as an example model which is also used by  | 
 | 34 | +the `Single-Machine Model Parallel Best Practices <model_parallel_tutorial.html>`_  | 
 | 35 | +tutorial. Similarly, the ``ResNet50`` model is divided into two shards and  | 
 | 36 | +the input batch is partitioned into multiple splits and fed into the two model  | 
 | 37 | +shards in a pipelined fashion. The difference is that, instead of parallelizing  | 
 | 38 | +the execution using CUDA streams, this tutorial invokes asynchronous RPCs. So,  | 
 | 39 | +the solution presented in this tutorial also works across machine boundaries.  | 
 | 40 | +The remainder of this tutorial presents the implementation in four steps.  | 
 | 41 | + | 
 | 42 | + | 
 | 43 | + | 
 | 44 | +Step 1: Partition ResNet50 Model  | 
 | 45 | +--------------------------------  | 
 | 46 | + | 
 | 47 | +This is the preparation step which implements ``ResNet50`` in two model shards.  | 
 | 48 | +The code below is borrowed from the  | 
 | 49 | +`ResNet implementation in torchvision <https://github.com/pytorch/vision/blob/7c077f6a986f05383bcb86b535aedb5a63dd5c4b/torchvision/models/resnet.py#L124>`_.  | 
 | 50 | +The ``ResNetBase`` module contains the common building blocks and attributes for  | 
 | 51 | +the two ResNet shards.  | 
 | 52 | + | 
 | 53 | + | 
 | 54 | +.. code:: python  | 
 | 55 | +    import threading  | 
 | 56 | +
  | 
 | 57 | +    import torch  | 
 | 58 | +    import torch.nn as nn  | 
 | 59 | +
  | 
 | 60 | +    from torchvision.models.resnet import Bottleneck  | 
 | 61 | +
  | 
 | 62 | +    num_classes = 1000  | 
 | 63 | +
  | 
 | 64 | +
  | 
 | 65 | +    def conv1x1(in_planes, out_planes, stride=1):  | 
 | 66 | +        return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)  | 
 | 67 | +
  | 
 | 68 | +
  | 
 | 69 | +    class ResNetBase(nn.Module):  | 
 | 70 | +        def __init__(self, block, inplanes, num_classes=1000,  | 
 | 71 | +                    groups=1, width_per_group=64, norm_layer=None):  | 
 | 72 | +            super(ResNetBase, self).__init__()  | 
 | 73 | +
  | 
 | 74 | +            self._lock = threading.Lock()  | 
 | 75 | +            self._block = block  | 
 | 76 | +            self._norm_layer = nn.BatchNorm2d  | 
 | 77 | +            self.inplanes = inplanes  | 
 | 78 | +            self.dilation = 1  | 
 | 79 | +            self.groups = groups  | 
 | 80 | +            self.base_width = width_per_group  | 
 | 81 | +
  | 
 | 82 | +        def _make_layer(self, planes, blocks, stride=1):  | 
 | 83 | +            norm_layer = self._norm_layer  | 
 | 84 | +            downsample = None  | 
 | 85 | +            previous_dilation = self.dilation  | 
 | 86 | +            if stride != 1 or self.inplanes != planes * self._block.expansion:  | 
 | 87 | +                downsample = nn.Sequential(  | 
 | 88 | +                    conv1x1(self.inplanes, planes * self._block.expansion, stride),  | 
 | 89 | +                    norm_layer(planes * self._block.expansion),  | 
 | 90 | +                )  | 
 | 91 | +
  | 
 | 92 | +            layers = []  | 
 | 93 | +            layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,  | 
 | 94 | +                                    self.base_width, previous_dilation, norm_layer))  | 
 | 95 | +            self.inplanes = planes * self._block.expansion  | 
 | 96 | +            for _ in range(1, blocks):  | 
 | 97 | +                layers.append(self._block(self.inplanes, planes, groups=self.groups,  | 
 | 98 | +                                        base_width=self.base_width, dilation=self.dilation,  | 
 | 99 | +                                        norm_layer=norm_layer))  | 
 | 100 | +
  | 
 | 101 | +            return nn.Sequential(*layers)  | 
 | 102 | +
  | 
 | 103 | +        def parameter_rrefs(self):  | 
 | 104 | +            return [RRef(p) for p in self.parameters()]  | 
 | 105 | +
  | 
 | 106 | +
  | 
 | 107 | +Now, we are ready to define the two model shards. For the constructor, we  | 
 | 108 | +simply split all ResNet50 layers into two parts and move each part into the  | 
 | 109 | +provided device. The ``forward`` functions of both shards take an ``RRef`` of  | 
 | 110 | +the input data, fetch the data locally, and then move it to the expected device.  | 
 | 111 | +After applying all layers to the input, it moves the output to CPU and returns.  | 
 | 112 | +It is because the RPC API requires tensors to reside on CPU to avoid invalid  | 
 | 113 | +device errors when the numbers of devices in the caller and the callee do not  | 
 | 114 | +match.  | 
 | 115 | + | 
 | 116 | + | 
 | 117 | +.. code:: python  | 
 | 118 | +
  | 
 | 119 | +    class ResNetShard1(ResNetBase):  | 
 | 120 | +        def __init__(self, device, *args, **kwargs):  | 
 | 121 | +            super(ResNetShard1, self).__init__(  | 
 | 122 | +                Bottleneck, 64, num_classes=num_classes, *args, **kwargs)  | 
 | 123 | +
  | 
 | 124 | +            self.device = device  | 
 | 125 | +            self.seq = nn.Sequential(  | 
 | 126 | +                nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),  | 
 | 127 | +                self._norm_layer(self.inplanes),  | 
 | 128 | +                nn.ReLU(inplace=True),  | 
 | 129 | +                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  | 
 | 130 | +                self._make_layer(64, 3),  | 
 | 131 | +                self._make_layer(128, 4, stride=2)  | 
 | 132 | +            ).to(self.device)  | 
 | 133 | +
  | 
 | 134 | +            for m in self.modules():  | 
 | 135 | +                if isinstance(m, nn.Conv2d):  | 
 | 136 | +                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')  | 
 | 137 | +                elif isinstance(m, nn.BatchNorm2d):  | 
 | 138 | +                    nn.init.constant_(m.weight, 1)  | 
 | 139 | +                    nn.init.constant_(m.bias, 0)  | 
 | 140 | +
  | 
 | 141 | +        def forward(self, x_rref):  | 
 | 142 | +            x = x_rref.to_here().to(self.device)  | 
 | 143 | +            with self._lock:  | 
 | 144 | +                out =  self.seq(x)  | 
 | 145 | +            return out.cpu()  | 
 | 146 | +
  | 
 | 147 | +
  | 
 | 148 | +    class ResNetShard2(ResNetBase):  | 
 | 149 | +        def __init__(self, device, *args, **kwargs):  | 
 | 150 | +            super(ResNetShard2, self).__init__(  | 
 | 151 | +                Bottleneck, 512, num_classes=num_classes, *args, **kwargs)  | 
 | 152 | +
  | 
 | 153 | +            self.device = device  | 
 | 154 | +            self.seq = nn.Sequential(  | 
 | 155 | +                self._make_layer(256, 6, stride=2),  | 
 | 156 | +                self._make_layer(512, 3, stride=2),  | 
 | 157 | +                nn.AdaptiveAvgPool2d((1, 1)),  | 
 | 158 | +            ).to(self.device)  | 
 | 159 | +
  | 
 | 160 | +            self.fc =  nn.Linear(512 * self._block.expansion, num_classes).to(self.device)  | 
 | 161 | +
  | 
 | 162 | +        def forward(self, x_rref):  | 
 | 163 | +            x = x_rref.to_here().to(self.device)  | 
 | 164 | +            with self._lock:  | 
 | 165 | +                out = self.fc(torch.flatten(self.seq(x), 1))  | 
 | 166 | +            return out.cpu()  | 
 | 167 | +
  | 
 | 168 | +
  | 
 | 169 | +Step 2: Stitch ResNet50 Model Shards Into One Module  | 
 | 170 | +----------------------------------------------------  | 
 | 171 | + | 
 | 172 | + | 
 | 173 | +Then, we create a ``DistResNet50`` module to assemble the two shards and  | 
 | 174 | +implement the pipeline parallel logic. In the constructor, we use two  | 
 | 175 | +``rpc.remote`` calls to put the two shards on two different RPC workers  | 
 | 176 | +respectively and hold on to the ``RRef`` to the two model parts so that they  | 
 | 177 | +can be referenced in the forward pass.  The ``forward`` function  | 
 | 178 | +splits the input batch into multiple micro-batches, and feeds these  | 
 | 179 | +micro-batches to the two model parts in a pipelined fashion. It first uses an  | 
 | 180 | +``rpc.remote`` call to apply the first shard to a micro-batch and then forwards  | 
 | 181 | +the returned intermediate output ``RRef`` to the second model shard. After that,  | 
 | 182 | +it collects the ``Future`` of all micro-outputs, and waits for all of them after  | 
 | 183 | +the loop. Note that both ``remote()`` and ``rpc_async()`` return immediately and  | 
 | 184 | +run asynchronously. Therefore, the entire loop is non-blocking, and will launch  | 
 | 185 | +multiple RPCs concurrently. The execution order of one micro-batch on two model  | 
 | 186 | +parts are preserved by intermediate output ``y_rref``. The execution order  | 
 | 187 | +across micro-batches does not matter. In the end, the forward function  | 
 | 188 | +concatenates outputs of all micro-batches into one single output tensor and  | 
 | 189 | +returns. The ``parameter_rrefs`` function is a helper to  | 
 | 190 | +simplify distributed optimizer construction, which will be used later.  | 
 | 191 | + | 
 | 192 | + | 
 | 193 | + | 
 | 194 | +.. code:: python  | 
 | 195 | +
  | 
 | 196 | +    class DistResNet50(nn.Module):  | 
 | 197 | +        """  | 
 | 198 | +        Assemble two parts as an nn.Module and define pipelining logic  | 
 | 199 | +        """  | 
 | 200 | +        def __init__(self, num_split, workers, *args, **kwargs):  | 
 | 201 | +            super(DistResNet50, self).__init__()  | 
 | 202 | +
  | 
 | 203 | +            self.num_split = num_split  | 
 | 204 | +
  | 
 | 205 | +            # Put the first part of the ResNet50 on workers[0]  | 
 | 206 | +            self.p1_rref = rpc.remote(  | 
 | 207 | +                workers[0],  | 
 | 208 | +                ResNetShard1,  | 
 | 209 | +                args = ("cuda:0",) + args,  | 
 | 210 | +                kwargs = kwargs  | 
 | 211 | +            )  | 
 | 212 | +
  | 
 | 213 | +            # Put the second part of the ResNet50 on workers[1]  | 
 | 214 | +            self.p2_rref = rpc.remote(  | 
 | 215 | +                workers[1],  | 
 | 216 | +                ResNetShard2,  | 
 | 217 | +                args = ("cuda:1",) + args,  | 
 | 218 | +                kwargs = kwargs  | 
 | 219 | +            )  | 
 | 220 | +
  | 
 | 221 | +        def forward(self, xs):  | 
 | 222 | +            out_futures = []  | 
 | 223 | +            for x in iter(xs.split(self.split_size, dim=0)):  | 
 | 224 | +                x_rref = RRef(x)  | 
 | 225 | +                y_rref = self.p1_rref.remote().forward(x_rref)  | 
 | 226 | +                z_fut = self.p2_rref.rpc_async().forward(y_rref)  | 
 | 227 | +                out_futures.append(z_fut)  | 
 | 228 | +
  | 
 | 229 | +            return torch.cat(torch.futures.wait_all(out_futures))  | 
 | 230 | +
  | 
 | 231 | +        def parameter_rrefs(self):  | 
 | 232 | +            remote_params = []  | 
 | 233 | +            remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())  | 
 | 234 | +            remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())  | 
 | 235 | +            return remote_params  | 
 | 236 | +
  | 
 | 237 | +
  | 
 | 238 | +Step 3: Define The Training Loop  | 
 | 239 | +--------------------------------  | 
 | 240 | + | 
 | 241 | + | 
 | 242 | +After defining the model, let us implement the training loop. We use a  | 
 | 243 | +dedicated "master" worker to prepare random inputs and labels, and control the  | 
 | 244 | +distributed backward pass and distributed optimizer step. It first creates an  | 
 | 245 | +instance of the ``DistResNet50`` module. It specifies the number of  | 
 | 246 | +micro-batches for each batch, and also provides the name of the two RPC workers  | 
 | 247 | +(i.e., "worker1", and "worker2"). Then it defines the loss function and creates  | 
 | 248 | +a ``DistributedOptimizer`` using the ``parameter_rrefs()`` helper to acquire a  | 
 | 249 | +list of parameter ``RRefs``. Then, the main training loop is very similar to  | 
 | 250 | +regular local training, except that it uses ``dist_autograd`` to launch  | 
 | 251 | +backward and provides the ``context_id`` for both backward and optimizer  | 
 | 252 | +``step()``.  | 
 | 253 | + | 
 | 254 | + | 
 | 255 | +.. code:: python  | 
 | 256 | +
  | 
 | 257 | +    import torch.distributed.autograd as dist_autograd  | 
 | 258 | +    import torch.optim as optim  | 
 | 259 | +    from torch.distributed.optim import DistributedOptimizer  | 
 | 260 | +
  | 
 | 261 | +    num_batches = 3  | 
 | 262 | +    batch_size = 120  | 
 | 263 | +    image_w = 128  | 
 | 264 | +    image_h = 128  | 
 | 265 | +
  | 
 | 266 | +
  | 
 | 267 | +    def run_master(num_split):  | 
 | 268 | +        # put the two model parts on worker1 and worker2 respectively  | 
 | 269 | +        model = DistResNet50(num_split, ["worker1", "worker2"])  | 
 | 270 | +        loss_fn = nn.MSELoss()  | 
 | 271 | +        opt = DistributedOptimizer(  | 
 | 272 | +            optim.SGD,  | 
 | 273 | +            model.parameter_rrefs(),  | 
 | 274 | +            lr=0.05,  | 
 | 275 | +        )  | 
 | 276 | +
  | 
 | 277 | +        one_hot_indices = torch.LongTensor(batch_size) \  | 
 | 278 | +                            .random_(0, num_classes) \  | 
 | 279 | +                            .view(batch_size, 1)  | 
 | 280 | +
  | 
 | 281 | +        for i in range(num_batches):  | 
 | 282 | +            print(f"Processing batch {i}")  | 
 | 283 | +            # generate random inputs and labels  | 
 | 284 | +            inputs = torch.randn(batch_size, 3, image_w, image_h)  | 
 | 285 | +            labels = torch.zeros(batch_size, num_classes) \  | 
 | 286 | +                        .scatter_(1, one_hot_indices, 1)  | 
 | 287 | +
  | 
 | 288 | +            with dist_autograd.context() as context_id:  | 
 | 289 | +                outputs = model(inputs)  | 
 | 290 | +                dist_autograd.backward(context_id, [loss_fn(outputs, labels)])  | 
 | 291 | +                opt.step(context_id)  | 
 | 292 | +
  | 
 | 293 | +
  | 
 | 294 | +Step 4: Launch RPC Processes  | 
 | 295 | +----------------------------  | 
 | 296 | + | 
 | 297 | + | 
 | 298 | +Finally, the code below shows the target function for all processes. The main  | 
 | 299 | +logic is defined in ``run_master``. The workers passively waiting for  | 
 | 300 | +commands from the master, and hence simply runs ``init_rpc`` and ``shutdown``,  | 
 | 301 | +where the ``shutdown`` by default will block until all RPC participants finish.  | 
 | 302 | + | 
 | 303 | +.. code:: python  | 
 | 304 | +
  | 
 | 305 | +    import os  | 
 | 306 | +    import time  | 
 | 307 | +
  | 
 | 308 | +    import torch.multiprocessing as mp  | 
 | 309 | +
  | 
 | 310 | +
  | 
 | 311 | +    def run_worker(rank, world_size, num_split):  | 
 | 312 | +        os.environ['MASTER_ADDR'] = 'localhost'  | 
 | 313 | +        os.environ['MASTER_PORT'] = '29500'  | 
 | 314 | +        options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=128)  | 
 | 315 | +
  | 
 | 316 | +        if rank == 0:  | 
 | 317 | +            rpc.init_rpc(  | 
 | 318 | +                "master",  | 
 | 319 | +                rank=rank,  | 
 | 320 | +                world_size=world_size,  | 
 | 321 | +                rpc_backend_options=options  | 
 | 322 | +            )  | 
 | 323 | +            run_master(num_split)  | 
 | 324 | +        else:  | 
 | 325 | +            rpc.init_rpc(  | 
 | 326 | +                f"worker{rank}",  | 
 | 327 | +                rank=rank,  | 
 | 328 | +                world_size=world_size,  | 
 | 329 | +                rpc_backend_options=options  | 
 | 330 | +            )  | 
 | 331 | +            pass  | 
 | 332 | +
  | 
 | 333 | +        # block until all rpcs finish  | 
 | 334 | +        rpc.shutdown()  | 
 | 335 | +
  | 
 | 336 | +
  | 
 | 337 | +    if __name__=="__main__":  | 
 | 338 | +        world_size = 3  | 
 | 339 | +        for num_split in [1, 2, 4, 8]:  | 
 | 340 | +            tik = time.time()  | 
 | 341 | +            mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)  | 
 | 342 | +            tok = time.time()  | 
 | 343 | +            print(f"number of splits = {num_split}, execution time = {tok - tik}")  | 
 | 344 | +
  | 
 | 345 | +
  | 
 | 346 | +The output below shows the speedup attained by increasing the number of splits  | 
 | 347 | +in each batch.  | 
 | 348 | + | 
 | 349 | +::  | 
 | 350 | + | 
 | 351 | +    $ python main.py  | 
 | 352 | +    Processing batch 0  | 
 | 353 | +    Processing batch 1  | 
 | 354 | +    Processing batch 2  | 
 | 355 | +    number of splits = 1, execution time = 16.45062756538391  | 
 | 356 | +    Processing batch 0  | 
 | 357 | +    Processing batch 1  | 
 | 358 | +    Processing batch 2  | 
 | 359 | +    number of splits = 2, execution time = 12.329529762268066  | 
 | 360 | +    Processing batch 0  | 
 | 361 | +    Processing batch 1  | 
 | 362 | +    Processing batch 2  | 
 | 363 | +    number of splits = 4, execution time = 10.164430618286133  | 
 | 364 | +    Processing batch 0  | 
 | 365 | +    Processing batch 1  | 
 | 366 | +    Processing batch 2  | 
 | 367 | +    number of splits = 8, execution time = 9.076049566268921  | 
0 commit comments