@@ -92,28 +92,30 @@ vanilla Pytorch model::
9292
9393Because the ``forward `` method of this module uses control flow that is
9494dependent on the input, it is not suitable for tracing. Instead, we can convert
95- it to a ``ScriptModule `` by subclassing it from ``torch.jit.ScriptModule `` and
96- adding a ``@torch.jit.script_method `` annotation to the model's ``forward ``
97- method::
98-
99- import torch
100-
101- class MyModule(torch.jit.ScriptModule):
102- def __init__(self, N, M):
103- super(MyModule, self).__init__()
104- self.weight = torch.nn.Parameter(torch.rand(N, M))
105-
106- @torch.jit.script_method
107- def forward(self, input):
108- if bool(input.sum() > 0):
109- output = self.weight.mv(input)
110- else:
111- output = self.weight + input
112- return output
113-
114- my_script_module = MyModule(2, 3)
115-
116- Creating a new ``MyModule `` object now directly produces an instance of
95+ it to a ``ScriptModule ``.
96+ In order to convert the module to the ``ScriptModule ``, one needs to
97+ compile the module with ``torch.jit.script `` as follows::
98+
99+ class MyModule(torch.nn.Module):
100+ def __init__(self, N, M):
101+ super(MyModule, self).__init__()
102+ self.weight = torch.nn.Parameter(torch.rand(N, M))
103+
104+ def forward(self, input):
105+ if input.sum() > 0:
106+ output = self.weight.mv(input)
107+ else:
108+ output = self.weight + input
109+ return output
110+
111+ my_module = MyModule(10,20)
112+ sm = torch.jit.script(my_module)
113+
114+ If you need to exclude some methods in your ``nn.Module ``
115+ because they use Python features that TorchScript doesn't support yet,
116+ you could annotate those with ``@torch.jit.ignore ``
117+
118+ ``my_module `` is an instance of
117119``ScriptModule `` that is ready for serialization.
118120
119121Step 2: Serializing Your Script Module to a File
@@ -152,32 +154,38 @@ do:
152154
153155.. code-block :: cpp
154156
155- #include <torch/script.h> // One-stop header.
156-
157- #include <iostream>
158- #include <memory>
159-
160- int main(int argc, const char* argv[]) {
161- if (argc != 2) {
162- std::cerr << "usage: example-app <path-to-exported-script-module>\n";
163- return -1;
157+ #include <torch/script.h> // One-stop header.
158+
159+ #include <iostream>
160+ #include <memory>
161+
162+ int main(int argc, const char* argv[]) {
163+ if (argc != 2) {
164+ std::cerr << "usage: example-app <path-to-exported-script-module>\n";
165+ return -1;
166+ }
167+
168+
169+ torch::jit::script::Module module;
170+ try {
171+ // Deserialize the ScriptModule from a file using torch::jit::load().
172+ module = torch::jit::load(argv[1]);
173+ }
174+ catch (const c10::Error& e) {
175+ std::cerr << "error loading the model\n";
176+ return -1;
177+ }
178+
179+ std::cout << "ok\n";
164180 }
165181
166- // Deserialize the ScriptModule from a file using torch::jit::load().
167- std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
168-
169- assert(module != nullptr);
170- std::cout << "ok\n";
171- }
172182
173183 The ``<torch/script.h> `` header encompasses all relevant includes from the
174184LibTorch library necessary to run the example. Our application accepts the file
175185path to a serialized PyTorch ``ScriptModule `` as its only command line argument
176186and then proceeds to deserialize the module using the ``torch::jit::load() ``
177- function, which takes this file path as input. In return we receive a shared
178- pointer to a ``torch::jit::script::Module ``, the equivalent to a
179- ``torch.jit.ScriptModule `` in C++. For now, we only verify that this pointer is
180- not null. We will examine how to execute it in a moment.
187+ function, which takes this file path as input. In return we receive a ``torch::jit::script::Module ``
188+ object. We will examine how to execute it in a moment.
181189
182190Depending on LibTorch and Building the Application
183191^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -300,8 +308,7 @@ application's ``main()`` function:
300308 inputs.push_back(torch::ones({1, 3, 224, 224}));
301309
302310 // Execute the model and turn its output into a tensor.
303- at::Tensor output = module->forward(inputs).toTensor ();
304-
311+ at::Tensor output = module.forward(inputs).toTensor ();
305312 std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
306313
307314The first two lines set up the inputs to our model. We create a vector of
@@ -344,7 +351,7 @@ Looks like a good match!
344351
345352.. tip::
346353
347- To move your model to GPU memory, you can write ` ` model-> to(at::kCUDA);` ` .
354+ To move your model to GPU memory, you can write ` ` model. to(at::kCUDA);` ` .
348355 Make sure the inputs to a model living in CUDA memory are also in CUDA memory
349356 by calling ` ` tensor.to(at::kCUDA)` ` , which will return a new tensor in CUDA
350357 memory.
0 commit comments