Skip to content

Commit ccadd5f

Browse files
James Reedsoumith
authored andcommitted
Intro to TorchScript tutorial (#592)
1 parent 8816366 commit ccadd5f

File tree

3 files changed

+389
-0
lines changed

3 files changed

+389
-0
lines changed

_static/img/torchscript.png

44.3 KB
Loading
Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
"""
2+
Introduction to TorchScript
3+
===========================
4+
5+
*James Reed ([email protected]), Michael Suo ([email protected])*, rev2
6+
7+
In this tutorial we will cover:
8+
9+
1. The basics of model authoring in PyTorch, including:
10+
11+
- Modules
12+
- Defining ``forward`` functions
13+
- Composing modules into a hierarchy of modules
14+
15+
2. Methods for converting PyTorch modules to TorchScript, our
16+
high-performance deployment runtime
17+
18+
- Tracing an existing module
19+
- Using scripting to directly compile a module
20+
- How to compose both approaches
21+
- Saving and loading TorchScript modules
22+
23+
"""
24+
25+
import torch # This is all you need to use both PyTorch and TorchScript!
26+
print(torch.__version__)
27+
28+
29+
######################################################################
30+
# Basics of PyTorch Model Authoring
31+
# ---------------------------------
32+
#
33+
# Let’s start out be defining a simple ``Module``. A ``Module`` is the
34+
# basic unit of composition in PyTorch. It contains:
35+
#
36+
# 1. A constructor, which prepares the module for invocation
37+
# 2. A set of ``Parameters`` and sub-\ ``Modules``. These are initialized
38+
# by the constructor and can be used by the module during invocation.
39+
# 3. A ``forward`` function. This is the code that is run when the module
40+
# is invoked.
41+
#
42+
# Let’s examine a small example:
43+
#
44+
45+
class MyCell(torch.nn.Module):
46+
def __init__(self):
47+
super(MyCell, self).__init__()
48+
49+
def forward(self, x, h):
50+
new_h = torch.tanh(x + h)
51+
return new_h, new_h
52+
53+
my_cell = MyCell()
54+
x = torch.rand(3, 4)
55+
h = torch.rand(3, 4)
56+
print(my_cell(x, h))
57+
58+
59+
######################################################################
60+
# So we’ve:
61+
#
62+
# 1. Created a class that subclasses ``torch.nn.Module``.
63+
# 2. Defined a constructor. The constructor doesn’t do much, just calls
64+
# the constructor for ``super``.
65+
# 3. Defined a ``forward`` function, which takes two inputs and returns
66+
# two outputs. The actual contents of the ``forward`` function are not
67+
# really important, but it’s sort of a fake `RNN
68+
# cell <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__–that
69+
# is–it’s a function that is applied on a loop.
70+
#
71+
# We instantiated the module, and made ``x`` and ``y``, which are just 3x4
72+
# matrices of random values. Then we invoked the cell with
73+
# ``my_cell(x, h)``. This in turn calls our ``forward`` function.
74+
#
75+
# Let’s do something a little more interesting:
76+
#
77+
78+
class MyCell(torch.nn.Module):
79+
def __init__(self):
80+
super(MyCell, self).__init__()
81+
self.linear = torch.nn.Linear(4, 4)
82+
83+
def forward(self, x, h):
84+
new_h = torch.tanh(self.linear(x) + h)
85+
return new_h, new_h
86+
87+
my_cell = MyCell()
88+
print(my_cell)
89+
print(my_cell(x, h))
90+
91+
92+
######################################################################
93+
# We’ve redefined our module ``MyCell``, but this time we’ve added a
94+
# ``self.linear`` attribute, and we invoke ``self.linear`` in the forward
95+
# function.
96+
#
97+
# What exactly is happening here? ``torch.nn.Linear`` is a ``Module`` from
98+
# the PyTorch standard library. Just like ``MyCell``, it can be invoked
99+
# using the call syntax. We are building a hierarchy of ``Module``\ s.
100+
#
101+
# ``print`` on a ``Module`` will give a visual representation of the
102+
# ``Module``\ ’s subclass hierarchy. In our example, we can see our
103+
# ``Linear`` subclass and its parameters.
104+
#
105+
# By composing ``Module``\ s in this way, we can succintly and readably
106+
# author models with reusable components.
107+
#
108+
# You may have noticed ``grad_fn`` on the outputs. This is a detail of
109+
# PyTorch’s method of automatic differentiation, called
110+
# `autograd <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__.
111+
# In short, this system allows us to compute derivatives through
112+
# potentially complex programs. The design allows for a massive amount of
113+
# flexibility in model authoring.
114+
#
115+
# Now let’s examine said flexibility:
116+
#
117+
118+
class MyDecisionGate(torch.nn.Module):
119+
def forward(self, x):
120+
if x.sum() > 0:
121+
return x
122+
else:
123+
return -x
124+
125+
class MyCell(torch.nn.Module):
126+
def __init__(self):
127+
super(MyCell, self).__init__()
128+
self.dg = MyDecisionGate()
129+
self.linear = torch.nn.Linear(4, 4)
130+
131+
def forward(self, x, h):
132+
new_h = torch.tanh(self.dg(self.linear(x)) + h)
133+
return new_h, new_h
134+
135+
my_cell = MyCell()
136+
print(my_cell)
137+
print(my_cell(x, h))
138+
139+
140+
######################################################################
141+
# We’ve once again redefined our MyCell class, but here we’ve defined
142+
# ``MyDecisionGate``. This module utilizes **control flow**. Control flow
143+
# consists of things like loops and ``if``-statements.
144+
#
145+
# Many frameworks take the approach of computing symbolic derivatives
146+
# given a full program representation. However, in PyTorch, we use a
147+
# gradient tape. We record operations as they occur, and replay them
148+
# backwards in computing derivatives. In this way, the framework does not
149+
# have to explicitly define derivatives for all constructs in the
150+
# language.
151+
#
152+
# .. figure:: https://github.com/pytorch/pytorch/raw/master/docs/source/_static/img/dynamic_graph.gif
153+
# :alt: How autograd works
154+
#
155+
# How autograd works
156+
#
157+
158+
159+
######################################################################
160+
# Basics of TorchScript
161+
# ---------------------
162+
#
163+
# Now let’s take our running example and see how we can apply TorchScript.
164+
#
165+
# In short, TorchScript provides tools to capture the definition of your
166+
# model, even in light of the flexible and dynamic nature of PyTorch.
167+
# Let’s begin by examining what we call **tracing**.
168+
#
169+
# Tracing ``Modules``
170+
# ~~~~~~~~~~~~~~~~~~~
171+
#
172+
173+
class MyCell(torch.nn.Module):
174+
def __init__(self):
175+
super(MyCell, self).__init__()
176+
self.linear = torch.nn.Linear(4, 4)
177+
178+
def forward(self, x, h):
179+
new_h = torch.tanh(self.linear(x) + h)
180+
return new_h, new_h
181+
182+
my_cell = MyCell()
183+
x, h = torch.rand(3, 4), torch.rand(3, 4)
184+
traced_cell = torch.jit.trace(my_cell, (x, h))
185+
print(traced_cell)
186+
traced_cell(x, h)
187+
188+
189+
######################################################################
190+
# We’ve rewinded a bit and taken the second version of our ``MyCell``
191+
# class. As before, we’ve instantiated it, but this time, we’ve called
192+
# ``torch.jit.trace``, passed in the ``Module``, and passed in *example
193+
# inputs* the network might see.
194+
#
195+
# What exactly has this done? It has invoked the ``Module``, recorded the
196+
# operations that occured when the ``Module`` was run, and created an
197+
# instance of ``torch.jit.ScriptModule`` (of which ``TracedModule`` is an
198+
# instance)
199+
#
200+
# TorchScript records its definitions in an Intermediate Representation
201+
# (or IR), commonly referred to in Deep learning as a *graph*. We can
202+
# examine the graph with the ``.graph`` property:
203+
#
204+
205+
print(traced_cell.graph)
206+
207+
208+
######################################################################
209+
# However, this is a very low-level representation and most of the
210+
# information contained in the graph is not useful for end users. Instead,
211+
# we can use the ``.code`` property to give a Python-syntax interpretation
212+
# of the code:
213+
#
214+
215+
print(traced_cell.code)
216+
217+
218+
######################################################################
219+
# So **why** did we do all this? There are several reasons:
220+
#
221+
# 1. TorchScript code can be invoked in its own interpreter, which is
222+
# basically a restricted Python interpreter. This interpreter does not
223+
# acquire the Global Interpreter Lock, and so many requests can be
224+
# processed on the same instance simultaneously.
225+
# 2. This format allows us to save the whole model to disk and load it
226+
# into another environment, such as in a server written in a language
227+
# other than Python
228+
# 3. TorchScript gives us a representation in which we can do compiler
229+
# optimizations on the code to provide more efficient execution
230+
# 4. TorchScript allows us to interface with many backend/device runtimes
231+
# that require a broader view of the program than individual operators.
232+
#
233+
# We can see that invoking ``traced_cell`` produces the same results as
234+
# the Python module:
235+
#
236+
237+
print(my_cell(x, h))
238+
print(traced_cell(x, h))
239+
240+
241+
######################################################################
242+
# Using Scripting to Convert Modules
243+
# ----------------------------------
244+
#
245+
# There’s a reason we used version two of our module, and not the one with
246+
# the control-flow-laden submodule. Let’s examine that now:
247+
#
248+
249+
class MyDecisionGate(torch.nn.Module):
250+
def forward(self, x):
251+
if x.sum() > 0:
252+
return x
253+
else:
254+
return -x
255+
256+
class MyCell(torch.nn.Module):
257+
def __init__(self, dg):
258+
super(MyCell, self).__init__()
259+
self.dg = dg
260+
self.linear = torch.nn.Linear(4, 4)
261+
262+
def forward(self, x, h):
263+
new_h = torch.tanh(self.dg(self.linear(x)) + h)
264+
return new_h, new_h
265+
266+
my_cell = MyCell(MyDecisionGate())
267+
traced_cell = torch.jit.trace(my_cell, (x, h))
268+
print(traced_cell.code)
269+
270+
271+
######################################################################
272+
# Looking at the ``.code`` output, we can see that the ``if-else`` branch
273+
# is nowhere to be found! Why? Tracing does exactly what we said it would:
274+
# run the code, record the operations *that happen* and construct a
275+
# ScriptModule that does exactly that. Unfortunately, things like control
276+
# flow are erased.
277+
#
278+
# How can we faithfully represent this module in TorchScript? We provide a
279+
# **script compiler**, which does direct analysis of your Python source
280+
# code to transform it into TorchScript. Let’s convert ``MyDecisionGate``
281+
# using the script compiler:
282+
#
283+
284+
scripted_gate = torch.jit.script(MyDecisionGate())
285+
286+
my_cell = MyCell(scripted_gate)
287+
traced_cell = torch.jit.script(my_cell)
288+
print(traced_cell.code)
289+
290+
291+
######################################################################
292+
# Hooray! We’ve now faithfully captured the behavior of our program in
293+
# TorchScript. Let’s now try running the program:
294+
#
295+
296+
# New inputs
297+
x, h = torch.rand(3, 4), torch.rand(3, 4)
298+
traced_cell(x, h)
299+
300+
301+
######################################################################
302+
# Mixing Scripting and Tracing
303+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
304+
#
305+
# Some situations call for using tracing rather than scripting (e.g. a
306+
# module has many architectural decisions that are made based on constant
307+
# Python values that we would like to not appear in TorchScript). In this
308+
# case, scripting can be composed with tracing: ``torch.jit.script`` will
309+
# inline the code for a traced module, and tracing will inline the code
310+
# for a scripted module.
311+
#
312+
# An example of the first case:
313+
#
314+
315+
class MyRNNLoop(torch.nn.Module):
316+
def __init__(self):
317+
super(MyRNNLoop, self).__init__()
318+
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
319+
320+
def forward(self, xs):
321+
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
322+
for i in range(xs.size(0)):
323+
y, h = self.cell(xs[i], h)
324+
return y, h
325+
326+
rnn_loop = torch.jit.script(MyRNNLoop())
327+
print(rnn_loop.code)
328+
329+
330+
331+
######################################################################
332+
# And an example of the second case:
333+
#
334+
335+
class WrapRNN(torch.nn.Module):
336+
def __init__(self):
337+
super(WrapRNN, self).__init__()
338+
self.loop = torch.jit.script(MyRNNLoop())
339+
340+
def forward(self, xs):
341+
y, h = self.loop(xs)
342+
return torch.relu(y)
343+
344+
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
345+
print(traced.code)
346+
347+
348+
######################################################################
349+
# This way, scripting and tracing can be used when the situation calls for
350+
# each of them and used together.
351+
#
352+
# Saving and Loading models
353+
# -------------------------
354+
#
355+
# We provide APIs to save and load TorchScript modules to/from disk in an
356+
# archive format. This format includes code, parameters, attributes, and
357+
# debug information, meaning that the archive is a freestanding
358+
# representation of the model that can be loaded in an entirely separate
359+
# process. Let’s save and load our wrapped RNN module:
360+
#
361+
362+
traced.save('wrapped_rnn.zip')
363+
364+
loaded = torch.jit.load('wrapped_rnn.zip')
365+
366+
print(loaded)
367+
print(loaded.code)
368+
369+
370+
######################################################################
371+
# As you can see, serialization preserves the module hierarchy and the
372+
# code we’ve been examining throughout. The model can also be loaded, for
373+
# example, `into
374+
# C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`__ for
375+
# python-free execution.
376+
#
377+
# Further Reading
378+
# ~~~~~~~~~~~~~~~
379+
#
380+
# We’ve completed our tutorial! For a more involved demonstration, check
381+
# out the NeurIPS demo for converting machine translation models using
382+
# TorchScript:
383+
# https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ
384+
#

index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ Extending PyTorch
200200
Production Usage
201201
----------------------
202202

203+
.. customgalleryitem::
204+
:tooltip: Introduction to TorchScript
205+
:description: :doc:`beginner/Intro_to_TorchScript`
206+
:figure: _static/img/torchscript.png
207+
203208
.. customgalleryitem::
204209
:tooltip: Loading a PyTorch model in C++
205210
:description: :doc:`advanced/cpp_export`

0 commit comments

Comments
 (0)