@@ -124,7 +124,7 @@ like this:
124124 :end-before: END output_tensor
125125
126126We use the ``.ptr<float>() `` method on the OpenCV ``Mat `` class to get a raw
127- pointer to the underlying data (just like ``.data <float>() `` for the PyTorch
127+ pointer to the underlying data (just like ``.data_ptr <float>() `` for the PyTorch
128128tensor earlier). We also specify the output shape of the tensor, which we
129129hardcoded as ``8 x 8 ``. The output of ``torch::from_blob `` is then a
130130``torch::Tensor ``, pointing to the memory owned by the OpenCV matrix.
@@ -145,40 +145,28 @@ Registering the Custom Operator with TorchScript
145145Now that have implemented our custom operator in C++, we need to *register * it
146146with the TorchScript runtime and compiler. This will allow the TorchScript
147147compiler to resolve references to our custom operator in TorchScript code.
148- Registration is very simple. For our case, we need to write:
148+ If you have ever used the pybind11 library, our syntax for registration
149+ resembles the pybind11 syntax very closely. To register a single function,
150+ we write:
149151
150152.. literalinclude :: ../advanced_source/torch_script_custom_ops/op.cpp
151153 :language: cpp
152154 :start-after: BEGIN registry
153155 :end-before: END registry
154156
155- somewhere in the global scope of our ``op.cpp `` file. This creates a global
156- variable ``registry ``, which will register our operator with TorchScript in its
157- constructor (i.e. exactly once per program). We specify the name of the
158- operator, and a pointer to its implementation (the function we wrote earlier).
159- The name consists of two parts: a *namespace * (``my_ops ``) and a name for the
160- particular operator we are registering (``warp_perspective ``). The namespace and
161- operator name are separated by two colons (``:: ``).
157+ somewhere at the top level of our ``op.cpp `` file. The ``TORCH_LIBRARY `` macro
158+ creates a function that will be called when your program starts. The name
159+ of your library (``my_ops ``) is given as the first argument (it should not
160+ be in quotes). The second argument (``m ``) defines a variable of type
161+ ``torch::Library `` which is the main interface to register your operators.
162+ The method ``Library::def `` actually creates an operator named ``warp_perspective ``,
163+ exposing it to both Python and TorchScript. You can define as many operators
164+ as you like by making multiple calls to ``def ``.
162165
163- .. tip ::
164-
165- If you want to register more than one operator, you can chain calls to
166- ``.op() `` after the constructor:
167-
168- .. code-block :: cpp
169-
170- static auto registry =
171- torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective)
172- .op("my_ops::another_op", &another_op)
173- .op("my_ops::and_another_op", &and_another_op);
174-
175- Behind the scenes, ``RegisterOperators `` will perform a number of fairly
176- complicated C++ template metaprogramming magic tricks to infer the argument and
177- return value types of the function pointer we pass it (``&warp_perspective ``).
178- This information is used to form a *function schema * for our operator. A
179- function schema is a structured representation of an operator -- a kind of
180- "signature" or "prototype" -- used by the TorchScript compiler to verify
181- correctness in TorchScript programs.
166+ Behinds the scenes, the ``def `` function is actually doing quite a bit of work:
167+ it is using template metaprogramming to inspect the type signature of your
168+ function and translate it into an operator schema which specifies the operators
169+ type within TorchScript's type system.
182170
183171Building the Custom Operator
184172----------------------------
@@ -189,7 +177,16 @@ we can load into Python for research and experimentation, or into C++ for
189177inference in a no-Python environment. There exist multiple ways to build our
190178operator, using either pure CMake, or Python alternatives like ``setuptools ``.
191179For brevity, the paragraphs below only discuss the CMake approach. The appendix
192- of this tutorial dives into the Python based alternatives.
180+ of this tutorial dives into other alternatives.
181+
182+ Environment setup
183+ *****************
184+
185+ We need an installation of PyTorch and OpenCV. The easiest and most platform
186+ independent way to get both is to via Conda::
187+
188+ conda install -c pytorch pytorch
189+ conda install opencv
193190
194191Building with CMake
195192*******************
@@ -203,29 +200,11 @@ a directory structure that looks like this::
203200 op.cpp
204201 CMakeLists.txt
205202
206- Also, make sure to grab the latest version of the LibTorch distribution, which
207- packages PyTorch's C++ libraries and CMake build files, from `pytorch.org
208- <https://pytorch.org/get-started/locally> `_. Place the unzipped distribution
209- somewhere accessible in your file system. The following paragraphs will refer to
210- that location as ``/path/to/libtorch ``. The contents of our ``CMakeLists.txt ``
211- file should then be the following:
203+ The contents of our ``CMakeLists.txt `` file should then be the following:
212204
213205.. literalinclude :: ../advanced_source/torch_script_custom_ops/CMakeLists.txt
214206 :language: cpp
215207
216- .. warning ::
217-
218- This setup makes some assumptions about the build environment, particularly
219- what pertains to the installation of OpenCV. The above ``CMakeLists.txt `` file
220- was tested inside a Docker container running Ubuntu Xenial with
221- ``libopencv-dev `` installed via ``apt ``. If it does not work for you and you
222- feel stuck, please use the ``Dockerfile `` in the `accompanying tutorial
223- repository <https://github.com/pytorch/extension-script> `_ to
224- build an isolated, reproducible environment in which to play around with the
225- code from this tutorial. If you run into further troubles, please file an
226- issue in the tutorial repository or post a question in `our forum
227- <https://discuss.pytorch.org/> `_.
228-
229208To now build our operator, we can run the following commands from our
230209``warp_perspective `` folder:
231210
@@ -268,24 +247,18 @@ To now build our operator, we can run the following commands from our
268247 [100%] Built target warp_perspective
269248
270249which will place a ` ` libwarp_perspective.so` ` shared library file in the
271- ` ` build` ` folder. In the ` ` cmake` ` command above, you should replace
272- ` ` /path/to/libtorch` ` with the path to your unzipped LibTorch distribution.
250+ ` ` build` ` folder. In the ` ` cmake` ` command above, we use the helper
251+ variable ` ` torch.utils.cmake_prefix_path` ` to conveniently tell us where
252+ the cmake files for our PyTorch install are.
273253
274254We will explore how to use and call our operator in detail further below, but to
275255get an early sensation of success, we can try running the following code in
276256Python:
277257
278- .. code-block:: python
279-
280- >>> import torch
281- >>> torch.ops.load_library(" /path/to/libwarp_perspective.so" )
282- >>> print(torch.ops.my_ops.warp_perspective)
283-
284- Here, ` ` /path/to/libwarp_perspective.so` ` should be a relative or absolute path
285- to the ` ` libwarp_perspective.so` ` shared library we just built. If all goes
286- well, this should print something like
258+ .. literalinclude:: ../advanced_source/torch_script_custom_ops/smoke_test.py
259+ :language: python
287260
288- .. code-block:: python
261+ If all goes well, this should print something like::
289262
290263 < built-in method my_ops::warp_perspective of PyCapsule object at 0x7f618fc6fa50>
291264
@@ -302,10 +275,9 @@ TorchScript code.
302275You already saw how to import your operator into Python:
303276``torch.ops.load_library ()` ` . This function takes the path to a shared library
304277containing custom operators, and loads it into the current process. Loading the
305- shared library will also execute the constructor of the global
306- ` ` RegisterOperators` ` object we placed into our custom operator implementation
307- file. This will register our custom operator with the TorchScript compiler and
308- allow us to use that operator in TorchScript code.
278+ shared library will also execute the ` ` TORCH_LIBRARY` ` block. This will register
279+ our custom operator with the TorchScript compiler and allow us to use that
280+ operator in TorchScript code.
309281
310282You can refer to your loaded operator as ` ` torch.ops.< namespace> .< function> ` ` ,
311283where ` ` < namespace> ` ` is the namespace part of your operator name, and
@@ -316,11 +288,16 @@ While this function can be used in scripted or traced TorchScript modules, we
316288can also just use it in vanilla eager PyTorch and pass it regular PyTorch
317289tensors:
318290
291+ .. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
292+ :language: python
293+ :prepend: import torch
294+ :start-after: BEGIN preamble
295+ :end-before: END preamble
296+
297+ producing:
298+
319299.. code-block:: python
320300
321- >>> import torch
322- >>> torch.ops.load_library(" libwarp_perspective.so" )
323- >>> torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3))
324301 tensor([[0.0000, 0.3218, 0.4611, ..., 0.4636, 0.4636, 0.4636],
325302 [0.3746, 0.0978, 0.5005, ..., 0.4636, 0.4636, 0.4636],
326303 [0.3245, 0.0169, 0.0000, ..., 0.4458, 0.4458, 0.4458],
@@ -332,90 +309,92 @@ tensors:
332309
333310.. note::
334311
335- What happens behind the scenes is that the first time you access
336- ` ` torch.ops.namespace.function` ` in Python, the TorchScript compiler (in C++
337- land) will see if a function ``namespace::function`` has been registered, and
338- if so, return a Python handle to this function that we can subsequently use to
339- call into our C++ operator implementation from Python. This is one noteworthy
340- difference between TorchScript custom operators and C++ extensions: C++
341- extensions are bound manually using pybind11, while TorchScript custom ops are
342- bound on the fly by PyTorch itself. Pybind11 gives you more flexibility with
343- regards to what types and classes you can bind into Python and is thus
344- recommended for purely eager code, but it is not supported for TorchScript
345- ops.
312+ What happens behind the scenes is that the first time you access
313+ ` ` torch.ops.namespace.function` ` in Python, the TorchScript compiler (in C++
314+ land) will see if a function ``namespace::function`` has been registered, and
315+ if so, return a Python handle to this function that we can subsequently use to
316+ call into our C++ operator implementation from Python. This is one noteworthy
317+ difference between TorchScript custom operators and C++ extensions: C++
318+ extensions are bound manually using pybind11, while TorchScript custom ops are
319+ bound on the fly by PyTorch itself. Pybind11 gives you more flexibility with
320+ regards to what types and classes you can bind into Python and is thus
321+ recommended for purely eager code, but it is not supported for TorchScript
322+ ops.
346323
347324From here on, you can use your custom operator in scripted or traced code just
348325as you would other functions from the ` ` torch` ` package. In fact, " standard
349326library" functions like ` ` torch.matmul` ` go through largely the same
350327registration path as custom operators, which makes custom operators really
351328first-class citizens when it comes to how and where they can be used in
352- TorchScript.
329+ TorchScript. (One difference, however, is that standard library functions
330+ have custom written Python argument parsing logic that differs from
331+ ` ` torch.ops` ` argument parsing.)
353332
354333Using the Custom Operator with Tracing
355334**************************************
356335
357336Let' s start by embedding our operator in a traced function. Recall that for
358337tracing, we start with some vanilla Pytorch code:
359338
360- .. code-block :: python
361-
362- def compute(x, y, z):
363- return x.matmul(y) + torch.relu(z)
339+ .. literalinclude :: ../advanced_source/torch_script_custom_ops/test.py
340+ :language: python
341+ :start-after: BEGIN compute
342+ :end-before: END compute
364343
365344and then call ``torch.jit.trace`` on it. We further pass ``torch.jit.trace``
366345some example inputs, which it will forward to our implementation to record the
367346sequence of operations that occur as the inputs flow through it. The result of
368347this is effectively a "frozen" version of the eager PyTorch program, which the
369348TorchScript compiler can further analyze, optimize and serialize:
370349
371- .. code-block:: python
350+ .. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
351+ :language: python
352+ :start-after: BEGIN trace
353+ :end-before: END trace
372354
373- >>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
374- >>> trace = torch.jit.trace(compute, inputs)
375- >>> print(trace.graph)
376- graph(%x : Float(4, 8)
377- %y : Float(8, 5)
378- %z : Float(4, 5)) {
379- %3 : Float(4, 5) = aten::matmul(%x, %y)
380- %4 : Float(4, 5) = aten::relu(%z)
381- %5 : int = prim::Constant[value=1]()
382- %6 : Float(4, 5) = aten::add(%3, %4, %5)
383- return (%6);
384- }
355+ Producing::
356+
357+ graph(%x : Float(4:8, 8:1),
358+ %y : Float(8:5, 5:1),
359+ %z : Float(4:5, 5:1)):
360+ %3 : Float(4:5, 5:1) = aten::matmul(%x, %y) # test.py:10:0
361+ %4 : Float(4:5, 5:1) = aten::relu(%z) # test.py:10:0
362+ %5 : int = prim::Constant[value=1]() # test.py:10:0
363+ %6 : Float(4:5, 5:1) = aten::add(%3, %4, %5) # test.py:10:0
364+ return (%6)
385365
386366Now, the exciting revelation is that we can simply drop our custom operator into
387367our PyTorch trace as if it were ``torch.relu`` or any other ``torch`` function:
388368
389- .. code-block:: python
390-
391- torch.ops.load_library("libwarp_perspective.so")
392-
393- def compute(x, y, z):
394- x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
395- return x.matmul(y) + torch.relu(z)
369+ .. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
370+ :language: python
371+ :start-after: BEGIN compute2
372+ :end-before: END compute2
396373
397374and then trace it as before:
398375
399- .. code-block:: python
400-
401- >>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(8, 5)]
402- >>> trace = torch.jit.trace(compute, inputs)
403- >>> print(trace.graph)
404- graph(%x.1 : Float(4, 8)
405- %y : Float(8, 5)
406- %z : Float(8, 5)) {
407- %3 : int = prim::Constant[value=3]()
408- %4 : int = prim::Constant[value=6]()
409- %5 : int = prim::Constant[value=0]()
410- %6 : int[] = prim::Constant[value=[0, -1]]()
411- %7 : Float(3, 3) = aten::eye(%3, %4, %5, %6)
412- %x : Float(8, 8) = my_ops::warp_perspective(%x.1, %7)
413- %11 : Float(8, 5) = aten::matmul(%x, %y)
414- %12 : Float(8, 5) = aten::relu(%z)
415- %13 : int = prim::Constant[value=1]()
416- %14 : Float(8, 5) = aten::add(%11, %12, %13)
417- return (%14);
418- }
376+ .. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
377+ :language: python
378+ :start-after: BEGIN trace2
379+ :end-before: END trace2
380+
381+ Producing::
382+
383+ graph(%x.1 : Float(4:8, 8:1),
384+ %y : Float(8:5, 5:1),
385+ %z : Float(8:5, 5:1)):
386+ %3 : int = prim::Constant[value=3]() # test.py:25:0
387+ %4 : int = prim::Constant[value=6]() # test.py:25:0
388+ %5 : int = prim::Constant[value=0]() # test.py:25:0
389+ %6 : Device = prim::Constant[value="cpu"]() # test.py:25:0
390+ %7 : bool = prim::Constant[value=0]() # test.py:25:0
391+ %8 : Float(3:3, 3:1) = aten::eye(%3, %4, %5, %6, %7) # test.py:25:0
392+ %x : Float(8:8, 8:1) = my_ops::warp_perspective(%x.1, %8) # test.py:25:0
393+ %10 : Float(8:5, 5:1) = aten::matmul(%x, %y) # test.py:26:0
394+ %11 : Float(8:5, 5:1) = aten::relu(%z) # test.py:26:0
395+ %12 : int = prim::Constant[value=1]() # test.py:26:0
396+ %13 : Float(8:5, 5:1) = aten::add(%10, %11, %12) # test.py:26:0
397+ return (%13)
419398
420399Integrating TorchScript custom ops into traced PyTorch code is as easy as this!
421400
@@ -947,8 +926,9 @@ custom TorchScript operator as a string. For this, use
947926 return output.clone();
948927 }
949928
950- static auto registry =
951- torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
929+ TORCH_LIBRARY(my_ops, m) {
930+ m.def("warp_perspective", &warp_perspective);
931+ }
952932 """
953933
954934 torch.utils.cpp_extension.load_inline(
0 commit comments