@@ -31,12 +31,14 @@ Defining schema and backend implementations
3131-------------------------------------------
3232
3333The general principle behind the dispatcher is that it divides the
34- implementation of an operator into multiple kernels, each of which
35- implements functionality for a specific *dispatch key *; for example,
36- CPU, CUDA or Autograd. The end effect is that when you call
37- an operator, we first execute the Autograd kernel, and then we
38- redispatch to the CPU or CUDA kernel depending on the device
39- types of the passed in tensors.
34+ implementation of an operator into multiple kernels, each of which implements
35+ functionality for a specific *dispatch key *; for example, CPU, CUDA or Autograd.
36+ The dispatcher determines what the highest priority dispatch key is at the time
37+ you call an operator (this is done by looking at both the tensor arguments as
38+ well as some thread local state), and transfers control to the kernel for that
39+ dispatch key. The end effect is that when you call an operator, we first
40+ execute the Autograd kernel, and then we redispatch to the CPU or CUDA kernel
41+ depending on the device types of the passed in tensors.
4042
4143Let's take a look at the various parts involved in making this
4244happen. First, we must define the schema for the operator in question.
@@ -58,10 +60,12 @@ For concreteness, here is a really simple implementation of addition on CPU:
5860 :start-after: BEGIN myadd_cpu
5961 :end-before: END myadd_cpu
6062
61- We'd like to register this function as an implementation of ``myops::myadd ``, but we
62- don't want to register it as a catch-all kernel to be run in all cases; we
63- only want it to be run when we call ``myops::myadd `` at the backend on CPU tensors.
64- To do this, we can use the ``TORCH_LIBRARY_IMPL `` macro:
63+ We'd like to register this function as an implementation of ``myops::myadd ``.
64+ However, the simple way of registering it (``def("myadd", myadd_cpu) ``) would
65+ register the kernel to run in all cases, even if the tensor is not a CPU
66+ tensor! (Internally, we refer to these as "catch-all" kernels, since they
67+ catch all cases.) To ensure that ``myadd_cpu `` is only run for
68+ CPU tensors, we can use the ``TORCH_LIBRARY_IMPL `` macro:
6569
6670.. literalinclude :: ../advanced_source/dispatcher/op.cpp
6771 :language: cpp
@@ -71,10 +75,8 @@ To do this, we can use the ``TORCH_LIBRARY_IMPL`` macro:
7175The ``TORCH_LIBRARY_IMPL `` lets us register implementations for operators on
7276a specific dispatch key (in this case, CPU). Each call to ``impl ``
7377associates a CPU kernel with the corresponding operator (which we previously
74- defined in the ``TORCH_LIBRARY `` block). You can have as many
75- ``TORCH_LIBRARY_IMPL `` blocks for a namespace as you like; so for example,
76- if we also have a CUDA implementation ``myadd_cuda ``, we can register it
77- with:
78+ defined in the ``TORCH_LIBRARY `` block). If we also have a CUDA implementation ``myadd_cuda ``,
79+ we can register it in a separate ``TORCH_LIBRARY_IMPL `` block:
7880
7981.. literalinclude :: ../advanced_source/dispatcher/op.cpp
8082 :language: cpp
8385
8486These registrations can be split across files or even across library boundaries; so
8587for example, you could have these two ``TORCH_LIBRARY_IMPL `` blocks compiled
86- into a separate ``myops_cpu `` and ``myops_cuda `` dynamic library.
88+ into a separate ``myops_cpu `` and ``myops_cuda `` dynamic libraries. Generally,
89+ speaking, the structure of your registrations will look like this:
90+
91+ 1. A single ``TORCH_LIBRARY `` that lists every custom operator in your namespace
92+ in a centralized place.
93+ 2. A ``TORCH_LIBRARY_IMPL `` per dispatch key that registers implementations for
94+ that key (e.g., CPU or CUDA). If you like, you can further subdivide
95+ ``TORCH_LIBRARY_IMPL `` blocks into a block per operator. This is convenient
96+ if you have a separate file per operator implementation, but don't want to
97+ expose the operators in a header; you can just put the registration in the
98+ cpp file that defines your operator.
8799
88100.. note ::
89101
1521642. Call the dispatch function ``myadd `` to call back into the dispatcher.
153165
154166Without (1), your calls will infinite loop (and stack overflow), because
155- ``myadd `` will send you back to the autograd implementation! With (1),
156- the redispatch will skip over autograd and go to the next handlers,
157- which will either be CPU and CUDA.
167+ ``myadd `` will send you back to this function (as the highest priority dispatch
168+ key would still be autograd.) With (1),
169+ autograd is excluded from the set of dispatch keys under consideration, and
170+ we will go to the next handlers, which will either be CPU and CUDA.
158171
159172We can now register this function in the same way we registered the CPU/CUDA
160173functions:
0 commit comments