Skip to content

Commit 3e32d22

Browse files
committed
Port custom ops tutorial to new registration API, increase testability.
Signed-off-by: Edward Z. Yang <[email protected]>
1 parent 3740027 commit 3e32d22

File tree

4 files changed

+145
-127
lines changed

4 files changed

+145
-127
lines changed

advanced_source/torch_script_custom_ops.rst

Lines changed: 105 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ like this:
124124
:end-before: END output_tensor
125125

126126
We use the ``.ptr<float>()`` method on the OpenCV ``Mat`` class to get a raw
127-
pointer to the underlying data (just like ``.data<float>()`` for the PyTorch
127+
pointer to the underlying data (just like ``.data_ptr<float>()`` for the PyTorch
128128
tensor earlier). We also specify the output shape of the tensor, which we
129129
hardcoded as ``8 x 8``. The output of ``torch::from_blob`` is then a
130130
``torch::Tensor``, pointing to the memory owned by the OpenCV matrix.
@@ -145,40 +145,28 @@ Registering the Custom Operator with TorchScript
145145
Now that have implemented our custom operator in C++, we need to *register* it
146146
with the TorchScript runtime and compiler. This will allow the TorchScript
147147
compiler to resolve references to our custom operator in TorchScript code.
148-
Registration is very simple. For our case, we need to write:
148+
If you have ever used the pybind11 library, our syntax for registration
149+
resembles the pybind11 syntax very closely. To register a single function,
150+
we write:
149151

150152
.. literalinclude:: ../advanced_source/torch_script_custom_ops/op.cpp
151153
:language: cpp
152154
:start-after: BEGIN registry
153155
:end-before: END registry
154156

155-
somewhere in the global scope of our ``op.cpp`` file. This creates a global
156-
variable ``registry``, which will register our operator with TorchScript in its
157-
constructor (i.e. exactly once per program). We specify the name of the
158-
operator, and a pointer to its implementation (the function we wrote earlier).
159-
The name consists of two parts: a *namespace* (``my_ops``) and a name for the
160-
particular operator we are registering (``warp_perspective``). The namespace and
161-
operator name are separated by two colons (``::``).
157+
somewhere at the top level of our ``op.cpp`` file. The ``TORCH_LIBRARY`` macro
158+
creates a function that will be called when your program starts. The name
159+
of your library (``my_ops``) is given as the first argument (it should not
160+
be in quotes). The second argument (``m``) defines a variable of type
161+
``torch::Library`` which is the main interface to register your operators.
162+
The method ``Library::def`` actually creates an operator named ``warp_perspective``,
163+
exposing it to both Python and TorchScript. You can define as many operators
164+
as you like by making multiple calls to ``def``.
162165

163-
.. tip::
164-
165-
If you want to register more than one operator, you can chain calls to
166-
``.op()`` after the constructor:
167-
168-
.. code-block:: cpp
169-
170-
static auto registry =
171-
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective)
172-
.op("my_ops::another_op", &another_op)
173-
.op("my_ops::and_another_op", &and_another_op);
174-
175-
Behind the scenes, ``RegisterOperators`` will perform a number of fairly
176-
complicated C++ template metaprogramming magic tricks to infer the argument and
177-
return value types of the function pointer we pass it (``&warp_perspective``).
178-
This information is used to form a *function schema* for our operator. A
179-
function schema is a structured representation of an operator -- a kind of
180-
"signature" or "prototype" -- used by the TorchScript compiler to verify
181-
correctness in TorchScript programs.
166+
Behinds the scenes, the ``def`` function is actually doing quite a bit of work:
167+
it is using template metaprogramming to inspect the type signature of your
168+
function and translate it into an operator schema which specifies the operators
169+
type within TorchScript's type system.
182170

183171
Building the Custom Operator
184172
----------------------------
@@ -189,7 +177,16 @@ we can load into Python for research and experimentation, or into C++ for
189177
inference in a no-Python environment. There exist multiple ways to build our
190178
operator, using either pure CMake, or Python alternatives like ``setuptools``.
191179
For brevity, the paragraphs below only discuss the CMake approach. The appendix
192-
of this tutorial dives into the Python based alternatives.
180+
of this tutorial dives into other alternatives.
181+
182+
Environment setup
183+
*****************
184+
185+
We need an installation of PyTorch and OpenCV. The easiest and most platform
186+
independent way to get both is to via Conda::
187+
188+
conda install -c pytorch pytorch
189+
conda install opencv
193190

