55import unittest
66import numpy as np
77
8- sys .path .append ('..' )
98from neural_compressor .model .onnx_model import ONNXModel
9+ from neural_compressor .data import Datasets , DATALOADERS
10+ from neural_compressor import quantization , PostTrainingQuantConfig
1011
1112def get_onnx_model ():
1213 model = torchvision .models .resnet18 ()
@@ -109,6 +110,48 @@ def setUp(self):
109110 model = helper .make_model (graph )
110111 self .q_model = ONNXModel (model )
111112
113+ # MatMul
114+ # |
115+ # Add
116+ # |
117+ # Reshape
118+ # |
119+ # Reshape
120+ # |
121+ # MatMul
122+ # |
123+ # Add
124+
125+ input = onnx .helper .make_tensor_value_info ('input' , onnx .TensorProto .FLOAT , [2 , 4 ])
126+
127+ W1 = onnx .helper .make_tensor_value_info ('W1' , onnx .TensorProto .FLOAT , [4 , 5 ])
128+ w1 = generate_input_initializer ([4 , 5 ], np .float32 , 'W1' )
129+ B1 = onnx .helper .make_tensor_value_info ('b1' , onnx .TensorProto .FLOAT , [5 ])
130+ b1 = generate_input_initializer ([5 ], np .float32 , 'b1' )
131+ shape = numpy_helper .from_array (np .array ((2 , 5 )).astype (np .int64 ), name = 'shape' )
132+ W2 = onnx .helper .make_tensor_value_info ('W2' , onnx .TensorProto .FLOAT , [5 , 6 ])
133+ w2 = generate_input_initializer ([5 , 6 ], np .float32 , 'W2' )
134+ B2 = onnx .helper .make_tensor_value_info ('b2' , onnx .TensorProto .FLOAT , [6 ])
135+ b2 = generate_input_initializer ([6 ], np .float32 , 'b2' )
136+ output = onnx .helper .make_tensor_value_info ('output' , onnx .TensorProto .FLOAT , [2 , 6 ])
137+
138+ node1 = onnx .helper .make_node ('MatMul' , inputs = ['input' , 'W1' ], outputs = ['y1' ])
139+ node2 = onnx .helper .make_node ('Add' , inputs = ['y1' , 'b1' ], outputs = ['y1_add_b1' ])
140+ node3 = onnx .helper .make_node ('Reshape' , inputs = ['y1_add_b1' , 'shape' ], outputs = ['y2' ])
141+ node4 = onnx .helper .make_node ('Reshape' , inputs = ['y2' , 'shape' ], outputs = ['y3' ])
142+ node5 = onnx .helper .make_node ('MatMul' , inputs = ['y3' , 'W2' ], outputs = ['y4' ])
143+ node6 = onnx .helper .make_node ('Add' , inputs = ['y4' , 'b2' ], outputs = ['output' ])
144+
145+ graph = onnx .helper .make_graph ([node1 , node2 , node3 , node4 , node5 , node6 ], 'test_matmul_reshape_graph' , [input , W1 , B1 , W2 , B2 ], [output ])
146+ graph .initializer .add ().CopyFrom (w1 )
147+ graph .initializer .add ().CopyFrom (b1 )
148+ graph .initializer .add ().CopyFrom (w2 )
149+ graph .initializer .add ().CopyFrom (b2 )
150+ graph .initializer .add ().CopyFrom (shape )
151+
152+ model = onnx .helper .make_model (graph , ** {'opset_imports' : [onnx .helper .make_opsetid ('' , 14 )]})
153+ self .matmul_reshape_model = model
154+
112155 def test_nodes (self ):
113156 self .assertEqual (len (self .model .nodes ()), 6 )
114157 nodes_name = [node .name for node in self .model .nodes ()]
@@ -254,9 +297,29 @@ def test_find_nodes_by_initializer(self):
254297 self .assertEqual (nodes [0 ].name , "Conv1" )
255298
256299 def test_get_scale_zero (self ):
257- input_scale , input_zero = self .q_model .get_scale_zero ('B_quantized' )
258- weight_scale , weight_zero = self .q_model .get_scale_zero ('C_quantized' )
259- bias_scale , bias_zero = self .q_model .get_scale_zero ('E' )
300+ import time
301+ result = [0.1 ]
302+ def sub_eval (model , result ):
303+ time .sleep (0.001 * len (result ))
304+ return result [0 ]
305+
306+ def eval (model ):
307+ return sub_eval (model , result )
308+
309+ dataset = Datasets ("onnxrt_qdq" )["dummy" ]((4 , 4 ), low = 0. , high = 0. , dtype = 'float32' )
310+ dataloader = DATALOADERS ["onnxrt_qdq" ](dataset , 2 )
311+ config = PostTrainingQuantConfig ()
312+ q_model = quantization .fit (self .matmul_reshape_model , config ,
313+ calib_dataloader = dataloader , eval_func = eval )
314+ q_model .save ('test.onnx' )
315+ scale , zp = q_model .get_scale_zero ('y3_QuantizeInput_quantized' )
316+ self .assertEqual (scale .name , 'y1_add_b1_scale' )
317+ self .assertEqual (zp .name , 'y1_add_b1_zero_point' )
318+
319+ scale , zp = q_model .get_scale_zero ('input_quantized' )
320+ self .assertEqual (scale .name , 'input_scale' )
321+ self .assertEqual (zp .name , 'input_zero_point' )
322+
260323
261324 def test_save (self ):
262325 self .model .save_model_to_file ('./test_model_6.onnx' , use_external_data_format = True )
@@ -268,5 +331,15 @@ def test_find_by_name(self):
268331 initializer = find_by_name ('X1' , self .model .initializer ())
269332 self .assertIsNone (initializer )
270333
334+ def test_remove_unused_nodes (self ):
335+ self .assertEqual (len (self .model .nodes ()), 6 )
336+ node_to_add = onnx .helper .make_node ('Relu' , ['output1' ], ['output2' ], keepdims = 0 , name = 'added_relu' )
337+ self .model .add_node (node_to_add )
338+ self .assertEqual (len (self .model .nodes ()), 7 )
339+ self .model .remove_unused_nodes ()
340+ self .assertEqual (len (self .model .nodes ()), 6 )
341+
342+
343+
271344if __name__ == "__main__" :
272345 unittest .main ()
0 commit comments