|
| 1 | +Combining Distributed DataParallel with Distributed RPC Framework |
| 2 | +================================================================= |
| 3 | +**Author**: `Pritam Damania <https://github.com/pritamdamania87>`_ |
| 4 | + |
| 5 | + |
| 6 | +This tutorial uses a simple example to demonstrate how you can combine |
| 7 | +`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__ (DDP) |
| 8 | +with the `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__ |
| 9 | +to combine distributed data parallelism with distributed model parallelism to |
| 10 | +train a simple model. Source code of the example can be found `here <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__. |
| 11 | + |
| 12 | +Previous tutorials, |
| 13 | +`Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__ |
| 14 | +and `Getting Started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__, |
| 15 | +described how to perform distributed data parallel and distributed model |
| 16 | +parallel training respectively. Although, there are several training paradigms |
| 17 | +where you might want to combine these two techniques. For example: |
| 18 | + |
| 19 | +1) If we have a model with a sparse part (large embedding table) and a dense |
| 20 | + part (FC layers), we might want to put the embedding table on a parameter |
| 21 | + server and replicate the FC layer across multiple trainers using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__. |
| 22 | + The `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__ |
| 23 | + can be used to perform embedding lookups on the parameter server. |
| 24 | +2) Enable hybrid parallelism as described in the `PipeDream <https://arxiv.org/abs/1806.03377>`__ paper. |
| 25 | + We can use the `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__ |
| 26 | + to pipeline stages of the model across multiple workers and replicate each |
| 27 | + stage (if needed) using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__. |
| 28 | + |
| 29 | +| |
| 30 | +In this tutorial we will cover case 1 mentioned above. We have a total of 4 |
| 31 | +workers in our setup as follows: |
| 32 | + |
| 33 | + |
| 34 | +1) 1 Master, which is responsible for creating an embedding table |
| 35 | + (nn.EmbeddingBag) on the parameter server. The master also drives the |
| 36 | + training loop on the two trainers. |
| 37 | +2) 1 Parameter Server, which basically holds the embedding table in memory and |
| 38 | + responds to RPCs from the Master and Trainers. |
| 39 | +3) 2 Trainers, which store a FC layer (nn.Linear) which is replicated amongst |
| 40 | + themselves using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__. |
| 41 | + The trainers are also responsible for executing the forward pass, backward |
| 42 | + pass and optimizer step. |
| 43 | + |
| 44 | +| |
| 45 | +The entire training process is executed as follows: |
| 46 | + |
| 47 | +1) The master creates an embedding table on the Parameter Server and holds a |
| 48 | + `RRef <https://pytorch.org/docs/master/rpc.html#rref>`__ to it. |
| 49 | +2) The master, then kicks of the training loop on the trainers and passes the |
| 50 | + embedding table RRef to the trainers. |
| 51 | +3) The trainers create a ``HybridModel`` which first performs an embedding lookup |
| 52 | + using the embedding table RRef provided by the master and then executes the |
| 53 | + FC layer which is wrapped inside DDP. |
| 54 | +4) The trainer executes the forward pass of the model and uses the loss to |
| 55 | + execute the backward pass using `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__. |
| 56 | +5) As part of the backward pass, the gradients for the FC layer are computed |
| 57 | + first and synced to all trainers via allreduce in DDP. |
| 58 | +6) Next, Distributed Autograd propagates the gradients to the parameter server, |
| 59 | + where the gradients for the embedding table are updated. |
| 60 | +7) Finally, the `Distributed Optimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__ is used to update all the parameters. |
| 61 | + |
| 62 | + |
| 63 | +| |
| 64 | +**NOTE**: You should always use `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__ for the backward pass if you're combining DDP and RPC. |
| 65 | + |
| 66 | + |
| 67 | +Now, lets go through each part in detail. Firstly, we need to setup all of our |
| 68 | +workers before we can perform any training. We create 4 processes such that |
| 69 | +ranks 0 and 1 are our trainers, rank 2 is the master and rank 3 is the |
| 70 | +parameter server. |
| 71 | + |
| 72 | +We initialize the RPC framework on all 4 workers using the TCP init_method. |
| 73 | +Once RPC initialization is done, the master creates an `EmbeddingBag <https://pytorch.org/docs/master/generated/torch.nn.EmbeddingBag.html>`__ |
| 74 | +on the Parameter Server using `rpc.remote <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.remote>`__. |
| 75 | +The master then loops through each trainer and kicks of the training loop by |
| 76 | +calling ``_run_trainer`` on each trainer using `rpc_async <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.rpc_async>`__. |
| 77 | +Finally, the master waits for all training to finish before exiting. |
| 78 | + |
| 79 | +The trainers first initialize a ``ProcessGroup`` for DDP with world_size=2 |
| 80 | +(for two trainers) using `init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__. |
| 81 | +Next, they initialize the RPC framework using the TCP init_method. Note that |
| 82 | +the ports are different in RPC initialization and ProcessGroup intialization. |
| 83 | +This is to avoid port conflicts between initialization of both frameworks. |
| 84 | +Once the initialization is done, the trainers just wait for the ``_run_trainer`` |
| 85 | +RPC from the master. |
| 86 | + |
| 87 | +The parameter server just initializes the RPC framework and waits for RPCs from |
| 88 | +the trainers and master. |
| 89 | + |
| 90 | +.. code:: python |
| 91 | +
|
| 92 | + def run_worker(rank, world_size): |
| 93 | + r""" |
| 94 | + A wrapper function that initializes RPC, calls the function, and shuts down |
| 95 | + RPC. |
| 96 | + """ |
| 97 | +
|
| 98 | + # We need to use different port numbers in TCP init_method for init_rpc and |
| 99 | + # init_process_group to avoid port conflicts. |
| 100 | + rpc_backend_options = ProcessGroupRpcBackendOptions() |
| 101 | + rpc_backend_options.init_method='tcp://localhost:29501' |
| 102 | +
|
| 103 | + # Rank 2 is master, 3 is ps and 0 and 1 are trainers. |
| 104 | + if rank == 2: |
| 105 | + rpc.init_rpc( |
| 106 | + "master", |
| 107 | + rank=rank, |
| 108 | + world_size=world_size, |
| 109 | + rpc_backend_options=rpc_backend_options) |
| 110 | +
|
| 111 | + # Build the embedding table on the ps. |
| 112 | + emb_rref = rpc.remote( |
| 113 | + "ps", |
| 114 | + torch.nn.EmbeddingBag, |
| 115 | + args=(NUM_EMBEDDINGS, EMBEDDING_DIM), |
| 116 | + kwargs={"mode": "sum"}) |
| 117 | +
|
| 118 | + # Run the training loop on trainers. |
| 119 | + futs = [] |
| 120 | + for trainer_rank in [0, 1]: |
| 121 | + trainer_name = "trainer{}".format(trainer_rank) |
| 122 | + fut = rpc.rpc_async( |
| 123 | + trainer_name, _run_trainer, args=(emb_rref, rank)) |
| 124 | + futs.append(fut) |
| 125 | +
|
| 126 | + # Wait for all training to finish. |
| 127 | + for fut in futs: |
| 128 | + fut.wait() |
| 129 | + elif rank <= 1: |
| 130 | + # Initialize process group for Distributed DataParallel on trainers. |
| 131 | + dist.init_process_group( |
| 132 | + backend="gloo", rank=rank, world_size=2, |
| 133 | + init_method='tcp://localhost:29500') |
| 134 | +
|
| 135 | + # Initialize RPC. |
| 136 | + trainer_name = "trainer{}".format(rank) |
| 137 | + rpc.init_rpc( |
| 138 | + trainer_name, |
| 139 | + rank=rank, |
| 140 | + world_size=world_size, |
| 141 | + rpc_backend_options=rpc_backend_options) |
| 142 | +
|
| 143 | + # Trainer just waits for RPCs from master. |
| 144 | + else: |
| 145 | + rpc.init_rpc( |
| 146 | + "ps", |
| 147 | + rank=rank, |
| 148 | + world_size=world_size, |
| 149 | + rpc_backend_options=rpc_backend_options) |
| 150 | + # parameter server do nothing |
| 151 | + pass |
| 152 | +
|
| 153 | + # block until all rpcs finish |
| 154 | + rpc.shutdown() |
| 155 | +
|
| 156 | +
|
| 157 | + if __name__=="__main__": |
| 158 | + # 2 trainers, 1 parameter server, 1 master. |
| 159 | + world_size = 4 |
| 160 | + mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True) |
| 161 | +
|
| 162 | +Before we discuss details of the Trainer, lets introduce the ``HybridModel`` that |
| 163 | +the trainer uses. As described below, the ``HybridModel`` is initialized using an |
| 164 | +RRef to the embedding table (emb_rref) on the parameter server and the ``device`` |
| 165 | +to use for DDP. The initialization of the model wraps a |
| 166 | +`nn.Linear <https://pytorch.org/docs/master/generated/torch.nn.Linear.html>`__ |
| 167 | +layer inside DDP to replicate and synchronize this layer across all trainers. |
| 168 | + |
| 169 | +The forward method of the model is pretty straightforward. It performs an |
| 170 | +embedding lookup on the parameter server using an |
| 171 | +`RRef helper <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__ |
| 172 | +and passes its output onto the FC layer. |
| 173 | + |
| 174 | + |
| 175 | +.. code:: python |
| 176 | +
|
| 177 | + class HybridModel(torch.nn.Module): |
| 178 | + r""" |
| 179 | + The model consists of a sparse part and a dense part. The dense part is an |
| 180 | + nn.Linear module that is replicated across all trainers using |
| 181 | + DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is |
| 182 | + stored on the parameter server. |
| 183 | +
|
| 184 | + The model holds a Remote Reference to the embedding table on the parameter |
| 185 | + server. |
| 186 | + """ |
| 187 | +
|
| 188 | + def __init__(self, emb_rref, device): |
| 189 | + super(HybridModel, self).__init__() |
| 190 | + self.emb_rref = emb_rref |
| 191 | + self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device]) |
| 192 | + self.device = device |
| 193 | +
|
| 194 | + def forward(self, indices, offsets): |
| 195 | + emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets) |
| 196 | + return self.fc(emb_lookup.cuda(self.device)) |
| 197 | +
|
| 198 | +Next, lets look at the setup on the Trainer. The trainer first creates the |
| 199 | +``HybridModel`` described above using an RRef to the embedding table on the |
| 200 | +parameter server and its own rank. |
| 201 | + |
| 202 | +Now, we need to retrieve a list of RRefs to all the parameters that we would |
| 203 | +like to optimize with `DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__. |
| 204 | +To retrieve the parameters for the embedding table from the parameter server, |
| 205 | +we define a simple helper function ``_retrieve_embedding_parameters``, which |
| 206 | +basically walks through all the parameters for the embedding table and returns |
| 207 | +a list of RRefs. The trainer calls this method on the parameter server via RPC |
| 208 | +to receive a list of RRefs to the desired parameters. Since the |
| 209 | +DistributedOptimizer always takes a list of RRefs to parameters that need to |
| 210 | +be optimized, we need to create RRefs even for the local parameters for our |
| 211 | +FC layers. This is done by walking ``model.parameters()``, creating an RRef for |
| 212 | +each parameter and appending it to a list. Note that ``model.parameters()`` only |
| 213 | +returns local parameters and doesn't include ``emb_rref``. |
| 214 | + |
| 215 | +Finally, we create our DistributedOptimizer using all the RRefs and define a |
| 216 | +CrossEntropyLoss function. |
| 217 | + |
| 218 | +.. code:: python |
| 219 | +
|
| 220 | + def _retrieve_embedding_parameters(emb_rref): |
| 221 | + return [RRef(p) for p in emb_rref.local_value().parameters()] |
| 222 | +
|
| 223 | +
|
| 224 | + def _run_trainer(emb_rref, rank): |
| 225 | + r""" |
| 226 | + Each trainer runs a forward pass which involves an embedding lookup on the |
| 227 | + parameter server and running nn.Linear locally. During the backward pass, |
| 228 | + DDP is responsible for aggregating the gradients for the dense part |
| 229 | + (nn.Linear) and distributed autograd ensures gradients updates are |
| 230 | + propagated to the parameter server. |
| 231 | + """ |
| 232 | +
|
| 233 | + # Setup the model. |
| 234 | + model = HybridModel(emb_rref, rank) |
| 235 | +
|
| 236 | + # Retrieve all model parameters as rrefs for DistributedOptimizer. |
| 237 | +
|
| 238 | + # Retrieve parameters for embedding table. |
| 239 | + model_parameter_rrefs = rpc.rpc_sync( |
| 240 | + "ps", _retrieve_embedding_parameters, args=(emb_rref,)) |
| 241 | +
|
| 242 | + # model.parameters() only includes local parameters. |
| 243 | + for param in model.parameters(): |
| 244 | + model_parameter_rrefs.append(RRef(param)) |
| 245 | +
|
| 246 | + # Setup distributed optimizer |
| 247 | + opt = DistributedOptimizer( |
| 248 | + optim.SGD, |
| 249 | + model_parameter_rrefs, |
| 250 | + lr=0.05, |
| 251 | + ) |
| 252 | +
|
| 253 | + criterion = torch.nn.CrossEntropyLoss() |
| 254 | +
|
| 255 | +Now we're ready to introduce the main training loop that is run on each trainer. |
| 256 | +``get_next_batch`` is just a helper function to generate random inputs and |
| 257 | +targets for training. We run the training loop for multiple epochs and for each |
| 258 | +batch: |
| 259 | + |
| 260 | +1) Setup a `Distributed Autograd Context <https://pytorch.org/docs/master/rpc.html#torch.distributed.autograd.context>`__ |
| 261 | + for Distributed Autograd. |
| 262 | +2) Run the forward pass of the model and retrieve its output. |
| 263 | +3) Compute the loss based on our outputs and targets using the loss function. |
| 264 | +4) Use Distributed Autograd to execute a distributed backward pass using the loss. |
| 265 | +5) Finally, run a Distributed Optimizer step to optimize all the parameters. |
| 266 | + |
| 267 | +.. code:: python |
| 268 | +
|
| 269 | + # def _run_trainer(emb_rref, rank): continued... |
| 270 | +
|
| 271 | + def get_next_batch(rank): |
| 272 | + for _ in range(10): |
| 273 | + num_indices = random.randint(20, 50) |
| 274 | + indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS) |
| 275 | +
|
| 276 | + # Generate offsets. |
| 277 | + offsets = [] |
| 278 | + start = 0 |
| 279 | + batch_size = 0 |
| 280 | + while start < num_indices: |
| 281 | + offsets.append(start) |
| 282 | + start += random.randint(1, 10) |
| 283 | + batch_size += 1 |
| 284 | +
|
| 285 | + offsets_tensor = torch.LongTensor(offsets) |
| 286 | + target = torch.LongTensor(batch_size).random_(8).cuda(rank) |
| 287 | + yield indices, offsets_tensor, target |
| 288 | +
|
| 289 | + # Train for 100 epochs |
| 290 | + for epoch in range(100): |
| 291 | + # create distributed autograd context |
| 292 | + for indices, offsets, target in get_next_batch(rank): |
| 293 | + with dist_autograd.context() as context_id: |
| 294 | + output = model(indices, offsets) |
| 295 | + loss = criterion(output, target) |
| 296 | +
|
| 297 | + # Run distributed backward pass |
| 298 | + dist_autograd.backward(context_id, [loss]) |
| 299 | +
|
| 300 | + # Tun distributed optimizer |
| 301 | + opt.step(context_id) |
| 302 | +
|
| 303 | + # Not necessary to zero grads as each iteration creates a different |
| 304 | + # distributed autograd context which hosts different grads |
| 305 | + print("Training done for epoch {}".format(epoch)) |
| 306 | +
|
| 307 | +| |
| 308 | +Source code for the entire example can be found `here <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__. |
0 commit comments