Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions advanced_source/cpp_extension.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,22 @@ For the "ahead of time" flavor, we build our C++ extension by writing a
``setup.py`` script that uses setuptools to compile our C++ code. For the LLTM, it
looks as simple as this::

from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension
from setuptools import setup, Extension
from torch.utils import cpp_extension

setup(name='lltm_cpp',
ext_modules=[CppExtension('lltm', ['lltm.cpp'])],
cmdclass={'build_ext': BuildExtension})

ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})

In this code, :class:`CppExtension` is a convenience wrapper around
:class:`setuptools.Extension` that passes the correct include paths and sets
the language of the extension to C++. The equivalent vanilla :mod:`setuptools`
code would simply be::

setuptools.Extension(
Extension(
name='lltm_cpp',
sources=['lltm.cpp'],
include_dirs=torch.utils.cpp_extension.include_paths(),
include_dirs=cpp_extension.include_paths(),
language='c++')

:class:`BuildExtension` performs a number of required configuration steps and
Expand Down Expand Up @@ -413,7 +412,7 @@ see::
If we call ``help()`` on the function or module, we can see that its signature
matches our C++ code::

In[4] help(lltm.forward)
In[4] help(lltm_cpp.forward)
forward(...) method of builtins.PyCapsule instance
forward(arg0: torch::Tensor, arg1: torch::Tensor, arg2: torch::Tensor, arg3: torch::Tensor, arg4: torch::Tensor) -> List[torch::Tensor]

Expand Down Expand Up @@ -473,6 +472,8 @@ small benchmark to see how much performance we gained from rewriting our op in
C++. We'll run the LLTM forwards and backwards a few times and measure the
duration::

import time

import torch

batch_size = 16
Expand Down