@@ -84,6 +84,11 @@ def _initialize_weights(self):
8484# was not trained fully for good accuracy and is used here for
8585# demonstration purposes only.
8686#
87+ # It is important to call ``torch_model.eval()`` or ``torch_model.train(False)``
88+ # before exporting the model, to turn the model to inference mode.
89+ # This is required since operators like dropout or batchnorm behave
90+ # differently in inference and training mode.
91+ #
8792
8893# Load pretrained model weights
8994model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
@@ -95,8 +100,8 @@ def _initialize_weights(self):
95100 map_location = None
96101torch_model .load_state_dict (model_zoo .load_url (model_url , map_location = map_location ))
97102
98- # set the train mode to false since we will only run the forward pass.
99- torch_model .train ( False )
103+ # set the model to inference mode
104+ torch_model .eval ( )
100105
101106
102107######################################################################
@@ -142,7 +147,11 @@ def _initialize_weights(self):
142147# the same values when run in onnxruntime.
143148#
144149# But before verifying the model's output with onnxruntime, we will check
145- # the onnx model with onnx's API. This will verify the model's structure
150+ # the onnx model with onnx's API.
151+ # First, ``onnx.load("super_resolution.onnx")`` will load the saved model and
152+ # will output a onnx.ModelProto structure (a top-level file/container format for bundling a ML model.
153+ # For more information `onnx.proto documentation <https://github.com/onnx/onnx/blob/master/onnx/onnx.proto>`__.).
154+ # Then, ``onnx.checker.check_model(onnx_model)`` will verify the model's structure
146155# and confirm that the model has a valid schema.
147156# The validity of the ONNX graph is verified by checking the model's
148157# version, the graph's structure, as well as the nodes and their inputs
@@ -196,28 +205,8 @@ def to_numpy(tensor):
196205
197206
198207######################################################################
199- # Transfering SRResNet using ONNX
200- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201- #
202-
203-
204- ######################################################################
205- # Using the same process as above, we also transferred an interesting new
206- # model "SRResNet" for super-resolution presented in `this
207- # paper <https://arxiv.org/pdf/1609.04802.pdf>`__ (thanks to the authors
208- # at Twitter for providing us code and pretrained parameters for the
209- # purpose of this tutorial). The model definition and a pre-trained model
210- # can be found
211- # `here <https://gist.github.com/prigoyal/b245776903efbac00ee89699e001c9bd>`__.
212- # Below is what SRResNet model input, output looks like. |SRResNet|
213- #
214- # .. |SRResNet| image:: /_static/img/SRResNet.png
215- #
216-
217-
218- ######################################################################
219- # Running the model using ONNXRuntime
220- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
208+ # Running the model on an image using ONNXRuntime
209+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
221210#
222211
223212
0 commit comments