Skip to content

Commit 8c4e5a6

Browse files
author
Lara
committed
adress PR comments
1 parent c6ec282 commit 8c4e5a6

File tree

2 files changed

+56
-25
lines changed

2 files changed

+56
-25
lines changed

advanced_source/super_resolution_with_onnxruntime.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
and `onnxruntime <https://github.com/microsoft/onnxruntime>`__.
1717
You can get binary builds of onnx and onnxrunimte with
1818
``pip install onnx onnxruntime``.
19+
Note that ONNXRuntime is compatible with Python versions 3.5 to 3.7.
1920
2021
``NOTE``: This tutorial needs PyTorch master branch which can be installed by following
2122
the instructions `here <https://github.com/pytorch/pytorch#from-source>`__
@@ -101,40 +102,52 @@ def _initialize_weights(self):
101102
######################################################################
102103
# Exporting a model in PyTorch works via tracing or scripting. This
103104
# tutorial will use as an example a model exported by tracing.
104-
# To export a model, you call the ``torch.onnx.export()`` function.
105+
# To export a model, we call the ``torch.onnx.export()`` function.
105106
# This will execute the model, recording a trace of what operators
106107
# are used to compute the outputs.
107-
# Because ``_export`` runs the model, we need to provide an input
108+
# Because ``export`` runs the model, we need to provide an input
108109
# tensor ``x``. The values in this can be random as long as it is the
109110
# right type and size.
110-
#
111+
# Note that the input size will be fixed in the exported ONNX graph for
112+
# all the input's dimensions, unless specified as a dynamic axes.
113+
# In this example we export the model with an input of batch_size 1,
114+
# but then specify the first dimension as dynamic in the ``dynamic_axes``
115+
# parameter in ``torch.onnx.export()``.
116+
# The exported model will thus accept inputs of size [batch_size, 1, 224, 224]
117+
# where batch_size can be variable.
118+
#
111119
# To learn more details about PyTorch's export interface, check out the
112120
# `torch.onnx documentation <https://pytorch.org/docs/master/onnx.html>`__.
113121
#
114122

115123
# Input to the model
116124
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
125+
torch_out = torch_model(x)
117126

118127
# Export the model
119-
torch_out = torch.onnx._export(torch_model, # model being run
120-
x, # model input (or a tuple for multiple inputs)
121-
"super_resolution.onnx", # where to save the model (can be a file or file-like object)
122-
export_params=True, # store the trained parameter weights inside the model file
123-
opset_version=10, # the onnx version to export the model to
124-
do_constant_folding=True, # wether to execute constant folding for optimization
125-
input_names = ['input'], # the model's input names
126-
output_names = ['output'], # the model's output names
127-
dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes
128-
'output' : {0 : 'batch_size'}})
128+
torch.onnx.export(torch_model, # model being run
129+
x, # model input (or a tuple for multiple inputs)
130+
"super_resolution.onnx", # where to save the model (can be a file or file-like object)
131+
export_params=True, # store the trained parameter weights inside the model file
132+
opset_version=10, # the onnx version to export the model to
133+
do_constant_folding=True, # wether to execute constant folding for optimization
134+
input_names = ['input'], # the model's input names
135+
output_names = ['output'], # the model's output names
136+
dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes
137+
'output' : {0 : 'batch_size'}})
129138

130139
######################################################################
131-
# ``torch_out`` is the output after executing the model. Normally you can
132-
# ignore this output, but here we will use it to verify that the model we
133-
# exported computes the same values when run in onnxruntime.
140+
# We also computed ``torch_out``, the output after of the model,
141+
# which we will use to verify that the model we exported computes
142+
# the same values when run in onnxruntime.
134143
#
135144
# But before verifying the model's output with onnxruntime, we will check
136145
# the onnx model with onnx's API. This will verify the model's structure
137-
# and confirm that the model has a valid schema.
146+
# and confirm that the model has a valid schema.
147+
# The validity of the ONNX graph is verified by checking the model's
148+
# version, the graph's structure, as well as the nodes and their inputs
149+
# and outputs.
150+
#
138151

139152
import onnx
140153

@@ -143,10 +156,18 @@ def _initialize_weights(self):
143156

144157

145158
######################################################################
146-
# Now let's create an onnxruntime session. This part can normally be
147-
# done in a separate process or on another machine, but we will
148-
# continue in the same process so that we can verify that onnxruntime
149-
# and PyTorch are computing the same value for the network:
159+
# Now let's compute the output using ONNXRuntime's Python APIs.
160+
# This part can normally be done in a separate process or on another
161+
# machine, but we will continue in the same process so that we can
162+
# verify that onnxruntime and PyTorch are computing the same value
163+
# for the network.
164+
#
165+
# In order to run the model with ONNXRuntime, we need to create an
166+
# inference session for the model with the chosen configuration
167+
# parameters (here we use the default config).
168+
# Once the session is created, we evaluate the model using the run() api.
169+
# The output of this call is a list containing the outputs of the model
170+
# computed by ONNXRuntime.
150171
#
151172

152173
import onnxruntime
@@ -217,6 +238,15 @@ def to_numpy(tensor):
217238
# python library. Note that this preprocessing is the standard practice of
218239
# processing data for training/testing neural networks.
219240
#
241+
# We first resize the image to fit the size of the model's input (224x224).
242+
# Then we split the image into its Y, Cb, and Cr components.
243+
# These components represent a greyscale image (Y), and
244+
# the blue-difference (Cb) and red-difference (Cr) chroma components.
245+
# The Y component being more sensitive to the human eye, we are
246+
# interested in this component which we will be transforming.
247+
# After extracting the Y component, we convert it to a tensor which
248+
# will be the input of our model.
249+
#
220250

221251
from PIL import Image
222252
import torchvision.transforms as transforms
@@ -235,8 +265,9 @@ def to_numpy(tensor):
235265

236266

237267
######################################################################
238-
# Now, as a next step, let's take the resized cat image and run the
239-
# super-resolution model in ONNXRuntime.
268+
# Now, as a next step, let's take the tensor representing the
269+
# greyscale resized cat image and run the super-resolution model in
270+
# ONNXRuntime as explained previously.
240271
#
241272

242273
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
@@ -250,7 +281,7 @@ def to_numpy(tensor):
250281
# final output image from the output tensor, and save the image.
251282
# The post-processing steps have been adopted from PyTorch
252283
# implementation of super-resolution model
253-
# `here <https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py>`__
284+
# `here <https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py>`__.
254285
#
255286

256287
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')

custom_directives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def run(self):
9292
intro, _ = sphinx_gallery.gen_rst.extract_intro_and_title(abs_fname, blocks[0][1])
9393

9494
thumbnail_rst = sphinx_gallery.backreferences._thumbnail_div(
95-
dirname, dirname, basename, intro)
95+
dirname, basename, intro)
9696

9797
if 'figure' in self.options:
9898
rel_figname, figname = env.relfn2path(self.options['figure'])

0 commit comments

Comments
 (0)