Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions neural_compressor/adaptor/ox_utils/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,22 @@ def augment_graph(self, activation_only=False, weight_only=False):
(node.name in self.white_nodes)
if should_be_dump:
if not weight_only and not activation_only:
tensors_to_dump.update(node.input)
tensors_to_dump.update([input for input in node.input if len(input) != 0])
tensors_to_dump.update([output for output in node.output if len(output) != 0])
tensors_to_dump.update(node.output)
elif weight_only:
for input in node.input:
if self.already_quantized and \
input.replace('_dequantized', '_quantized') in initializers:
input.replace('_dequantized', '_quantized') in initializers and \
len(input) != 0:
tensors_to_dump.add(input)
elif not self.already_quantized and input in initializers:
elif not self.already_quantized and \
input in initializers and \
len(input) != 0:
tensors_to_dump.add(input)
elif activation_only:
tensors_to_dump.update([node.input[0]])
if len(node.input[0]) != 0:
tensors_to_dump.update([node.input[0]])

model_inputs = [i.name for i in model.graph.input]
for tensor in tensors_to_dump:
Expand Down Expand Up @@ -525,6 +530,8 @@ def dump_tensor(self, activation=True, weight=False, format=None):
for i in range(iters):
if node.op_type in ['Attention', 'QAttention'] and tensor_name not in node.input[:2]:
continue
if node.op_type in ['MatMul', 'QLinearMatMul'] and tensor_name != node.input[0]:
continue
if is_qdq:
map_node_activation[i][node_name] = \
{tensor_name.replace('_dequantized', '').replace('_' + node_name, ''): tensors[i]}
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/ox_utils/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def quantize_model(self):
"""Quantize onnx model."""
# step 1: insert q-dq, cast-cast pairs
self.insert_qdq()

# step 2: remove redundant pairs -> qdq model
self.remove_redundant_pairs()

Expand All @@ -158,7 +158,7 @@ def quantize_model(self):

self.merge_dedicated_qdq_pair()

self.model.remove_unused_constant()
self.model.remove_unused_nodes()

self.model.model.producer_name = __producer__
self.model.model.producer_version = __version__
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/ox_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm'
if folding:
self._fold_scale(scales)
self.model.topological_sort()
self.model.remove_unused_constant()
self.model.remove_unused_nodes()
return self.model

def recover(self):
Expand Down
72 changes: 51 additions & 21 deletions neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,32 +329,45 @@ def get_scale_zero(self, tensor):
if not tensor.endswith('_quantized'):
logger.debug("Find {} in the quantized graph is not quantized.".format(tensor))
return None, None

def _searcher(tensor_name):
"""Search scale and zero point tensor recursivly."""
node = self._input_name_to_nodes[tensor_name][0]
parent = self._output_name_to_node[tensor_name] if tensor_name in self._output_name_to_node else None
direct_int8 = ['Reshape', 'Transpose', 'Squeeze', 'Unsqueeze', 'MaxPool', 'Pad', 'Split']
if parent is not None and parent.op_type in direct_int8:
fp32_tensor_name = \
parent.input[0].replace('_quantized', '')\
.replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
elif node.op_type in ['Gather']: # pragma: no cover
fp32_tensor_name = \
node.output[0].replace('_quantized', '')\
.replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
else:
fp32_tensor_name = \
tensor_name.replace('_quantized', '')\
.replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
scale = fp32_tensor_name + '_scale'
scale_tensor = self.get_initializer(scale)
zo = fp32_tensor_name + '_zero_point'
zo_tensor = self.get_initializer(zo)

if scale_tensor is None or zo_tensor is None:
if parent is not None:
scale_tensor, zo_tensor = _searcher(parent.input[0])
return scale_tensor, zo_tensor

