Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 51 additions & 44 deletions advanced_source/cpp_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,30 @@ vanilla Pytorch model::

Because the ``forward`` method of this module uses control flow that is
dependent on the input, it is not suitable for tracing. Instead, we can convert
it to a ``ScriptModule`` by subclassing it from ``torch.jit.ScriptModule`` and
adding a ``@torch.jit.script_method`` annotation to the model's ``forward``
method::

import torch

class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))

@torch.jit.script_method
def forward(self, input):
if bool(input.sum() > 0):
output = self.weight.mv(input)
else:
output = self.weight + input
return output

my_script_module = MyModule(2, 3)

Creating a new ``MyModule`` object now directly produces an instance of
it to a ``ScriptModule``.
In order to convert the module to the ``ScriptModule``, one needs to
compile the module with ``torch.jit.script`` as follows::

class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))

def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

If you need to exclude some methods in your ``nn.Module``
because they use Python features that TorchScript doesn't support yet,
you could annotate those with ``@torch.jit.ignore``

``my_module`` is an instance of
``ScriptModule`` that is ready for serialization.

Step 2: Serializing Your Script Module to a File
Expand Down Expand Up @@ -152,32 +154,38 @@ do:

.. code-block:: cpp

#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}


torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}

std::cout << "ok\n";
}

// Deserialize the ScriptModule from a file using torch::jit::load().
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

assert(module != nullptr);
std::cout << "ok\n";
}

The ``<torch/script.h>`` header encompasses all relevant includes from the
LibTorch library necessary to run the example. Our application accepts the file
path to a serialized PyTorch ``ScriptModule`` as its only command line argument
and then proceeds to deserialize the module using the ``torch::jit::load()``
function, which takes this file path as input. In return we receive a shared
pointer to a ``torch::jit::script::Module``, the equivalent to a
``torch.jit.ScriptModule`` in C++. For now, we only verify that this pointer is
not null. We will examine how to execute it in a moment.
function, which takes this file path as input. In return we receive a ``torch::jit::script::Module``
object. We will examine how to execute it in a moment.

Depending on LibTorch and Building the Application
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -300,8 +308,7 @@ application's ``main()`` function:
inputs.push_back(torch::ones({1, 3, 224, 224}));

// Execute the model and turn its output into a tensor.
at::Tensor output = module->forward(inputs).toTensor();

at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

The first two lines set up the inputs to our model. We create a vector of
Expand Down Expand Up @@ -344,7 +351,7 @@ Looks like a good match!

.. tip::

To move your model to GPU memory, you can write ``model->to(at::kCUDA);``.
To move your model to GPU memory, you can write ``model.to(at::kCUDA);``.
Make sure the inputs to a model living in CUDA memory are also in CUDA memory
by calling ``tensor.to(at::kCUDA)``, which will return a new tensor in CUDA
memory.
Expand Down