Skip to content

Commit f8465c3

Browse files
authored
Adding an Overview Page for PyTorch Distributed (#1056)
* Adding an Overview Page for PyTorch Distributed * Let existing PT Distributed tutorials link to the overview page * Add a link to AMP * Address Comments * Remove unnecessary dist.barrier()
1 parent 2f3ab79 commit f8465c3

File tree

9 files changed

+245
-17
lines changed

9 files changed

+245
-17
lines changed
34.9 KB
Loading

beginner_source/dist_overview.rst

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
PyTorch Distributed Overview
2+
============================
3+
**Author**: `Shen Li <https://mrshenli.github.io/>`_
4+
5+
6+
This is the overview page for the ``torch.distributed`` package. As there are
7+
more and more documents, examples and tutorials added at different locations,
8+
it becomes unclear which document or tutorial to consult for a specific problem
9+
or what is the best order to read these contents. The goal of this page is to
10+
address this problem by categorizing documents into different topics and briefly
11+
describe each of them. If this is your first time building distributed training
12+
applications using PyTorch, it is recommended to use this document to navigate
13+
to the technology that can best serve your use case.
14+
15+
16+
Introduction
17+
------------
18+
19+
As of PyTorch v1.6.0, features in ``torch.distributed`` can be categorized into
20+
three main components:
21+
22+
* `Distributed Data-Parallel Training <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
23+
(DDP) is a widely adopted single-program multiple-data training paradigm. With
24+
DDP, the model is replicated on every process, and every model replica will be
25+
fed with a different set of input data samples. DDP takes care of gradient
26+
communications to keep model replicas synchronized and overlaps it with the
27+
gradient computations to speed up training.
28+
* `RPC-Based Distributed Training <https://pytorch.org/docs/master/rpc.html>`__
29+
(RPC) is developed to support general training structures that cannot fit into
30+
data-parallel training, such as distributed pipeline parallelism, parameter
31+
server paradigm, and combination of DDP with other training paradigms. It
32+
helps manage remote object lifetime and extend autograd engine to beyond
33+
machine boundaries.
34+
* `Collective Communication <https://pytorch.org/docs/stable/distributed.html>`__
35+
(c10d) library support sending tensors across processes within a group. It
36+
offers both collective communication APIs (e.g.,
37+
`all_reduce <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce>`__
38+
and `all_gather <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather>`__)
39+
and P2P communication APIs (e.g.,
40+
`send <https://pytorch.org/docs/stable/distributed.html#torch.distributed.send>`__
41+
and `isend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.isend>`__).
42+
DDP and RPC (`ProcessGroup Backend <https://pytorch.org/docs/master/rpc.html#process-group-backend>`__)
43+
are built on c10d as of v1.6.0, where the former uses collective communications
44+
and the latter uses P2P communications. Usually, developers do not need to
45+
directly use this raw communication API, as DDP and RPC features above can serve
46+
many distributed training scenarios. However, there are use cases where this API
47+
is still helpful. One example would be distributed parameter averaging, where
48+
applications would like to compute the average values of all model parameters
49+
after the backward pass instead of using DDP to communicate gradients. This can
50+
decouple communications from computations and allow finer-grain control over
51+
what to communicate, but on the other hand, it also gives up the performance
52+
optimizations offered by DDP. The
53+
`Writing Distributed Applications with PyTorch <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
54+
shows examples of using c10d communication APIs.
55+
56+
57+
Most of the existing documents are written for either DDP or RPC, the remainder
58+
of this page will elaborate materials for these two components.
59+
60+
61+
Data Parallel Training
62+
----------------------
63+
64+
PyTorch provides several options for data-parallel training. For applications
65+
that gradually grow from simple to complex and from prototype to production, the
66+
common development trajectory would be:
67+
68+
1. Use single-device training, if the data and model can fit in one GPU, and the
69+
training speed is not a concern.
70+
2. Use single-machine multi-GPU
71+
`DataParallel <https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html>`__,
72+
if there are multiple GPUs on the server, and you would like to speed up
73+
training with the minimum code change.
74+
3. Use single-machine multi-GPU
75+
`DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__,
76+
if you would like to further speed up training and are willing to write a
77+
little more code to set it up.
78+
4. Use multi-machine `DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
79+
and the `launching script <https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md>`__,
80+
if the application needs to scale across machine boundaries.
81+
5. Use `torchelastic <https://pytorch.org/elastic>`__ to launch distributed
82+
training, if errors (e.g., OOM) are expected or if the resources can join and
83+
leave dynamically during the training.
84+
85+
86+
.. note:: Data-parallel training also works with `Automatic Mixed Precision (AMP) <https://pytorch.org/docs/master/notes/amp_examples.html#working-with-multiple-gpus>`__.
87+
88+
89+
``torch.nn.DataParallel``
90+
~~~~~~~~~~~~~~~~~~~~~~~~~
91+
92+
The `DataParallel <https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html>`__
93+
package enables single-machine multi-GPU parallelism with the lowest coding
94+
hurdle. It only requires a one-line change to the application code. The tutorial
95+
`Optional: Data Parallelism <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`__
96+
shows an example. The caveat is that, although ``DataParallel`` is very easy to
97+
use, it usually does not offer the best performance. This is because the
98+
implementation of ``DataParallel`` replicates the model in every forward pass,
99+
and its single-process multi-thread parallelism naturally suffers from GIL
100+
contentions. To get better performance, please consider using
101+
`DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__.
102+
103+
104+
``torch.nn.parallel.DistributedDataParallel``
105+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
106+
107+
Compared to `DataParallel <https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html>`__,
108+
`DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
109+
requires one more step to set up, i.e., calling
110+
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.
111+
DDP uses multi-process parallelism, and hence there is no GIL contention across
112+
model replicas. Moreover, the model is broadcast at DDP construction time instead
113+
of in every forward pass, which also helps to speed up training. DDP is shipped
114+
with several performance optimization technologies. For a more in-depth
115+
explanation, please refer to this
116+
`DDP paper <https://arxiv.org/abs/2006.15704>`__ (VLDB'20).
117+
118+
119+
DDP materials are listed below:
120+
121+
1. `DDP notes <https://pytorch.org/docs/stable/notes/ddp.html>`__
122+
offer a starter example and some brief descriptions of its design and
123+
implementation. If this is your first time using DDP, please start from this
124+
document.
125+
2. `Getting Started with Distributed Data Parallel <../intermediate/ddp_tutorial.html>`__
126+
explains some common problems with DDP training, including unbalanced
127+
workload, checkpointing, and multi-device models. Note that, DDP can be
128+
easily combined with single-machine multi-device model parallelism which is
129+
described in the
130+
`Single-Machine Model Parallel Best Practices <../intermediate/model_parallel_tutorial.html>`__
131+
tutorial.
132+
3. The `Launching and configuring distributed data parallel applications <https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md>`__
133+
document shows how to use the DDP launching script.
134+
4. `PyTorch Distributed Trainer with Amazon AWS <aws_distributed_training_tutorial.html>`__
135+
demonstrates how to use DDP on AWS.
136+
137+
TorchElastic
138+
~~~~~~~~~~~~
139+
140+
With the growth of the application complexity and scale, failure recovery
141+
becomes an imperative requirement. Sometimes, it is inevitable to hit errors
142+
like OOM when using DDP, but DDP itself cannot recover from those errors nor
143+
does basic ``try-except`` block work. This is because DDP requires all processes
144+
to operate in a closely synchronized manner and all ``AllReduce`` communications
145+
launched in different processes must match. If one of the processes in the group
146+
throws an OOM exception, it is likely to lead to desynchronization (mismatched
147+
``AllReduce`` operations) which would then cause a crash or hang. If you expect
148+
failures to occur during training or if resources might leave and join
149+
dynamically, please launch distributed data-parallel training using
150+
`torchelastic <https://pytorch.org/elastic>`__.
151+
152+
153+
General Distributed Training
154+
----------------------------
155+
156+
Many training paradigms do not fit into data parallelism, e.g.,
157+
parameter server paradigm, distributed pipeline parallelism, reinforcement
158+
learning applications with multiple observers or agents, etc. The
159+
`torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__ aims at
160+
supporting general distributed training scenarios.
161+
162+
The `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__ package
163+
has four main pillars:
164+
165+
* `RPC <https://pytorch.org/docs/master/rpc.html#rpc>`__ supports running
166+
a given function on a remote worker.
167+
* `RRef <https://pytorch.org/docs/master/rpc.html#rref>`__ helps to manage the
168+
lifetime of a remote object. The reference counting protocol is presented in the
169+
`RRef notes <https://pytorch.org/docs/master/rpc/rref.html#remote-reference-protocol>`__.
170+
* `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__
171+
extends the autograd engine beyond machine boundaries. Please refer to
172+
`Distributed Autograd Design <https://pytorch.org/docs/master/rpc/distributed_autograd.html#distributed-autograd-design>`__
173+
for more details.
174+
* `Distributed Optimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__
175+
that automatically reaches out to all participating workers to update
176+
parameters using gradients computed by the distributed autograd engine.
177+
178+
RPC Tutorials are listed below:
179+
180+
1. The `Getting Started with Distributed RPC Framework <../intermediate/rpc_tutorial.html>`__
181+
tutorial first uses a simple Reinforcement Learning (RL) example to
182+
demonstrate RPC and RRef. Then, it applies a basic distributed model
183+
parallelism to an RNN example to show how to use distributed autograd and
184+
distributed optimizer.
185+
2. The `Implementing a Parameter Server Using Distributed RPC Framework <../intermediate/rpc_param_server_tutorial.html>`__
186+
tutorial borrows the spirit of
187+
`HogWild! training <https://people.eecs.berkeley.edu/~brecht/papers/hogwildTR.pdf>`__
188+
and applies it to an asynchronous parameter server (PS) training application.
189+
3. The `Distributed Pipeline Parallelism Using RPC <../intermediate/dist_pipeline_parallel_tutorial.html>`__
190+
tutorial extends the single-machine pipeline parallel example (presented in
191+
`Single-Machine Model Parallel Best Practices <../intermediate/model_parallel_tutorial.html>`__)
192+
to a distributed environment and shows how to implement it using RPC.
193+
4. The `Implementing Batch RPC Processing Using Asynchronous Executions <../intermediate/rpc_async_execution.html>`__
194+
tutorial demonstrates how to implement RPC batch processing using the
195+
`@rpc.functions.async_execution <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.functions.async_execution>`__
196+
decorator, which can help speed up inference and training. It uses similar
197+
RL and PS examples employed in the above tutorials 1 and 2.

index.rst

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ Welcome to PyTorch Tutorials
297297

298298
.. Parallel-and-Distributed-Training
299299
300+
.. customcarditem::
301+
:header: PyTorch Distributed Overview
302+
:card_description: Briefly go over all concepts and features in the distributed package. Use this document to find the distributed training technology that can best serve your application.
303+
:image: _static/img/thumbnails/cropped/PyTorch-Distributed-Overview.png
304+
:link: beginner/dist_overview.html
305+
:tags: Parallel-and-Distributed-Training
306+
300307
.. customcarditem::
301308
:header: Single-Machine Model Parallel Best Practices
302309
:card_description: Learn how to implement model parallel, a distributed training technique which splits a single model onto different GPUs, rather than replicating the entire model on each GPU
@@ -311,6 +318,13 @@ Welcome to PyTorch Tutorials
311318
:link: intermediate/ddp_tutorial.html
312319
:tags: Parallel-and-Distributed-Training
313320

321+
.. customcarditem::
322+
:header: (advanced) PyTorch 1.0 Distributed Trainer with Amazon AWS
323+
:card_description: Set up the distributed package of PyTorch, use the different communication strategies, and go over some the internals of the package.
324+
:image: _static/img/thumbnails/cropped/advanced-PyTorch-1point0-Distributed-Trainer-with-Amazon-AWS.png
325+
:link: beginner/aws_distributed_training_tutorial.html
326+
:tags: Parallel-and-Distributed-Training
327+
314328
.. customcarditem::
315329
:header: Writing Distributed Applications with PyTorch
316330
:card_description: Set up the distributed package of PyTorch, use the different communication strategies, and go over some the internals of the package.
@@ -325,13 +339,6 @@ Welcome to PyTorch Tutorials
325339
:link: intermediate/rpc_tutorial.html
326340
:tags: Parallel-and-Distributed-Training
327341

328-
.. customcarditem::
329-
:header: (advanced) PyTorch 1.0 Distributed Trainer with Amazon AWS
330-
:card_description: Set up the distributed package of PyTorch, use the different communication strategies, and go over some the internals of the package.
331-
:image: _static/img/thumbnails/cropped/advanced-PyTorch-1point0-Distributed-Trainer-with-Amazon-AWS.png
332-
:link: beginner/aws_distributed_training_tutorial.html
333-
:tags: Parallel-and-Distributed-Training
334-
335342
.. customcarditem::
336343
:header: Implementing a Parameter Server Using Distributed RPC Framework
337344
:card_description: Walk through a through a simple example of implementing a parameter server using PyTorch’s Distributed RPC framework.
@@ -513,6 +520,7 @@ Additional Resources
513520
:hidden:
514521
:caption: Parallel and Distributed Training
515522

523+
beginner/dist_overview
516524
intermediate/model_parallel_tutorial
517525
intermediate/ddp_tutorial
518526
intermediate/dist_tuto

intermediate_source/ddp_tutorial.rst

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ Getting Started with Distributed Data Parallel
22
=================================================
33
**Author**: `Shen Li <https://mrshenli.github.io/>`_
44

5+
Prerequisites:
6+
7+
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
8+
- `DistributedDataParallel API documents <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
9+
- `DistributedDataParallel notes <https://pytorch.org/docs/master/notes/ddp.html>`__
10+
11+
512
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__
613
(DDP) implements data parallelism at the module level which can run across
714
multiple machines. Applications using DDP should spawn multiple processes and
@@ -202,9 +209,9 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
202209
loss_fn(outputs, labels).backward()
203210
optimizer.step()
204211
205-
# Use a barrier() to make sure that all processes have finished reading the
206-
# checkpoint
207-
dist.barrier()
212+
# Not necessary to use a dist.barrier() to guard the file deletion below
213+
# as the AllReduce ops in the backward pass of DDP already served as
214+
# a synchronization.
208215
209216
if rank == 0:
210217
os.remove(CHECKPOINT_PATH)

intermediate_source/dist_pipeline_parallel_tutorial.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Distributed Pipeline Parallelism Using RPC
44

55
Prerequisites:
66

7+
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
78
- `Single-Machine Model Parallel Best Practices <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
89
- `Getting started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__
910
- RRef helper functions:

intermediate_source/dist_tuto.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ Writing Distributed Applications with PyTorch
22
=============================================
33
**Author**: `Séb Arnold <https://seba1511.com>`_
44

5+
Prerequisites:
6+
7+
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
8+
59
In this short tutorial, we will be going over the distributed package
610
of PyTorch. We'll see how to set up the distributed setting, use the
711
different communication strategies, and go over some the internals of

intermediate_source/rpc_async_execution.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ Implementing Batch RPC Processing Using Asynchronous Executions
55

66
Prerequisites:
77

8-
- `Getting started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__
9-
- `Implementing a Parameter Server using Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html>`__
8+
- `PyTorch Distributed Overview <../beginner/dist_overview.html>`__
9+
- `Getting started with Distributed RPC Framework <rpc_tutorial.html>`__
10+
- `Implementing a Parameter Server using Distributed RPC Framework <rpc_param_server_tutorial.html>`__
1011
- `RPC Asynchronous Execution Decorator <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.functions.async_execution>`__
1112

1213
This tutorial demonstrates how to build batch-processing RPC applications with

0 commit comments

Comments
 (0)