node = self._input_name_to_nodes[tensor][0]
parent = self._output_name_to_node[tensor] if tensor in self._output_name_to_node else None
direct_int8 = ['Reshape', 'Transpose', 'Squeeze', 'Unsqueeze', 'MaxPool', 'Pad']
if parent is not None and parent.op_type in direct_int8:
fp32_tensor_name = \
parent.input[0].replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
elif node.op_type in ['Gather']:
fp32_tensor_name = \
node.output[0].replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
else:
fp32_tensor_name = \
tensor.replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
scale = fp32_tensor_name + '_scale'
scale_tensor = self.get_initializer(scale)
zo = fp32_tensor_name + '_zero_point'
zo_tensor = self.get_initializer(zo)

#TODO check if scale_tensor and zero_point is needed
# for bias of qlinearconv, scale and zero_point is not needed
if (node.op_type == 'QLinearConv' and tensor == node.input[-1]) or \
(node.op_type == 'QGemm' and tensor == node.input[-3]):
pass
return None, None
else:
scale_tensor, zo_tensor = _searcher(tensor)
assert scale_tensor, 'missing scale for tensor {}'.format(tensor)
assert zo_tensor, 'missing zero point for tensor {}'.format(tensor)
return scale_tensor, zo_tensor
return scale_tensor, zo_tensor

def save_model_to_file(self, output_path, use_external_data_format=False):
"""Save model to external data, which is needed for model size > 2GB."""
Expand Down Expand Up @@ -405,8 +418,8 @@ def replace_output_of_all_nodes(self, old_output_name, new_output_name,
if node.op_type not in black_optype:
ONNXModel.replace_node_output(node, old_output_name, new_output_name)

def remove_unused_constant(self):
"""Remove unused constant."""
def remove_unused_nodes(self):
"""Remove unused nodes."""
unused_nodes = []
nodes = self.nodes()
for node in nodes:
Expand All @@ -419,6 +432,23 @@ def remove_unused_constant(self):
self.get_children(node)[0].output[0] not in self._input_name_to_nodes:
unused_nodes.append(node)
unused_nodes.extend(self.get_children(node))
else:
# remove the node if it does not serve as the input or output of any other nodes
unused = True
for output in node.output:
if output in self._input_name_to_nodes or \
output in self.output():
unused = False
break
for input in node.input:
if self.get_initializer(input) is not None:
continue
elif input in self._output_name_to_node or \
input in self.input():
unused = False
break
if unused:
unused_nodes.append(node)
self.remove_nodes(unused_nodes)

ununsed_weights = []
Expand Down Expand Up @@ -615,7 +645,7 @@ def export(self, save_path, conf):
self.remove_nodes(remove_nodes)
self.add_initializers(inits)
self.update()
self.remove_unused_constant()
self.remove_unused_nodes()
self.topological_sort()
self.save(save_path)
else:
Expand Down
81 changes: 77 additions & 4 deletions test/model/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import unittest
import numpy as np

sys.path.append('..')
from neural_compressor.model.onnx_model import ONNXModel
from neural_compressor.data import Datasets, DATALOADERS
from neural_compressor import quantization, PostTrainingQuantConfig

def get_onnx_model():
model = torchvision.models.resnet18()
Expand Down Expand Up @@ -109,6 +110,48 @@ def setUp(self):
model = helper.make_model(graph)
self.q_model = ONNXModel(model)

# MatMul
# |
# Add
# |
# Reshape
# |
# Reshape
# |
# MatMul
# |
# Add

input = onnx.helper.make_tensor_value_info('input', onnx.TensorProto.FLOAT, [2, 4])

W1 = onnx.helper.make_tensor_value_info('W1', onnx.TensorProto.FLOAT, [4, 5])
w1 = generate_input_initializer([4, 5], np.float32, 'W1')
B1 = onnx.helper.make_tensor_value_info('b1', onnx.TensorProto.FLOAT, [5])
b1 = generate_input_initializer([5], np.float32, 'b1')
shape = numpy_helper.from_array(np.array((2, 5)).astype(np.int64), name='shape')
W2 = onnx.helper.make_tensor_value_info('W2', onnx.TensorProto.FLOAT, [5, 6])
w2 = generate_input_initializer([5, 6], np.float32, 'W2')
B2 = onnx.helper.make_tensor_value_info('b2', onnx.TensorProto.FLOAT, [6])
b2 = generate_input_initializer([6], np.float32, 'b2')
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [2, 6])