194191
Building with CMake
195192
*******************
@@ -203,29 +200,11 @@ a directory structure that looks like this::
203200
op.cpp
204201
CMakeLists.txt
205202

206-
Also, make sure to grab the latest version of the LibTorch distribution, which
207-
packages PyTorch's C++ libraries and CMake build files, from `pytorch.org
208-
<https://pytorch.org/get-started/locally>`_. Place the unzipped distribution
209-
somewhere accessible in your file system. The following paragraphs will refer to
210-
that location as ``/path/to/libtorch``. The contents of our ``CMakeLists.txt``
211-
file should then be the following:
203+
The contents of our ``CMakeLists.txt`` file should then be the following:
212204

213205
.. literalinclude:: ../advanced_source/torch_script_custom_ops/CMakeLists.txt
214206
:language: cpp
215207

216-
.. warning::
217-
218-
This setup makes some assumptions about the build environment, particularly
219-
what pertains to the installation of OpenCV. The above ``CMakeLists.txt`` file
220-
was tested inside a Docker container running Ubuntu Xenial with
221-
``libopencv-dev`` installed via ``apt``. If it does not work for you and you
222-
feel stuck, please use the ``Dockerfile`` in the `accompanying tutorial
223-
repository <https://github.com/pytorch/extension-script>`_ to
224-
build an isolated, reproducible environment in which to play around with the
225-
code from this tutorial. If you run into further troubles, please file an
226-
issue in the tutorial repository or post a question in `our forum
227-
<https://discuss.pytorch.org/>`_.
228-
229208
To now build our operator, we can run the following commands from our
230209
``warp_perspective`` folder:
231210

