Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions neural_compressor/adaptor/onnxrt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,16 @@
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'GatherElements': {
'weight': *uint8_asym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'GatherND': {
'weight': *uint8_asym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'MatMul': {
'weight': *int8_sym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor,
Expand Down Expand Up @@ -422,6 +432,7 @@
'Mod': *default_static_qlinear_qdq_minmax,
'ReduceMax': *default_static_qlinear_qdq_minmax,
'ReduceMin': *default_static_qlinear_qdq_minmax,
'Tile': *default_static_qlinear_qdq_minmax,
},
'dynamic': *ref_1_9_dynamic
}
Expand All @@ -436,6 +447,88 @@
recipes:
<<: *default_optimization

-
version:
name: '1.13.0'
int8: &ref_1_13 {
'static': {
'FusedConv': {
'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'Conv': {
'weight': *int8_sym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor,
'mode': ['QDQ', 'QLinear']
},
'Gather': {
'weight': *uint8_asym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'GatherElements': {
'weight': *uint8_asym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'GatherND': {
'weight': *uint8_asym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'MatMul': {
'weight': *int8_sym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor,
'mode': ['QDQ', 'QLinear']
},
'Gemm': {
'weight': *int8_sym_perchanneltensor_minmax,
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'EmbedLayerNormalization': {
'weight': *uint8_asym_pertensor_minmax, # QDQ: *int8_sym_pertensor_minmax
'activation': *uint8_asym_pertensor_minmax,
'mode': ['QDQ', 'QLinear']
},
'Attention': *default_static_qlinear_qdq_minmax,
'Mul': *default_static_qlinear,
'Relu': *default_static_qlinear_qdq_minmax,
'Clip': *default_static_qlinear_qdq_minmax,
'LeakyRelu': *default_static_qlinear_qdq_minmax,
'Sigmoid': *default_static_qlinear_qdq_minmax,
'MaxPool': *default_static_qlinear_qdq_minmax,
'GlobalAveragePool': *default_static_qlinear_qdq_minmax,
'Pad': *default_static_qlinear_qdq_minmax,
'Split': *default_static_qlinear_qdq_minmax,
'Add': *default_static_qlinear,
'Squeeze': *default_static_qlinear_qdq_minmax,
'Reshape': *default_static_qlinear_qdq_minmax,
'Concat': *default_static_qlinear_qdq_minmax,
'AveragePool': *default_static_qlinear_qdq_minmax,
'Unsqueeze': *default_static_qlinear_qdq_minmax,
'Transpose': *default_static_qlinear_qdq_minmax,
'ArgMax': *default_static_qlinear,
'Resize': *default_static_qlinear_qdq_minmax,
'Abs': *default_static_qlinear_qdq_minmax,
'Shrink': *default_static_qlinear_qdq_minmax,
'Sign': *default_static_qlinear_qdq_minmax,
'Flatten': *default_static_qlinear_qdq_minmax,
'Expand': *default_static_qlinear_qdq_minmax,
'Slice': *default_static_qlinear_qdq_minmax,
'Mod': *default_static_qlinear_qdq_minmax,
'ReduceMax': *default_static_qlinear_qdq_minmax,
'ReduceMin': *default_static_qlinear_qdq_minmax,
'Tile': *default_static_qlinear_qdq_minmax,
'CenterCropPad': *default_static_qlinear_qdq_minmax,
},
'dynamic': *ref_1_9_dynamic
}
weight_only_integer: *cap_weight_only
recipes:
<<: *default_optimization

-
version:
name: 'default'
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/adaptor/ox_utils/operators/direct_q8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@


@op_registry(
op_types="Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Expand, Slice, " "SpaceToDepth, DepthToSpace, Upsample"
op_types="Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Expand, Slice, "
"SpaceToDepth, DepthToSpace, Upsample, Tile, CenterCropPad"
)
class Direct8BitOperator(Operator):
"""Direct8Bit Operator."""
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/ox_utils/operators/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg


@op_registry(op_types="Gather")
@op_registry(op_types="Gather, GatherElements, GatherND")
class GatherOperator(Operator):
"""Gather Operator."""

Expand Down
254 changes: 254 additions & 0 deletions test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,260 @@ def test_reducemin_reducemax(self):
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

def test_tile(self):
# test Tile nodes: MatMul-Tile-MatMul
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 1])

matmul1_weight = helper.make_tensor(
"matmul1_weight", TensorProto.FLOAT, [1, 5], np.random.random((1, 5)).reshape(5).tolist()
)
matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [2, 3, 4, 5])
matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0")

repeats = helper.make_tensor("repeats", TensorProto.INT64, [4], [2, 2, 2, 2])
tile_output = helper.make_tensor_value_info("tile_output", TensorProto.FLOAT, [4, 6, 8, 10])
tile_node = onnx.helper.make_node(
"Tile",
["matmul1_output", "repeats"],
["tile_output"],
name="Tile_1",
)

matmul2_weight = helper.make_tensor(
"matmul2_weight", TensorProto.FLOAT, [10, 1], np.random.random((10, 1)).reshape(10).tolist()
)
matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [4, 6, 8, 1])
matmul2_node = onnx.helper.make_node(
"MatMul", ["tile_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2"
)

initializers = [matmul1_weight, matmul2_weight, repeats]
graph = helper.make_graph(
[matmul1_node, tile_node, matmul2_node],
"TestTile_test_model",
[input_tensor],
[matmul2_output],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 7

q_config = {"Matmul_0": self.static_q_config, "Tile_1": self.static_q_config, "Matmul_2": self.static_q_config}
quantize_params = {
"input": [np.uint8(10.0), np.float32(0)],
"matmul1_weight": [np.uint8(10.0), np.float32(0)],
"matmul1_output": [np.uint8(10.0), np.float32(0)],
"matmul2_weight": [np.uint8(10.0), np.float32(0)],
"matmul2_output": [np.uint8(10.0), np.float32(0)],
"tile_output": [np.uint8(10.0), np.float32(0)],
}
quantizable_op_types = ["MatMul", "Tile"]
q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

def test_centercroppad(self):
# test CenterCropPad nodes: MatMul-CenterCropPad-MatMul
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [20, 10, 1])

matmul1_weight = helper.make_tensor(
"matmul1_weight", TensorProto.FLOAT, [1, 3], np.random.random((1, 3)).reshape(3).tolist()
)
matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [20, 10, 3])
matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0")

centercroppad_output = helper.make_tensor_value_info("centercroppad_output", TensorProto.FLOAT, [10, 7, 3])
shape = helper.make_tensor("shape", TensorProto.INT64, [3], [10, 7, 3])
centercroppad_node = onnx.helper.make_node(
"CenterCropPad",
["matmul1_output", "shape"],
["centercroppad_output"],
name="Centercroppad_1",
)

matmul2_weight = helper.make_tensor(
"matmul2_weight", TensorProto.FLOAT, [3, 1], np.random.random((3, 1)).reshape(3).tolist()
)
matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [10, 7, 1])
matmul2_node = onnx.helper.make_node(
"MatMul", ["centercroppad_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2"
)

initializers = [matmul1_weight, shape, matmul2_weight]
graph = helper.make_graph(
[matmul1_node, centercroppad_node, matmul2_node],
"TestCenterCropPad_test_model",
[input_tensor],
[matmul2_output],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 8

q_config = {
"Matmul_0": self.static_q_config,
"Centercroppad_1": self.static_q_config,
"Matmul_2": self.static_q_config,
}
quantize_params = {
"input": [np.uint8(10.0), np.float32(0)],
"matmul1_weight": [np.uint8(10.0), np.float32(0)],
"matmul1_output": [np.uint8(10.0), np.float32(0)],
"matmul2_weight": [np.uint8(10.0), np.float32(0)],
"matmul2_output": [np.uint8(10.0), np.float32(0)],
"centercroppad_output": [np.uint8(10.0), np.float32(0)],
}
quantizable_op_types = ["MatMul", "CenterCropPad"]
q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

def test_gathernd(self):
# test GatherND nodes: MatMul-GatherND-MatMul
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 2, 1])

