diff --git a/neural_compressor/adaptor/ox_utils/calibration.py b/neural_compressor/adaptor/ox_utils/calibration.py index bb6060096cd..a7a7756aaa1 100644 --- a/neural_compressor/adaptor/ox_utils/calibration.py +++ b/neural_compressor/adaptor/ox_utils/calibration.py @@ -139,17 +139,22 @@ def augment_graph(self, activation_only=False, weight_only=False): (node.name in self.white_nodes) if should_be_dump: if not weight_only and not activation_only: - tensors_to_dump.update(node.input) + tensors_to_dump.update([input for input in node.input if len(input) != 0]) + tensors_to_dump.update([output for output in node.output if len(output) != 0]) tensors_to_dump.update(node.output) elif weight_only: for input in node.input: if self.already_quantized and \ - input.replace('_dequantized', '_quantized') in initializers: + input.replace('_dequantized', '_quantized') in initializers and \ + len(input) != 0: tensors_to_dump.add(input) - elif not self.already_quantized and input in initializers: + elif not self.already_quantized and \ + input in initializers and \ + len(input) != 0: tensors_to_dump.add(input) elif activation_only: - tensors_to_dump.update([node.input[0]]) + if len(node.input[0]) != 0: + tensors_to_dump.update([node.input[0]]) model_inputs = [i.name for i in model.graph.input] for tensor in tensors_to_dump: @@ -525,6 +530,8 @@ def dump_tensor(self, activation=True, weight=False, format=None): for i in range(iters): if node.op_type in ['Attention', 'QAttention'] and tensor_name not in node.input[:2]: continue + if node.op_type in ['MatMul', 'QLinearMatMul'] and tensor_name != node.input[0]: + continue if is_qdq: map_node_activation[i][node_name] = \ {tensor_name.replace('_dequantized', '').replace('_' + node_name, ''): tensors[i]} diff --git a/neural_compressor/adaptor/ox_utils/quantizer.py b/neural_compressor/adaptor/ox_utils/quantizer.py index 04558648fe3..9749391ebb2 100644 --- a/neural_compressor/adaptor/ox_utils/quantizer.py +++ b/neural_compressor/adaptor/ox_utils/quantizer.py @@ -149,7 +149,7 @@ def quantize_model(self): """Quantize onnx model.""" # step 1: insert q-dq, cast-cast pairs self.insert_qdq() - + # step 2: remove redundant pairs -> qdq model self.remove_redundant_pairs() @@ -158,7 +158,7 @@ def quantize_model(self): self.merge_dedicated_qdq_pair() - self.model.remove_unused_constant() + self.model.remove_unused_nodes() self.model.model.producer_name = __producer__ self.model.model.producer_version = __version__ diff --git a/neural_compressor/adaptor/ox_utils/smooth_quant.py b/neural_compressor/adaptor/ox_utils/smooth_quant.py index edf5d20c615..7bca8cf00c9 100644 --- a/neural_compressor/adaptor/ox_utils/smooth_quant.py +++ b/neural_compressor/adaptor/ox_utils/smooth_quant.py @@ -171,7 +171,7 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm' if folding: self._fold_scale(scales) self.model.topological_sort() - self.model.remove_unused_constant() + self.model.remove_unused_nodes() return self.model def recover(self): diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index 27583181a0a..f6dd37c605f 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -329,32 +329,45 @@ def get_scale_zero(self, tensor): if not tensor.endswith('_quantized'): logger.debug("Find {} in the quantized graph is not quantized.".format(tensor)) return None, None + + def _searcher(tensor_name): + """Search scale and zero point tensor recursivly.""" + node = self._input_name_to_nodes[tensor_name][0] + parent = self._output_name_to_node[tensor_name] if tensor_name in self._output_name_to_node else None + direct_int8 = ['Reshape', 'Transpose', 'Squeeze', 'Unsqueeze', 'MaxPool', 'Pad', 'Split'] + if parent is not None and parent.op_type in direct_int8: + fp32_tensor_name = \ + parent.input[0].replace('_quantized', '')\ + .replace('_QuantizeLinear', '').replace('_QuantizeInput', '') + elif node.op_type in ['Gather']: # pragma: no cover + fp32_tensor_name = \ + node.output[0].replace('_quantized', '')\ + .replace('_QuantizeLinear', '').replace('_QuantizeInput', '') + else: + fp32_tensor_name = \ + tensor_name.replace('_quantized', '')\ + .replace('_QuantizeLinear', '').replace('_QuantizeInput', '') + scale = fp32_tensor_name + '_scale' + scale_tensor = self.get_initializer(scale) + zo = fp32_tensor_name + '_zero_point' + zo_tensor = self.get_initializer(zo) + + if scale_tensor is None or zo_tensor is None: + if parent is not None: + scale_tensor, zo_tensor = _searcher(parent.input[0]) + return scale_tensor, zo_tensor + node = self._input_name_to_nodes[tensor][0] - parent = self._output_name_to_node[tensor] if tensor in self._output_name_to_node else None - direct_int8 = ['Reshape', 'Transpose', 'Squeeze', 'Unsqueeze', 'MaxPool', 'Pad'] - if parent is not None and parent.op_type in direct_int8: - fp32_tensor_name = \ - parent.input[0].replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '') - elif node.op_type in ['Gather']: - fp32_tensor_name = \ - node.output[0].replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '') - else: - fp32_tensor_name = \ - tensor.replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '') - scale = fp32_tensor_name + '_scale' - scale_tensor = self.get_initializer(scale) - zo = fp32_tensor_name + '_zero_point' - zo_tensor = self.get_initializer(zo) - #TODO check if scale_tensor and zero_point is needed # for bias of qlinearconv, scale and zero_point is not needed if (node.op_type == 'QLinearConv' and tensor == node.input[-1]) or \ (node.op_type == 'QGemm' and tensor == node.input[-3]): - pass + return None, None else: + scale_tensor, zo_tensor = _searcher(tensor) assert scale_tensor, 'missing scale for tensor {}'.format(tensor) assert zo_tensor, 'missing zero point for tensor {}'.format(tensor) - return scale_tensor, zo_tensor + return scale_tensor, zo_tensor def save_model_to_file(self, output_path, use_external_data_format=False): """Save model to external data, which is needed for model size > 2GB.""" @@ -405,8 +418,8 @@ def replace_output_of_all_nodes(self, old_output_name, new_output_name, if node.op_type not in black_optype: ONNXModel.replace_node_output(node, old_output_name, new_output_name) - def remove_unused_constant(self): - """Remove unused constant.""" + def remove_unused_nodes(self): + """Remove unused nodes.""" unused_nodes = [] nodes = self.nodes() for node in nodes: @@ -419,6 +432,23 @@ def remove_unused_constant(self): self.get_children(node)[0].output[0] not in self._input_name_to_nodes: unused_nodes.append(node) unused_nodes.extend(self.get_children(node)) + else: + # remove the node if it does not serve as the input or output of any other nodes + unused = True + for output in node.output: + if output in self._input_name_to_nodes or \ + output in self.output(): + unused = False + break + for input in node.input: + if self.get_initializer(input) is not None: + continue + elif input in self._output_name_to_node or \ + input in self.input(): + unused = False + break + if unused: + unused_nodes.append(node) self.remove_nodes(unused_nodes) ununsed_weights = [] @@ -615,7 +645,7 @@ def export(self, save_path, conf): self.remove_nodes(remove_nodes) self.add_initializers(inits) self.update() - self.remove_unused_constant() + self.remove_unused_nodes() self.topological_sort() self.save(save_path) else: diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index ee1337a313e..4759773b08d 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -5,8 +5,9 @@ import unittest import numpy as np -sys.path.append('..') from neural_compressor.model.onnx_model import ONNXModel +from neural_compressor.data import Datasets, DATALOADERS +from neural_compressor import quantization, PostTrainingQuantConfig def get_onnx_model(): model = torchvision.models.resnet18() @@ -109,6 +110,48 @@ def setUp(self): model = helper.make_model(graph) self.q_model = ONNXModel(model) + # MatMul + # | + # Add + # | + # Reshape + # | + # Reshape + # | + # MatMul + # | + # Add + + input = onnx.helper.make_tensor_value_info('input', onnx.TensorProto.FLOAT, [2, 4]) + + W1 = onnx.helper.make_tensor_value_info('W1', onnx.TensorProto.FLOAT, [4, 5]) + w1 = generate_input_initializer([4, 5], np.float32, 'W1') + B1 = onnx.helper.make_tensor_value_info('b1', onnx.TensorProto.FLOAT, [5]) + b1 = generate_input_initializer([5], np.float32, 'b1') + shape = numpy_helper.from_array(np.array((2, 5)).astype(np.int64), name='shape') + W2 = onnx.helper.make_tensor_value_info('W2', onnx.TensorProto.FLOAT, [5, 6]) + w2 = generate_input_initializer([5, 6], np.float32, 'W2') + B2 = onnx.helper.make_tensor_value_info('b2', onnx.TensorProto.FLOAT, [6]) + b2 = generate_input_initializer([6], np.float32, 'b2') + output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [2, 6]) + + node1 = onnx.helper.make_node('MatMul', inputs=['input', 'W1'], outputs=['y1']) + node2 = onnx.helper.make_node('Add', inputs=['y1', 'b1'], outputs=['y1_add_b1']) + node3 = onnx.helper.make_node('Reshape', inputs=['y1_add_b1', 'shape'], outputs=['y2']) + node4 = onnx.helper.make_node('Reshape', inputs=['y2', 'shape'], outputs=['y3']) + node5 = onnx.helper.make_node('MatMul', inputs=['y3', 'W2'], outputs=['y4']) + node6 = onnx.helper.make_node('Add', inputs=['y4', 'b2'], outputs=['output']) + + graph = onnx.helper.make_graph([node1, node2, node3, node4, node5, node6], 'test_matmul_reshape_graph', [input, W1, B1, W2, B2], [output]) + graph.initializer.add().CopyFrom(w1) + graph.initializer.add().CopyFrom(b1) + graph.initializer.add().CopyFrom(w2) + graph.initializer.add().CopyFrom(b2) + graph.initializer.add().CopyFrom(shape) + + model = onnx.helper.make_model(graph, **{'opset_imports': [onnx.helper.make_opsetid('', 14)]}) + self.matmul_reshape_model = model + def test_nodes(self): self.assertEqual(len(self.model.nodes()), 6) nodes_name = [node.name for node in self.model.nodes()] @@ -254,9 +297,29 @@ def test_find_nodes_by_initializer(self): self.assertEqual(nodes[0].name, "Conv1") def test_get_scale_zero(self): - input_scale, input_zero = self.q_model.get_scale_zero('B_quantized') - weight_scale, weight_zero = self.q_model.get_scale_zero('C_quantized') - bias_scale, bias_zero = self.q_model.get_scale_zero('E') + import time + result = [0.1] + def sub_eval(model, result): + time.sleep(0.001 * len(result)) + return result[0] + + def eval(model): + return sub_eval(model, result) + + dataset = Datasets("onnxrt_qdq")["dummy"]((4, 4), low=0., high=0., dtype='float32') + dataloader = DATALOADERS["onnxrt_qdq"](dataset, 2) + config = PostTrainingQuantConfig() + q_model = quantization.fit(self.matmul_reshape_model, config, + calib_dataloader=dataloader, eval_func=eval) + q_model.save('test.onnx') + scale, zp = q_model.get_scale_zero('y3_QuantizeInput_quantized') + self.assertEqual(scale.name, 'y1_add_b1_scale') + self.assertEqual(zp.name, 'y1_add_b1_zero_point') + + scale, zp = q_model.get_scale_zero('input_quantized') + self.assertEqual(scale.name, 'input_scale') + self.assertEqual(zp.name, 'input_zero_point') + def test_save(self): self.model.save_model_to_file('./test_model_6.onnx', use_external_data_format=True) @@ -268,5 +331,15 @@ def test_find_by_name(self): initializer = find_by_name('X1', self.model.initializer()) self.assertIsNone(initializer) + def test_remove_unused_nodes(self): + self.assertEqual(len(self.model.nodes()), 6) + node_to_add = onnx.helper.make_node('Relu', ['output1'], ['output2'], keepdims=0, name='added_relu') + self.model.add_node(node_to_add) + self.assertEqual(len(self.model.nodes()), 7) + self.model.remove_unused_nodes() + self.assertEqual(len(self.model.nodes()), 6) + + + if __name__ == "__main__": unittest.main()