Skip to content

Commit 46e4726

Browse files
authored
Add example on how to run DDP with PyTorch elastic (#1948)
* Add example on how to run DDP with PyTorch elastic Add one example of how to initialize DDP on 2 hosts with 8 processes on each. * Fix typo * Address the comment from the reviewer * Address comments from reviewers.
1 parent 8a33b50 commit 46e4726

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

intermediate_source/ddp_tutorial.rst

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,80 @@ either the application or the model ``forward()`` method.
294294
run_demo(demo_basic, world_size)
295295
run_demo(demo_checkpoint, world_size)
296296
run_demo(demo_model_parallel, world_size)
297+
298+
Initialize DDP with torch.distributed.run/torchrun
299+
----------------------------------
300+
301+
We can leverage PyTorch Elastic to simplify the DDP code and initialize the job more easily.
302+
Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
303+
304+
.. code:: python
305+
306+
import torch
307+
import torch.distributed as dist
308+
import torch.nn as nn
309+
import torch.optim as optim
310+
311+
from torch.nn.parallel import DistributedDataParallel as DDP
312+
313+
class ToyModel(nn.Module):
314+
def __init__(self):
315+
super(ToyModel, self).__init__()
316+
self.net1 = nn.Linear(10, 10)
317+
self.relu = nn.ReLU()
318+
self.net2 = nn.Linear(10, 5)
319+
320+
def forward(self, x):
321+
return self.net2(self.relu(self.net1(x)))
322+
323+
324+
def demo_basic():
325+
dist.init_process_group("nccl")
326+
rank = dist.get_rank()
327+
print(f"Start running basic DDP example on rank {rank}.")
328+
329+
# create model and move it to GPU with id rank
330+
device_id = rank % torch.cuda.device_count()
331+
model = ToyModel().to(device_id)
332+
ddp_model = DDP(model, device_ids=[device_id])
333+
334+
loss_fn = nn.MSELoss()
335+
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
336+
337+
optimizer.zero_grad()
338+
outputs = ddp_model(torch.randn(20, 10))
339+
labels = torch.randn(20, 5).to(device_id)
340+
loss_fn(outputs, labels).backward()
341+
optimizer.step()
342+
343+
if __name__ == "__main__":
344+
demo_basic()
345+
346+
One can then run a `torch elastic/torchrun<https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command
347+
on all nodes to initialize the DDP job created above:
348+
349+
.. code:: bash
350+
351+
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 elastic_ddp.py
352+
353+
We are running the DDP script on two hosts, and each host we run with 8 processes, aka, we
354+
are running it on 16 GPUs. Note that ``$MASTER_ADDR`` must be the same across all nodes.
355+
356+
Here torchrun will launch 8 process and invoke ``elastic_ddp.py``
357+
on each process on the node it is launched on, but user also needs to apply cluster
358+
management tools like slurm to actually run this command on 2 nodes.
359+
360+
For example, on a SLURM enabled cluster, we can write a script to run the command above
361+
and set ``MASTER_ADDR`` as:
362+
363+
.. code:: bash
364+
365+
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
366+
367+
368+
Then we can just run this script using the SLURM command: ``srun --nodes=2 ./torchrun_script.sh``.
369+
Of course, this is just an example; you can choose your own cluster scheduling tools
370+
to initiate the torchrun job.
371+
372+
For more information about Elastic run, one can check this
373+
`quick start document <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ to learn more.

0 commit comments

Comments
 (0)