matmul1_weight = helper.make_tensor(
"matmul1_weight", TensorProto.FLOAT, [1, 2], np.random.random((1, 2)).reshape(2).tolist()
)
matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [2, 2, 2])
matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0")

gathernd_output = helper.make_tensor_value_info("gathernd_output", TensorProto.FLOAT, [2, 1, 2])
indices = helper.make_tensor("indices", TensorProto.INT64, [2, 1, 2], [0, 1, 1, 0])
gathernd_node = onnx.helper.make_node(
"GatherND",
["matmul1_output", "indices"],
["gathernd_output"],
name="Gathernd_1",
)

matmul2_weight = helper.make_tensor(
"matmul2_weight", TensorProto.FLOAT, [2, 1], np.random.random((2, 1)).reshape(2).tolist()
)
matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [2, 1, 1])
matmul2_node = onnx.helper.make_node(
"MatMul", ["gathernd_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2"
)

initializers = [matmul1_weight, indices, matmul2_weight]
graph = helper.make_graph(
[matmul1_node, gathernd_node, matmul2_node],
"TestGatherND_test_model",
[input_tensor],
[matmul2_output],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 7

q_config = {
"Matmul_0": self.static_q_config,
"Matmul_2": self.static_q_config,
"Gathernd_1": self.static_q_config,
}

quantize_params = {
"input": [np.uint8(10.0), np.float32(0)],
"matmul1_weight": [np.uint8(10.0), np.float32(0)],
"matmul1_output": [np.uint8(10.0), np.float32(0)],
"matmul2_weight": [np.uint8(10.0), np.float32(0)],
"matmul2_output": [np.uint8(10.0), np.float32(0)],
"gathernd_output": [np.uint8(10.0), np.float32(0)],
}
quantizable_op_types = ["MatMul", "GatherND"]
q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

def test_gatherelements(self):
# test GatherElements nodes: MatMul-GatherElements-MatMul
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 1])

