diff --git a/neural_compressor/adaptor/onnxrt.yaml b/neural_compressor/adaptor/onnxrt.yaml index a55250e9da9..c02ac89ffb0 100644 --- a/neural_compressor/adaptor/onnxrt.yaml +++ b/neural_compressor/adaptor/onnxrt.yaml @@ -379,6 +379,16 @@ 'activation': *uint8_asym_pertensor_minmax, 'mode': ['QDQ', 'QLinear'] }, + 'GatherElements': { + 'weight': *uint8_asym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, + 'GatherND': { + 'weight': *uint8_asym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, 'MatMul': { 'weight': *int8_sym_perchanneltensor_minmax, 'activation': *uint8_asym_pertensor, @@ -422,6 +432,7 @@ 'Mod': *default_static_qlinear_qdq_minmax, 'ReduceMax': *default_static_qlinear_qdq_minmax, 'ReduceMin': *default_static_qlinear_qdq_minmax, + 'Tile': *default_static_qlinear_qdq_minmax, }, 'dynamic': *ref_1_9_dynamic } @@ -436,6 +447,88 @@ recipes: <<: *default_optimization +- + version: + name: '1.13.0' + int8: &ref_1_13 { + 'static': { + 'FusedConv': { + 'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, + 'Conv': { + 'weight': *int8_sym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor, + 'mode': ['QDQ', 'QLinear'] + }, + 'Gather': { + 'weight': *uint8_asym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, + 'GatherElements': { + 'weight': *uint8_asym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, + 'GatherND': { + 'weight': *uint8_asym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, + 'MatMul': { + 'weight': *int8_sym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor, + 'mode': ['QDQ', 'QLinear'] + }, + 'Gemm': { + 'weight': *int8_sym_perchanneltensor_minmax, + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, + 'EmbedLayerNormalization': { + 'weight': *uint8_asym_pertensor_minmax, # QDQ: *int8_sym_pertensor_minmax + 'activation': *uint8_asym_pertensor_minmax, + 'mode': ['QDQ', 'QLinear'] + }, + 'Attention': *default_static_qlinear_qdq_minmax, + 'Mul': *default_static_qlinear, + 'Relu': *default_static_qlinear_qdq_minmax, + 'Clip': *default_static_qlinear_qdq_minmax, + 'LeakyRelu': *default_static_qlinear_qdq_minmax, + 'Sigmoid': *default_static_qlinear_qdq_minmax, + 'MaxPool': *default_static_qlinear_qdq_minmax, + 'GlobalAveragePool': *default_static_qlinear_qdq_minmax, + 'Pad': *default_static_qlinear_qdq_minmax, + 'Split': *default_static_qlinear_qdq_minmax, + 'Add': *default_static_qlinear, + 'Squeeze': *default_static_qlinear_qdq_minmax, + 'Reshape': *default_static_qlinear_qdq_minmax, + 'Concat': *default_static_qlinear_qdq_minmax, + 'AveragePool': *default_static_qlinear_qdq_minmax, + 'Unsqueeze': *default_static_qlinear_qdq_minmax, + 'Transpose': *default_static_qlinear_qdq_minmax, + 'ArgMax': *default_static_qlinear, + 'Resize': *default_static_qlinear_qdq_minmax, + 'Abs': *default_static_qlinear_qdq_minmax, + 'Shrink': *default_static_qlinear_qdq_minmax, + 'Sign': *default_static_qlinear_qdq_minmax, + 'Flatten': *default_static_qlinear_qdq_minmax, + 'Expand': *default_static_qlinear_qdq_minmax, + 'Slice': *default_static_qlinear_qdq_minmax, + 'Mod': *default_static_qlinear_qdq_minmax, + 'ReduceMax': *default_static_qlinear_qdq_minmax, + 'ReduceMin': *default_static_qlinear_qdq_minmax, + 'Tile': *default_static_qlinear_qdq_minmax, + 'CenterCropPad': *default_static_qlinear_qdq_minmax, + }, + 'dynamic': *ref_1_9_dynamic + } + weight_only_integer: *cap_weight_only + recipes: + <<: *default_optimization + - version: name: 'default' diff --git a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py index f40a154070e..a3d52089ce6 100644 --- a/neural_compressor/adaptor/ox_utils/operators/direct_q8.py +++ b/neural_compressor/adaptor/ox_utils/operators/direct_q8.py @@ -20,7 +20,8 @@ @op_registry( - op_types="Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Expand, Slice, " "SpaceToDepth, DepthToSpace, Upsample" + op_types="Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Expand, Slice, " + "SpaceToDepth, DepthToSpace, Upsample, Tile, CenterCropPad" ) class Direct8BitOperator(Operator): """Direct8Bit Operator.""" diff --git a/neural_compressor/adaptor/ox_utils/operators/gather.py b/neural_compressor/adaptor/ox_utils/operators/gather.py index 74360cd32af..58368a9ac02 100644 --- a/neural_compressor/adaptor/ox_utils/operators/gather.py +++ b/neural_compressor/adaptor/ox_utils/operators/gather.py @@ -22,7 +22,7 @@ from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg -@op_registry(op_types="Gather") +@op_registry(op_types="Gather, GatherElements, GatherND") class GatherOperator(Operator): """Gather Operator.""" diff --git a/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py b/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py index b1cef9e099a..2ebe0e10551 100644 --- a/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py +++ b/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py @@ -1562,6 +1562,260 @@ def test_reducemin_reducemax(self): session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) self.assertIsNotNone(session) + def test_tile(self): + # test Tile nodes: MatMul-Tile-MatMul + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 1]) + + matmul1_weight = helper.make_tensor( + "matmul1_weight", TensorProto.FLOAT, [1, 5], np.random.random((1, 5)).reshape(5).tolist() + ) + matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [2, 3, 4, 5]) + matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0") + + repeats = helper.make_tensor("repeats", TensorProto.INT64, [4], [2, 2, 2, 2]) + tile_output = helper.make_tensor_value_info("tile_output", TensorProto.FLOAT, [4, 6, 8, 10]) + tile_node = onnx.helper.make_node( + "Tile", + ["matmul1_output", "repeats"], + ["tile_output"], + name="Tile_1", + ) + + matmul2_weight = helper.make_tensor( + "matmul2_weight", TensorProto.FLOAT, [10, 1], np.random.random((10, 1)).reshape(10).tolist() + ) + matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [4, 6, 8, 1]) + matmul2_node = onnx.helper.make_node( + "MatMul", ["tile_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2" + ) + + initializers = [matmul1_weight, matmul2_weight, repeats] + graph = helper.make_graph( + [matmul1_node, tile_node, matmul2_node], + "TestTile_test_model", + [input_tensor], + [matmul2_output], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 + + q_config = {"Matmul_0": self.static_q_config, "Tile_1": self.static_q_config, "Matmul_2": self.static_q_config} + quantize_params = { + "input": [np.uint8(10.0), np.float32(0)], + "matmul1_weight": [np.uint8(10.0), np.float32(0)], + "matmul1_output": [np.uint8(10.0), np.float32(0)], + "matmul2_weight": [np.uint8(10.0), np.float32(0)], + "matmul2_output": [np.uint8(10.0), np.float32(0)], + "tile_output": [np.uint8(10.0), np.float32(0)], + } + quantizable_op_types = ["MatMul", "Tile"] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + + q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + + def test_centercroppad(self): + # test CenterCropPad nodes: MatMul-CenterCropPad-MatMul + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [20, 10, 1]) + + matmul1_weight = helper.make_tensor( + "matmul1_weight", TensorProto.FLOAT, [1, 3], np.random.random((1, 3)).reshape(3).tolist() + ) + matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [20, 10, 3]) + matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0") + + centercroppad_output = helper.make_tensor_value_info("centercroppad_output", TensorProto.FLOAT, [10, 7, 3]) + shape = helper.make_tensor("shape", TensorProto.INT64, [3], [10, 7, 3]) + centercroppad_node = onnx.helper.make_node( + "CenterCropPad", + ["matmul1_output", "shape"], + ["centercroppad_output"], + name="Centercroppad_1", + ) + + matmul2_weight = helper.make_tensor( + "matmul2_weight", TensorProto.FLOAT, [3, 1], np.random.random((3, 1)).reshape(3).tolist() + ) + matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [10, 7, 1]) + matmul2_node = onnx.helper.make_node( + "MatMul", ["centercroppad_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2" + ) + + initializers = [matmul1_weight, shape, matmul2_weight] + graph = helper.make_graph( + [matmul1_node, centercroppad_node, matmul2_node], + "TestCenterCropPad_test_model", + [input_tensor], + [matmul2_output], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + model.ir_version = 8 + + q_config = { + "Matmul_0": self.static_q_config, + "Centercroppad_1": self.static_q_config, + "Matmul_2": self.static_q_config, + } + quantize_params = { + "input": [np.uint8(10.0), np.float32(0)], + "matmul1_weight": [np.uint8(10.0), np.float32(0)], + "matmul1_output": [np.uint8(10.0), np.float32(0)], + "matmul2_weight": [np.uint8(10.0), np.float32(0)], + "matmul2_output": [np.uint8(10.0), np.float32(0)], + "centercroppad_output": [np.uint8(10.0), np.float32(0)], + } + quantizable_op_types = ["MatMul", "CenterCropPad"] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + + q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + + def test_gathernd(self): + # test GatherND nodes: MatMul-GatherND-MatMul + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 2, 1]) + + matmul1_weight = helper.make_tensor( + "matmul1_weight", TensorProto.FLOAT, [1, 2], np.random.random((1, 2)).reshape(2).tolist() + ) + matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [2, 2, 2]) + matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0") + + gathernd_output = helper.make_tensor_value_info("gathernd_output", TensorProto.FLOAT, [2, 1, 2]) + indices = helper.make_tensor("indices", TensorProto.INT64, [2, 1, 2], [0, 1, 1, 0]) + gathernd_node = onnx.helper.make_node( + "GatherND", + ["matmul1_output", "indices"], + ["gathernd_output"], + name="Gathernd_1", + ) + + matmul2_weight = helper.make_tensor( + "matmul2_weight", TensorProto.FLOAT, [2, 1], np.random.random((2, 1)).reshape(2).tolist() + ) + matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [2, 1, 1]) + matmul2_node = onnx.helper.make_node( + "MatMul", ["gathernd_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2" + ) + + initializers = [matmul1_weight, indices, matmul2_weight] + graph = helper.make_graph( + [matmul1_node, gathernd_node, matmul2_node], + "TestGatherND_test_model", + [input_tensor], + [matmul2_output], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 + + q_config = { + "Matmul_0": self.static_q_config, + "Matmul_2": self.static_q_config, + "Gathernd_1": self.static_q_config, + } + + quantize_params = { + "input": [np.uint8(10.0), np.float32(0)], + "matmul1_weight": [np.uint8(10.0), np.float32(0)], + "matmul1_output": [np.uint8(10.0), np.float32(0)], + "matmul2_weight": [np.uint8(10.0), np.float32(0)], + "matmul2_output": [np.uint8(10.0), np.float32(0)], + "gathernd_output": [np.uint8(10.0), np.float32(0)], + } + quantizable_op_types = ["MatMul", "GatherND"] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + + q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + + def test_gatherelements(self): + # test GatherElements nodes: MatMul-GatherElements-MatMul + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 1]) + + matmul1_weight = helper.make_tensor( + "matmul1_weight", TensorProto.FLOAT, [1, 3], np.random.random((1, 3)).reshape(3).tolist() + ) + matmul1_output = helper.make_tensor_value_info("matmul1_output", TensorProto.FLOAT, [3, 3]) + matmul1_node = onnx.helper.make_node("MatMul", ["input", "matmul1_weight"], ["matmul1_output"], name="Matmul_0") + + gatherelements_output = helper.make_tensor_value_info("gatherelements_output", TensorProto.FLOAT, [2, 3]) + indices = helper.make_tensor("indices", TensorProto.INT64, [2, 3], [-1, -2, 0, -2, 0, 0]) + gathernd_node = onnx.helper.make_node( + "GatherElements", + ["matmul1_output", "indices"], + ["gatherelements_output"], + name="Gatherelements_1", + ) + + matmul2_weight = helper.make_tensor( + "matmul2_weight", TensorProto.FLOAT, [3, 1], np.random.random((3, 1)).reshape(3).tolist() + ) + matmul2_output = helper.make_tensor_value_info("matmul2_output", TensorProto.FLOAT, [2, 1]) + matmul2_node = onnx.helper.make_node( + "MatMul", ["gatherelements_output", "matmul2_weight"], ["matmul2_output"], name="Matmul_2" + ) + + initializers = [matmul1_weight, indices, matmul2_weight] + graph = helper.make_graph( + [matmul1_node, gathernd_node, matmul2_node], + "TestGatherElements_test_model", + [input_tensor], + [matmul2_output], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 + + q_config = { + "Matmul_0": self.static_q_config, + "Matmul_2": self.static_q_config, + "Gatherelements_1": self.static_q_config, + } + + quantize_params = { + "input": [np.uint8(10.0), np.float32(0)], + "matmul1_weight": [np.uint8(10.0), np.float32(0)], + "matmul1_output": [np.uint8(10.0), np.float32(0)], + "matmul2_weight": [np.uint8(10.0), np.float32(0)], + "matmul2_output": [np.uint8(10.0), np.float32(0)], + "gatherelements_output": [np.uint8(10.0), np.float32(0)], + } + quantizable_op_types = ["MatMul", "GatherElements"] + q_model = self.qlinear_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 1) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 1) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + + q_model = self.qdq_test(model, q_config, quantize_params, quantizable_op_types) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["DequantizeLinear"], 6) + self.assertEqual(Counter([node.op_type for node in q_model.model.graph.node])["QuantizeLinear"], 4) + session = ort.InferenceSession(q_model.model.SerializeToString(), providers=["CPUExecutionProvider"]) + self.assertIsNotNone(session) + class TestCastONNXRT(unittest.TestCase): @classmethod