node1 = onnx.helper.make_node('MatMul', inputs=['input', 'W1'], outputs=['y1'])
node2 = onnx.helper.make_node('Add', inputs=['y1', 'b1'], outputs=['y1_add_b1'])
node3 = onnx.helper.make_node('Reshape', inputs=['y1_add_b1', 'shape'], outputs=['y2'])
node4 = onnx.helper.make_node('Reshape', inputs=['y2', 'shape'], outputs=['y3'])
node5 = onnx.helper.make_node('MatMul', inputs=['y3', 'W2'], outputs=['y4'])
node6 = onnx.helper.make_node('Add', inputs=['y4', 'b2'], outputs=['output'])

graph = onnx.helper.make_graph([node1, node2, node3, node4, node5, node6], 'test_matmul_reshape_graph', [input, W1, B1, W2, B2], [output])
graph.initializer.add().CopyFrom(w1)
graph.initializer.add().CopyFrom(b1)
graph.initializer.add().CopyFrom(w2)
graph.initializer.add().CopyFrom(b2)
graph.initializer.add().CopyFrom(shape)

model = onnx.helper.make_model(graph, **{'opset_imports': [onnx.helper.make_opsetid('', 14)]})
self.matmul_reshape_model = model

def test_nodes(self):
self.assertEqual(len(self.model.nodes()), 6)
nodes_name = [node.name for node in self.model.nodes()]
Expand Down Expand Up @@ -254,9 +297,29 @@ def test_find_nodes_by_initializer(self):
self.assertEqual(nodes[0].name, "Conv1")

def test_get_scale_zero(self):
input_scale, input_zero = self.q_model.get_scale_zero('B_quantized')
weight_scale, weight_zero = self.q_model.get_scale_zero('C_quantized')
bias_scale, bias_zero = self.q_model.get_scale_zero('E')
import time
result = [0.1]
def sub_eval(model, result):
time.sleep(0.001 * len(result))
return result[0]

def eval(model):
return sub_eval(model, result)

dataset = Datasets("onnxrt_qdq")["dummy"]((4, 4), low=0., high=0., dtype='float32')
dataloader = DATALOADERS["onnxrt_qdq"](dataset, 2)
config = PostTrainingQuantConfig()
q_model = quantization.fit(self.matmul_reshape_model, config,
calib_dataloader=dataloader, eval_func=eval)
q_model.save('test.onnx')
scale, zp = q_model.get_scale_zero('y3_QuantizeInput_quantized')
self.assertEqual(scale.name, 'y1_add_b1_scale')
self.assertEqual(zp.name, 'y1_add_b1_zero_point')

scale, zp = q_model.get_scale_zero('input_quantized')
self.assertEqual(scale.name, 'input_scale')
self.assertEqual(zp.name, 'input_zero_point')


def test_save(self):
self.model.save_model_to_file('./test_model_6.onnx', use_external_data_format=True)
Expand All @@ -268,5 +331,15 @@ def test_find_by_name(self):
initializer = find_by_name('X1', self.model.initializer())
self.assertIsNone(initializer)

def test_remove_unused_nodes(self):
self.assertEqual(len(self.model.nodes()), 6)
node_to_add = onnx.helper.make_node('Relu', ['output1'], ['output2'], keepdims=0, name='added_relu')
self.model.add_node(node_to_add)
self.assertEqual(len(self.model.nodes()), 7)
self.model.remove_unused_nodes()
self.assertEqual(len(self.model.nodes()), 6)



if __name__ == "__main__":
unittest.main()