@@ -294,3 +294,80 @@ either the application or the model ``forward()`` method.
294294 run_demo(demo_basic, world_size)
295295 run_demo(demo_checkpoint, world_size)
296296 run_demo(demo_model_parallel, world_size)
297+
298+ Initialize DDP with torch.distributed.run/torchrun
299+ ----------------------------------
300+
301+ We can leverage PyTorch Elastic to simplify the DDP code and initialize the job more easily.
302+ Let's still use the Toymodel example and create a file named ``elastic_ddp.py ``.
303+
304+ .. code :: python
305+
306+ import torch
307+ import torch.distributed as dist
308+ import torch.nn as nn
309+ import torch.optim as optim
310+
311+ from torch.nn.parallel import DistributedDataParallel as DDP
312+
313+ class ToyModel (nn .Module ):
314+ def __init__ (self ):
315+ super (ToyModel, self ).__init__ ()
316+ self .net1 = nn.Linear(10 , 10 )
317+ self .relu = nn.ReLU()
318+ self .net2 = nn.Linear(10 , 5 )
319+
320+ def forward (self , x ):
321+ return self .net2(self .relu(self .net1(x)))
322+
323+
324+ def demo_basic ():
325+ dist.init_process_group(" nccl" )
326+ rank = dist.get_rank()
327+ print (f " Start running basic DDP example on rank { rank} . " )
328+
329+ # create model and move it to GPU with id rank
330+ device_id = rank % torch.cuda.device_count()
331+ model = ToyModel().to(device_id)
332+ ddp_model = DDP(model, device_ids = [device_id])
333+
334+ loss_fn = nn.MSELoss()
335+ optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
336+
337+ optimizer.zero_grad()
338+ outputs = ddp_model(torch.randn(20 , 10 ))
339+ labels = torch.randn(20 , 5 ).to(device_id)
340+ loss_fn(outputs, labels).backward()
341+ optimizer.step()
342+
343+ if __name__ == " __main__" :
344+ demo_basic()
345+
346+ One can then run a `torch elastic/torchrun<https://pytorch.org/docs/stable/elastic/quickstart.html> `__ command
347+ on all nodes to initialize the DDP job created above:
348+
349+ .. code :: bash
350+
351+ torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR :29400 elastic_ddp.py
352+
353+ We are running the DDP script on two hosts, and each host we run with 8 processes, aka, we
354+ are running it on 16 GPUs. Note that ``$MASTER_ADDR `` must be the same across all nodes.
355+
356+ Here torchrun will launch 8 process and invoke ``elastic_ddp.py ``
357+ on each process on the node it is launched on, but user also needs to apply cluster
358+ management tools like slurm to actually run this command on 2 nodes.
359+
360+ For example, on a SLURM enabled cluster, we can write a script to run the command above
361+ and set ``MASTER_ADDR `` as:
362+
363+ .. code :: bash
364+
365+ export MASTER_ADDR=$( scontrol show hostname ${SLURM_NODELIST} | head -n 1)
366+
367+
368+ Then we can just run this script using the SLURM command: ``srun --nodes=2 ./torchrun_script.sh ``.
369+ Of course, this is just an example; you can choose your own cluster scheduling tools
370+ to initiate the torchrun job.
371+
372+ For more information about Elastic run, one can check this
373+ `quick start document <https://pytorch.org/docs/stable/elastic/quickstart.html >`__ to learn more.
0 commit comments