matmul1_weight = helper.make_tensor(
"matmul1_weight", TensorProto.FLOAT, [1, 3], np.random.random((1, 3)).reshape(3).tolist()
)
matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [3, 3])
matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0")

gatherelements_output = helper.make_tensor_value_info("gatherelements_output", TensorProto.FLOAT, [2, 3])
indices = helper.make_tensor("indices", TensorProto.INT64, [2, 3], [-1, -2, 0, -2, 0, 0])
gathernd_node = onnx.helper.make_node(
"GatherElements",
["matmul1_output", "indices"],
["gatherelements_output"],
name="Gatherelements_1",
)

matmul2_weight = helper.make_tensor(
"matmul2_weight", TensorProto.FLOAT, [3, 1], np.random.random((3, 1)).reshape(3).tolist()
)
matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [2, 1])
matmul2_node = onnx.helper.make_node(
"MatMul", ["gatherelements_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2"
)

initializers = [matmul1_weight, indices, matmul2_weight]
graph = helper.make_graph(
[matmul1_node, gathernd_node, matmul2_node],
"TestGatherElements_test_model",
[input_tensor],
[matmul2_output],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 7

q_config = {
"Matmul_0": self.static_q_config,
"Matmul_2": self.static_q_config,
"Gatherelements_1": self.static_q_config,
}

quantize_params = {
"input": [np.uint8(10.0), np.float32(0)],
"matmul1_weight": [np.uint8(10.0), np.float32(0)],
"matmul1_output": [np.uint8(10.0), np.float32(0)],
"matmul2_weight": [np.uint8(10.0), np.float32(0)],
"matmul2_output": [np.uint8(10.0), np.float32(0)],
"gatherelements_output": [np.uint8(10.0), np.float32(0)],
}
quantizable_op_types = ["MatMul", "GatherElements"]
q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)

q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6)
self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4)
session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"])
self.assertIsNotNone(session)


class TestCastONNXRT(unittest.TestCase):
@classmethod
Expand Down