1616and `onnxruntime <https://github.com/microsoft/onnxruntime>`__.
1717You can get binary builds of onnx and onnxrunimte with
1818``pip install onnx onnxruntime``.
19+ Note that ONNXRuntime is compatible with Python versions 3.5 to 3.7.
1920
2021``NOTE``: This tutorial needs PyTorch master branch which can be installed by following
2122the instructions `here <https://github.com/pytorch/pytorch#from-source>`__
@@ -101,40 +102,52 @@ def _initialize_weights(self):
101102######################################################################
102103# Exporting a model in PyTorch works via tracing or scripting. This
103104# tutorial will use as an example a model exported by tracing.
104- # To export a model, you call the ``torch.onnx.export()`` function.
105+ # To export a model, we call the ``torch.onnx.export()`` function.
105106# This will execute the model, recording a trace of what operators
106107# are used to compute the outputs.
107- # Because ``_export `` runs the model, we need to provide an input
108+ # Because ``export `` runs the model, we need to provide an input
108109# tensor ``x``. The values in this can be random as long as it is the
109110# right type and size.
110- #
111+ # Note that the input size will be fixed in the exported ONNX graph for
112+ # all the input's dimensions, unless specified as a dynamic axes.
113+ # In this example we export the model with an input of batch_size 1,
114+ # but then specify the first dimension as dynamic in the ``dynamic_axes``
115+ # parameter in ``torch.onnx.export()``.
116+ # The exported model will thus accept inputs of size [batch_size, 1, 224, 224]
117+ # where batch_size can be variable.
118+ #
111119# To learn more details about PyTorch's export interface, check out the
112120# `torch.onnx documentation <https://pytorch.org/docs/master/onnx.html>`__.
113121#
114122
115123# Input to the model
116124x = torch .randn (batch_size , 1 , 224 , 224 , requires_grad = True )
125+ torch_out = torch_model (x )
117126
118127# Export the model
119- torch_out = torch .onnx ._export (torch_model , # model being run
120- x , # model input (or a tuple for multiple inputs)
121- "super_resolution.onnx" , # where to save the model (can be a file or file-like object)
122- export_params = True , # store the trained parameter weights inside the model file
123- opset_version = 10 , # the onnx version to export the model to
124- do_constant_folding = True , # wether to execute constant folding for optimization
125- input_names = ['input' ], # the model's input names
126- output_names = ['output' ], # the model's output names
127- dynamic_axes = {'input' : {0 : 'batch_size' }, # variable lenght axes
128- 'output' : {0 : 'batch_size' }})
128+ torch .onnx .export (torch_model , # model being run
129+ x , # model input (or a tuple for multiple inputs)
130+ "super_resolution.onnx" , # where to save the model (can be a file or file-like object)
131+ export_params = True , # store the trained parameter weights inside the model file
132+ opset_version = 10 , # the onnx version to export the model to
133+ do_constant_folding = True , # wether to execute constant folding for optimization
134+ input_names = ['input' ], # the model's input names
135+ output_names = ['output' ], # the model's output names
136+ dynamic_axes = {'input' : {0 : 'batch_size' }, # variable lenght axes
137+ 'output' : {0 : 'batch_size' }})
129138
130139######################################################################
131- # ``torch_out`` is the output after executing the model. Normally you can
132- # ignore this output, but here we will use it to verify that the model we
133- # exported computes the same values when run in onnxruntime.
140+ # We also computed ``torch_out``, the output after of the model,
141+ # which we will use to verify that the model we exported computes
142+ # the same values when run in onnxruntime.
134143#
135144# But before verifying the model's output with onnxruntime, we will check
136145# the onnx model with onnx's API. This will verify the model's structure
137- # and confirm that the model has a valid schema.
146+ # and confirm that the model has a valid schema.
147+ # The validity of the ONNX graph is verified by checking the model's
148+ # version, the graph's structure, as well as the nodes and their inputs
149+ # and outputs.
150+ #
138151
139152import onnx
140153
@@ -143,10 +156,18 @@ def _initialize_weights(self):
143156
144157
145158######################################################################
146- # Now let's create an onnxruntime session. This part can normally be
147- # done in a separate process or on another machine, but we will
148- # continue in the same process so that we can verify that onnxruntime
149- # and PyTorch are computing the same value for the network:
159+ # Now let's compute the output using ONNXRuntime's Python APIs.
160+ # This part can normally be done in a separate process or on another
161+ # machine, but we will continue in the same process so that we can
162+ # verify that onnxruntime and PyTorch are computing the same value
163+ # for the network.
164+ #
165+ # In order to run the model with ONNXRuntime, we need to create an
166+ # inference session for the model with the chosen configuration
167+ # parameters (here we use the default config).
168+ # Once the session is created, we evaluate the model using the run() api.
169+ # The output of this call is a list containing the outputs of the model
170+ # computed by ONNXRuntime.
150171#
151172
152173import onnxruntime
@@ -217,6 +238,15 @@ def to_numpy(tensor):
217238# python library. Note that this preprocessing is the standard practice of
218239# processing data for training/testing neural networks.
219240#
241+ # We first resize the image to fit the size of the model's input (224x224).
242+ # Then we split the image into its Y, Cb, and Cr components.
243+ # These components represent a greyscale image (Y), and
244+ # the blue-difference (Cb) and red-difference (Cr) chroma components.
245+ # The Y component being more sensitive to the human eye, we are
246+ # interested in this component which we will be transforming.
247+ # After extracting the Y component, we convert it to a tensor which
248+ # will be the input of our model.
249+ #
220250
221251from PIL import Image
222252import torchvision .transforms as transforms
@@ -235,8 +265,9 @@ def to_numpy(tensor):
235265
236266
237267######################################################################
238- # Now, as a next step, let's take the resized cat image and run the
239- # super-resolution model in ONNXRuntime.
268+ # Now, as a next step, let's take the tensor representing the
269+ # greyscale resized cat image and run the super-resolution model in
270+ # ONNXRuntime as explained previously.
240271#
241272
242273ort_inputs = {ort_session .get_inputs ()[0 ].name : to_numpy (img_y )}
@@ -250,7 +281,7 @@ def to_numpy(tensor):
250281# final output image from the output tensor, and save the image.
251282# The post-processing steps have been adopted from PyTorch
252283# implementation of super-resolution model
253- # `here <https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py>`__
284+ # `here <https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py>`__.
254285#
255286
256287img_out_y = Image .fromarray (np .uint8 ((img_out_y [0 ] * 255.0 ).clip (0 , 255 )[0 ]), mode = 'L' )
0 commit comments