@@ -268,24 +247,18 @@ To now build our operator, we can run the following commands from our
268247
[100%] Built target warp_perspective
269248
270249
which will place a ``libwarp_perspective.so`` shared library file in the
271-
``build`` folder. In the ``cmake`` command above, you should replace
272-
``/path/to/libtorch`` with the path to your unzipped LibTorch distribution.
250+
``build`` folder. In the ``cmake`` command above, we use the helper
251+
variable ``torch.utils.cmake_prefix_path`` to conveniently tell us where
252+
the cmake files for our PyTorch install are.
273253
274254
We will explore how to use and call our operator in detail further below, but to
275255
get an early sensation of success, we can try running the following code in
276256
Python:
277257
278-
.. code-block:: python
279-
280-
>>> import torch
281-
>>> torch.ops.load_library("/path/to/libwarp_perspective.so")
282-
>>> print(torch.ops.my_ops.warp_perspective)
283-
284-
Here, ``/path/to/libwarp_perspective.so`` should be a relative or absolute path
285-
to the ``libwarp_perspective.so`` shared library we just built. If all goes
286-
well, this should print something like
258+
.. literalinclude:: ../advanced_source/torch_script_custom_ops/smoke_test.py
259+
:language: python
287260
288-
.. code-block:: python
261+
If all goes well, this should print something like::
289262
290263
<built-in method my_ops::warp_perspective of PyCapsule object at 0x7f618fc6fa50>
291264
@@ -302,10 +275,9 @@ TorchScript code.
302275
You already saw how to import your operator into Python:
303276
``torch.ops.load_library()``. This function takes the path to a shared library
304277
containing custom operators, and loads it into the current process. Loading the
305-
shared library will also execute the constructor of the global
306-
``RegisterOperators`` object we placed into our custom operator implementation
307-
file. This will register our custom operator with the TorchScript compiler and
308-
allow us to use that operator in TorchScript code.
278+
shared library will also execute the ``TORCH_LIBRARY`` block. This will register
279+
our custom operator with the TorchScript compiler and allow us to use that
280+
operator in TorchScript code.
309281
310282
You can refer to your loaded operator as ``torch.ops.<namespace>.<function>``,
311283
where ``<namespace>`` is the namespace part of your operator name, and
@@ -316,11 +288,16 @@ While this function can be used in scripted or traced TorchScript modules, we
316288
can also just use it in vanilla eager PyTorch and pass it regular PyTorch
317289
tensors:
318290
291+
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
292+
:language: python
293+
:prepend: import torch
294+
:start-after: BEGIN preamble
295+
:end-before: END preamble
296+
297+
producing:
298+
319299
.. code-block:: python
320300
321-
>>> import torch
322-
>>> torch.ops.load_library("libwarp_perspective.so")
323-
>>> torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3))
324301
tensor([[0.0000, 0.3218, 0.4611, ..., 0.4636, 0.4636, 0.4636],
325302
[0.3746, 0.0978, 0.5005, ..., 0.4636, 0.4636, 0.4636],
326303
[0.3245, 0.0169, 0.0000, ..., 0.4458, 0.4458, 0.4458],
@@ -332,90 +309,92 @@ tensors:
332309
333310
.. note::
334311
335-
What happens behind the scenes is that the first time you access
336-
``torch.ops.namespace.function`` in Python, the TorchScript compiler (in C++
337-
land) will see if a function ``namespace::function`` has been registered, and
338-
if so, return a Python handle to this function that we can subsequently use to
339-
call into our C++ operator implementation from Python. This is one noteworthy
340-
difference between TorchScript custom operators and C++ extensions: C++
341-
extensions are bound manually using pybind11, while TorchScript custom ops are
342-
bound on the fly by PyTorch itself. Pybind11 gives you more flexibility with
343-
regards to what types and classes you can bind into Python and is thus
344-
recommended for purely eager code, but it is not supported for TorchScript
345-
ops.
312+
What happens behind the scenes is that the first time you access
313+
``torch.ops.namespace.function`` in Python, the TorchScript compiler (in C++
314+
land) will see if a function ``namespace::function`` has been registered, and
315+
if so, return a Python handle to this function that we can subsequently use to
316+
call into our C++ operator implementation from Python. This is one noteworthy
317+
difference between TorchScript custom operators and C++ extensions: C++
318+
extensions are bound manually using pybind11, while TorchScript custom ops are
319+
bound on the fly by PyTorch itself. Pybind11 gives you more flexibility with
320+
regards to what types and classes you can bind into Python and is thus
321+
recommended for purely eager code, but it is not supported for TorchScript
322+
ops.
346323
347324
From here on, you can use your custom operator in scripted or traced code just
348325
as you would other functions from the ``torch`` package. In fact, "standard
349326
library" functions like ``torch.matmul`` go through largely the same
350327
registration path as custom operators, which makes custom operators really
351328
first-class citizens when it comes to how and where they can be used in
352-
TorchScript.
329+
TorchScript. (One difference, however, is that standard library functions
330+
have custom written Python argument parsing logic that differs from
331+
``torch.ops`` argument parsing.)
353332
354333
Using the Custom Operator with Tracing
355334
**************************************
356335
357336
Let's start by embedding our operator in a traced function. Recall that for
358337
tracing, we start with some vanilla Pytorch code:
359338
360-
.. code-block:: python
361-
362-
def compute(x, y, z):
363-
return x.matmul(y) + torch.relu(z)
339+
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
340+
:language: python
341+
:start-after: BEGIN compute
342+
:end-before: END compute
364343
365344
and then call ``torch.jit.trace`` on it. We further pass ``torch.jit.trace``
366345
some example inputs, which it will forward to our implementation to record the
367346
sequence of operations that occur as the inputs flow through it. The result of
368347
this is effectively a "frozen" version of the eager PyTorch program, which the
369348
TorchScript compiler can further analyze, optimize and serialize:
370349
371-
.. code-block:: python
350+
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
351+
:language: python
352+
:start-after: BEGIN trace
353+
:end-before: END trace
372354
373-
>>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
374-
>>> trace = torch.jit.trace(compute, inputs)
375-
>>> print(trace.graph)
376-
graph(%x : Float(4, 8)
377-
%y : Float(8, 5)
378-
%z : Float(4, 5)) {
379-
%3 : Float(4, 5) = aten::matmul(%x, %y)
380-
%4 : Float(4, 5) = aten::relu(%z)
381-
%5 : int = prim::Constant[value=1]()
382-
%6 : Float(4, 5) = aten::add(%3, %4, %5)
383-
return (%6);
384-
}
355+
Producing::
356+
357+
graph(%x : Float(4:8, 8:1),
358+
%y : Float(8:5, 5:1),
359+
%z : Float(4:5, 5:1)):
360+
%3 : Float(4:5, 5:1) = aten::matmul(%x, %y) # test.py:10:0
361+
%4 : Float(4:5, 5:1) = aten::relu(%z) # test.py:10:0
362+
%5 : int = prim::Constant[value=1]() # test.py:10:0
363+
%6 : Float(4:5, 5:1) = aten::add(%3, %4, %5) # test.py:10:0
364+
return (%6)
385365
386366
Now, the exciting revelation is that we can simply drop our custom operator into
387367
our PyTorch trace as if it were ``torch.relu`` or any other ``torch`` function:
388368
389-
.. code-block:: python
390-
391-
torch.ops.load_library("libwarp_perspective.so")
392-
393-
def compute(x, y, z):
394-
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
395-
return x.matmul(y) + torch.relu(z)
369+
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
370+
:language: python
371+
:start-after: BEGIN compute2
372+
:end-before: END compute2
396373
397374
and then trace it as before:
398375
399-
.. code-block:: python
400-
401-
>>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(8, 5)]
402-
>>> trace = torch.jit.trace(compute, inputs)
403-
>>> print(trace.graph)
404-
graph(%x.1 : Float(4, 8)
405-
%y : Float(8, 5)
406-
%z : Float(8, 5)) {
407-
%3 : int = prim::Constant[value=3]()
408-
%4 : int = prim::Constant[value=6]()
409-
%5 : int = prim::Constant[value=0]()
410-
%6 : int[] = prim::Constant[value=[0, -1]]()
411-
%7 : Float(3, 3) = aten::eye(%3, %4, %5, %6)
412-
%x : Float(8, 8) = my_ops::warp_perspective(%x.1, %7)
413-
%11 : Float(8, 5) = aten::matmul(%x, %y)
414-
%12 : Float(8, 5) = aten::relu(%z)
415-
%13 : int = prim::Constant[value=1]()
416-
%14 : Float(8, 5) = aten::add(%11, %12, %13)
417-
return (%14);
418-
}
376+
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
377+
:language: python
378+
:start-after: BEGIN trace2
379+
:end-before: END trace2
380+
381+
Producing::
382+
383+
graph(%x.1 : Float(4:8, 8:1),
384+
%y : Float(8:5, 5:1),
385+
%z : Float(8:5, 5:1)):
386+
%3 : int = prim::Constant[value=3]() # test.py:25:0
387+
%4 : int = prim::Constant[value=6]() # test.py:25:0
388+
%5 : int = prim::Constant[value=0]() # test.py:25:0
389+
%6 : Device = prim::Constant[value="cpu"]() # test.py:25:0
390+
%7 : bool = prim::Constant[value=0]() # test.py:25:0
391+
%8 : Float(3:3, 3:1) = aten::eye(%3, %4, %5, %6, %7) # test.py:25:0
392+
%x : Float(8:8, 8:1) = my_ops::warp_perspective(%x.1, %8) # test.py:25:0
393+
%10 : Float(8:5, 5:1) = aten::matmul(%x, %y) # test.py:26:0
394+
%11 : Float(8:5, 5:1) = aten::relu(%z) # test.py:26:0
395+
%12 : int = prim::Constant[value=1]() # test.py:26:0
396+
%13 : Float(8:5, 5:1) = aten::add(%10, %11, %12) # test.py:26:0
397+
return (%13)
419398
420399
Integrating TorchScript custom ops into traced PyTorch code is as easy as this!
421400
@@ -947,8 +926,9 @@ custom TorchScript operator as a string. For this, use
947926
return output.clone();
948927
}
949928
950-
static auto registry =
951-
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
929+
TORCH_LIBRARY(my_ops, m) {
930+
m.def("warp_perspective", &warp_perspective);
931+
}
952932
"""
953933
954934
torch.utils.cpp_extension.load_inline(

advanced_source/torch_script_custom_ops/op.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
3030
// END warp_perspective
3131

3232
// BEGIN registry
33-
static auto registry =
34-
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
33+
TORCH_LIBRARY(my_ops, m) {
34+
m.def("warp_perspective", warp_perspective);
35+
}
3536
// END registry
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torch
2+
torch.ops.load_library("build/libwarp_perspective.so")
3+
print(torch.ops.my_ops.warp_perspective)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
3+
4+
print("BEGIN preamble")
5+
torch.ops.load_library("build/libwarp_perspective.so")
6+
print(torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3)))
7+
print("END preamble")
8+
9+
10+
# BEGIN compute
11+
def compute(x, y, z):
12+
return x.matmul(y) + torch.relu(z)
13+
# END compute
14+
15+
16+
print("BEGIN trace")
17+
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
18+
trace = torch.jit.trace(compute, inputs)
19+
print(trace.graph)
20+
print("END trace")
21+
22+
23+
# BEGIN compute2
24+
def compute(x, y, z):
25+
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
26+
return x.matmul(y) + torch.relu(z)
27+
# END compute2
28+
29+
30+
print("BEGIN trace2")
31+
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(8, 5)]
32+
trace = torch.jit.trace(compute, inputs)
33+
print(trace.graph)
34+
print("END trace2")

0 commit comments

Comments
 (0)