Skip to content

Commit 28f044e

Browse files
author
Jessica Lin
authored
Merge pull request #1021 from jamesr66a/fork_join
Add TorchScript fork/join tutorial
2 parents f7d7360 + 353174f commit 28f044e

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed
228 KB
Loading
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
Dynamic Parallelism in TorchScript
2+
==================================
3+
4+
In this tutorial, we introduce the syntax for doing *dynamic inter-op parallelism*
5+
in TorchScript. This parallelism has the following properties:
6+
7+
* dynamic - The number of parallel tasks created and their workload can depend on the control flow of the program.
8+
* inter-op - The parallelism is concerned with running TorchScript program fragments in parallel. This is distinct from *intra-op parallelism*, which is concerned with splitting up individual operators and running subsets of the operator's work in parallel.
9+
Basic Syntax
10+
------------
11+
12+
The two important APIs for dynamic parallelism are:
13+
14+
* ``torch.jit.fork(fn : Callable[..., T], *args, **kwargs) -> torch.jit.Future[T]``
15+
* ``torch.jit.wait(fut : torch.jit.Future[T]) -> T``
16+
17+
A good way to demonstrate how these work is by way of an example:
18+
19+
.. code-block:: python
20+
21+
import torch
22+
23+
def foo(x):
24+
return torch.neg(x)
25+
26+
@torch.jit.script
27+
def example(x):
28+
# Call `foo` using parallelism:
29+
# First, we "fork" off a task. This task will run `foo` with argument `x`
30+
future = torch.jit.fork(foo, x)
31+
32+
# Call `foo` normally
33+
x_normal = foo(x)
34+
35+
# Second, we "wait" on the task. Since the task may be running in
36+
# parallel, we have to "wait" for its result to become available.
37+
# Notice that by having lines of code between the "fork()" and "wait()"
38+
# call for a given Future, we can overlap computations so that they
39+
# run in parallel.
40+
x_parallel = torch.jit.wait(future)
41+
42+
return x_normal, x_parallel
43+
44+
print(example(torch.ones(1))) # (-1., -1.)
45+
46+
47+
``fork()`` takes the callable ``fn`` and arguments to that callable ``args``
48+
and ``kwargs`` and creates an asynchronous task for the execution of ``fn``.
49+
``fn`` can be a function, method, or Module instance. ``fork()`` returns a
50+
reference to the value of the result of this execution, called a ``Future``.
51+
Because ``fork`` returns immediately after creating the async task, ``fn`` may
52+
not have been executed by the time the line of code after the ``fork()`` call
53+
is executed. Thus, ``wait()`` is used to wait for the async task to complete
54+
and return the value.
55+
56+
These constructs can be used to overlap the execution of statements within a
57+
function (shown in the worked example section) or be composed with other language
58+
constructs like loops:
59+
60+
.. code-block:: python
61+
62+
import torch
63+
from typing import List
64+
65+
def foo(x):
66+
return torch.neg(x)
67+
68+
@torch.jit.script
69+
def example(x):
70+
futures : List[torch.jit.Future[torch.Tensor]] = []
71+
for _ in range(100):
72+
futures.append(torch.jit.fork(foo, x))
73+
74+
results = []
75+
for future in futures:
76+
results.append(torch.jit.wait(future))
77+
78+
return torch.sum(torch.stack(results))
79+
80+
print(example(torch.ones([])))
81+
82+
.. note::
83+
84+
When we initialized an empty list of Futures, we needed to add an explicit
85+
type annotation to ``futures``. In TorchScript, empty containers default
86+
to assuming they contain Tensor values, so we annotate the list constructor
87+
# as being of type ``List[torch.jit.Future[torch.Tensor]]``
88+
89+
This example uses ``fork()`` to launch 100 instances of the function ``foo``,
90+
waits on the 100 tasks to complete, then sums the results, returning ``-100.0``.
91+
92+
Applied Example: Ensemble of Bidirectional LSTMs
93+
------------------------------------------------
94+
95+
Let's try to apply parallelism to a more realistic example and see what sort
96+
of performance we can get out of it. First, let's define the baseline model: an
97+
ensemble of bidirectional LSTM layers.
98+
99+
.. code-block:: python
100+
101+
import torch, time
102+
103+
# In RNN parlance, the dimensions we care about are:
104+
# # of time-steps (T)
105+
# Batch size (B)
106+
# Hidden size/number of "channels" (C)
107+
T, B, C = 50, 50, 1024
108+
109+
# A module that defines a single "bidirectional LSTM". This is simply two
110+
# LSTMs applied to the same sequence, but one in reverse
111+
class BidirectionalRecurrentLSTM(torch.nn.Module):
112+
def __init__(self):
113+
super().__init__()
114+
self.cell_f = torch.nn.LSTM(input_size=C, hidden_size=C)
115+
self.cell_b = torch.nn.LSTM(input_size=C, hidden_size=C)
116+
117+
def forward(self, x : torch.Tensor) -> torch.Tensor:
118+
# Forward layer
119+
output_f, _ = self.cell_f(x)
120+
121+
# Backward layer. Flip input in the time dimension (dim 0), apply the
122+
# layer, then flip the outputs in the time dimension
123+
x_rev = torch.flip(x, dims=[0])
124+
output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
125+
output_b_rev = torch.flip(output_b, dims=[0])
126+
127+
return torch.cat((output_f, output_b_rev), dim=2)
128+
129+
130+
# An "ensemble" of `BidirectionalRecurrentLSTM` modules. The modules in the
131+
# ensemble are run one-by-one on the same input then their results are
132+
# stacked and summed together, returning the combined result.
133+
class LSTMEnsemble(torch.nn.Module):
134+
def __init__(self, n_models):
135+
super().__init__()
136+
self.n_models = n_models
137+
self.models = torch.nn.ModuleList([
138+
BidirectionalRecurrentLSTM() for _ in range(self.n_models)])
139+
140+
def forward(self, x : torch.Tensor) -> torch.Tensor:
141+
results = []
142+
for model in self.models:
143+
results.append(model(x))
144+
return torch.stack(results).sum(dim=0)
145+
146+
# For a head-to-head comparison to what we're going to do with fork/wait, let's
147+
# instantiate the model and compile it with TorchScript
148+
ens = torch.jit.script(LSTMEnsemble(n_models=4))
149+
150+
# Normally you would pull this input out of an embedding table, but for the
151+
# purpose of this demo let's just use random data.
152+
x = torch.rand(T, B, C)
153+
154+
# Let's run the model once to warm up things like the memory allocator
155+
ens(x)
156+
157+
x = torch.rand(T, B, C)
158+
159+
# Let's see how fast it runs!
160+
s = time.time()
161+
ens(x)
162+
print('Inference took', time.time() - s, ' seconds')
163+
164+
On my machine, this network runs in ``2.05`` seconds. We can do a lot better!
165+
166+
Parallelizing Forward and Backward Layers
167+
-----------------------------------------
168+
169+
A very simple thing we can do is parallelize the forward and backward layers
170+
within ``BidirectionalRecurrentLSTM``. For this, the structure of the computation
171+
is static, so we don't actually even need any loops. Let's rewrite the ``forward``
172+
method of ``BidirectionalRecurrentLSTM`` like so:
173+
174+
.. code-block:: python
175+
176+
def forward(self, x : torch.Tensor) -> torch.Tensor:
177+
# Forward layer - fork() so this can run in parallel to the backward
178+
# layer
179+
future_f = torch.jit.fork(self.cell_f, x)
180+
181+
# Backward layer. Flip input in the time dimension (dim 0), apply the
182+
# layer, then flip the outputs in the time dimension
183+
x_rev = torch.flip(x, dims=[0])
184+
output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
185+
output_b_rev = torch.flip(output_b, dims=[0])
186+
187+
# Retrieve the output from the forward layer. Note this needs to happen
188+
# *after* the stuff we want to parallelize with
189+
output_f, _ = torch.jit.wait(future_f)
190+
191+
return torch.cat((output_f, output_b_rev), dim=2)
192+
193+
In this example, ``forward()`` delegates execution of ``cell_f`` to another thread,
194+
while it continues to execute ``cell_b``. This causes the execution of both the
195+
cells to be overlapped with each other.
196+
197+
Running the script again with this simple modification yields a runtime of
198+
``1.71`` seconds for an improvement of ``17%``!
199+
200+
Aside: Visualizing Parallelism
201+
------------------------------
202+
203+
We're not done optimizing our model but it's worth introducing the tooling we
204+
have for visualizing performance. One important tool is the `PyTorch profiler <https://pytorch.org/docs/stable/autograd.html#profiler>`_.
205+
206+
Let's use the profiler along with the Chrome trace export functionality to
207+
visualize the performance of our parallelized model:
208+
209+
.. code-block:: python
210+
with torch.autograd.profiler.profile() as prof:
211+
ens(x)
212+
prof.export_chrome_trace('parallel.json')
213+
214+
This snippet of code will write out a file named ``parallel.json``. If you
215+
navigate Google Chrome to ``chrome://tracing``, click the ``Load`` button, and
216+
load in that JSON file, you should see a timeline like the following:
217+
218+
.. image:: https://i.imgur.com/rm5hdG9.png
219+
220+
The horizontal axis of the timeline represents time and the vertical axis
221+
represents threads of execution. As we can see, we are running two ``lstm``
222+
instances at a time. This is the result of our hard work parallelizing the
223+
bidirectional layers!
224+
225+
Parallelizing Models in the Ensemble
226+
------------------------------------
227+
228+
You may have noticed that there is a further parallelization opportunity in our
229+
code: we can also run the models contained in ``LSTMEnsemble`` in parallel with
230+
each other. The way to do that is simple enough, this is how we should change
231+
the ``forward`` method of ``LSTMEnsemble``:
232+
233+
.. code-block:: python
234+
235+
def forward(self, x : torch.Tensor) -> torch.Tensor:
236+
# Launch tasks for each model
237+
futures : List[torch.jit.Future[torch.Tensor]] = []
238+
for model in self.models:
239+
futures.append(torch.jit.fork(model, x))
240+
241+
# Collect the results from the launched tasks
242+
results : List[torch.Tensor] = []
243+
for future in futures:
244+
results.append(torch.jit.wait(future))
245+
246+
return torch.stack(results).sum(dim=0)
247+
248+
Or, if you value brevity, we can use list comprehensions:
249+
250+
.. code-block:: python
251+
252+
def forward(self, x : torch.Tensor) -> torch.Tensor:
253+
futures = [torch.jit.fork(model, x) for model in self.models]
254+
results = [torch.jit.wait(fut) for fut in futures]
255+
return torch.stack(results).sum(dim=0)
256+
257+
Like described in the intro, we've used loops to fork off tasks for each of the
258+
models in our ensemble. We've then used another loop to wait for all of the
259+
tasks to be completed. This provides even more overlap of computation.
260+
261+
With this small update, the script runs in ``1.4`` seconds, for a total speedup
262+
of ``32%``! Pretty good for two lines of code.
263+
264+
We can also use the Chrome tracer again to see where's going on:
265+
266+
.. image:: https://i.imgur.com/kA0gyQm.png
267+
268+
We can now see that all ``LSTM`` instances are being run fully in parallel.
269+
270+
Conclusion
271+
----------
272+
273+
In this tutorial, we learned about ``fork()`` and ``wait()``, the basic APIs
274+
for doing dynamic, inter-op parallelism in TorchScript. We saw a few typical
275+
usage patterns for using these functions to parallelize the execution of
276+
functions, methods, or ``Modules`` in TorchScript code. Finally, we worked through
277+
an example of optimizing a model using this technique and explored the performance
278+
measurement and visualization tooling available in PyTorch.

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,13 @@ Welcome to PyTorch Tutorials
244244
:link: advanced/torch_script_custom_classes.html
245245
:tags: Frontend-APIs,TorchScript,C++
246246

247+
.. customcarditem::
248+
:header: Dynamic Parallelism in TorchScript
249+
:card_description: This tutorial introduces the syntax for doing *dynamic inter-op parallelism* in TorchScript.
250+
:image: _static/img/thumbnails/cropped/TorchScript-Parallelism.jpg
251+
:link: advanced/torch-script-parallelism.html
252+
:tags: Frontend-APIs,TorchScript,C++
253+
247254
.. customcarditem::
248255
:header: Autograd in C++ Frontend
249256
:card_description: The autograd package helps build flexible and dynamic nerural netorks. In this tutorial, exploreseveral examples of doing autograd in PyTorch C++ frontend
@@ -471,6 +478,7 @@ Additional Resources
471478
advanced/cpp_extension
472479
advanced/torch_script_custom_ops
473480
advanced/torch_script_custom_classes
481+
advanced/torch-script-parallelism
474482
advanced/cpp_autograd
475483

476484
.. toctree::

0 commit comments

Comments
 (0)