Skip to content

Commit 28528e0

Browse files
author
Lara
committed
adress remaining comments
1 parent 8c4e5a6 commit 28528e0

File tree

1 file changed

+14
-25
lines changed

1 file changed

+14
-25
lines changed

advanced_source/super_resolution_with_onnxruntime.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def _initialize_weights(self):
8484
# was not trained fully for good accuracy and is used here for
8585
# demonstration purposes only.
8686
#
87+
# It is important to call ``torch_model.eval()`` or ``torch_model.train(False)``
88+
# before exporting the model, to turn the model to inference mode.
89+
# This is required since operators like dropout or batchnorm behave
90+
# differently in inference and training mode.
91+
#
8792

8893
# Load pretrained model weights
8994
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
@@ -95,8 +100,8 @@ def _initialize_weights(self):
95100
map_location = None
96101
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
97102

98-
# set the train mode to false since we will only run the forward pass.
99-
torch_model.train(False)
103+
# set the model to inference mode
104+
torch_model.eval()
100105

101106

102107
######################################################################
@@ -142,7 +147,11 @@ def _initialize_weights(self):
142147
# the same values when run in onnxruntime.
143148
#
144149
# But before verifying the model's output with onnxruntime, we will check
145-
# the onnx model with onnx's API. This will verify the model's structure
150+
# the onnx model with onnx's API.
151+
# First, ``onnx.load("super_resolution.onnx")`` will load the saved model and
152+
# will output a onnx.ModelProto structure (a top-level file/container format for bundling a ML model.
153+
# For more information `onnx.proto documentation <https://github.com/onnx/onnx/blob/master/onnx/onnx.proto>`__.).
154+
# Then, ``onnx.checker.check_model(onnx_model)`` will verify the model's structure
146155
# and confirm that the model has a valid schema.
147156
# The validity of the ONNX graph is verified by checking the model's
148157
# version, the graph's structure, as well as the nodes and their inputs
@@ -196,28 +205,8 @@ def to_numpy(tensor):
196205

197206

198207
######################################################################
199-
# Transfering SRResNet using ONNX
200-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201-
#
202-
203-
204-
######################################################################
205-
# Using the same process as above, we also transferred an interesting new
206-
# model "SRResNet" for super-resolution presented in `this
207-
# paper <https://arxiv.org/pdf/1609.04802.pdf>`__ (thanks to the authors
208-
# at Twitter for providing us code and pretrained parameters for the
209-
# purpose of this tutorial). The model definition and a pre-trained model
210-
# can be found
211-
# `here <https://gist.github.com/prigoyal/b245776903efbac00ee89699e001c9bd>`__.
212-
# Below is what SRResNet model input, output looks like. |SRResNet|
213-
#
214-
# .. |SRResNet| image:: /_static/img/SRResNet.png
215-
#
216-
217-
218-
######################################################################
219-
# Running the model using ONNXRuntime
220-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
208+
# Running the model on an image using ONNXRuntime
209+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
221210
#
222211

223212

0 commit comments

Comments
 (0)