Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
16 changes: 11 additions & 5 deletions testdata/dnn/onnx/generate_quantized_onnx_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch.nn.functional as F
import numpy as np
import os
import onnx
import onnx # version >= 1.12.0
import onnxruntime as rt
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat

class DataReader(CalibrationDataReader):
def __init__(self, model_path, batchsize=5):
Expand All @@ -20,16 +20,16 @@ def __init__(self, model_path, batchsize=5):
def get_next(self):
return next(self.enum_data_dicts, None)

def quantize_and_save_model(name, input, model, act_type="uint8", wt_type="uint8", per_channel=False):
def quantize_and_save_model(name, input, model, act_type="uint8", wt_type="uint8", per_channel=False, ops_version = 13, quanFormat=QuantFormat.QOperator):
float_model_path = os.path.join("models", "dummy.onnx")
quantized_model_path = os.path.join("models", name + ".onnx")
type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}

model.eval()
torch.onnx.export(model, input, float_model_path, export_params=True, opset_version=12)
torch.onnx.export(model, input, float_model_path, export_params=True, opset_version=ops_version)

dr = DataReader(float_model_path)
quantize_static(float_model_path, quantized_model_path, dr, per_channel=per_channel,
quantize_static(float_model_path, quantized_model_path, dr, quant_format=quanFormat, per_channel=per_channel,
activation_type=type_dict[act_type], weight_type=type_dict[wt_type])

os.remove(float_model_path)
Expand All @@ -53,10 +53,16 @@ def quantize_and_save_model(name, input, model, act_type="uint8", wt_type="uint8

input = Variable(torch.randn(1, 3, 10, 10))
conv = nn.Conv2d(3, 5, kernel_size=3, stride=2, padding=1)
# generate QOperator qunatized model
quantize_and_save_model("quantized_conv_uint8_weights", input, conv)
quantize_and_save_model("quantized_conv_int8_weights", input, conv, wt_type="int8")
quantize_and_save_model("quantized_conv_per_channel_weights", input, conv, per_channel=True)

# generate QDQ qunatized model
quantize_and_save_model("quantized_conv_uint8_weights_qdq", input, conv, quanFormat=QuantFormat.QDQ)
quantize_and_save_model("quantized_conv_int8_weights_qdq", input, conv, wt_type="int8", quanFormat=QuantFormat.QDQ)
quantize_and_save_model("quantized_conv_per_channel_weights_qdq", input, conv, per_channel=True, quanFormat=QuantFormat.QDQ)

input = Variable(torch.randn(1, 3))
linear = nn.Linear(3, 4, bias=True)
quantize_and_save_model("quantized_matmul_uint8_weights", input, linear)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.