@@ -157,70 +157,85 @@ def set_data_type(
157157 return activation_type , weight_type
158158
159159
160- def get_quantizable_onnx_ops (
161- q_config ,
160+ def get_node_mapping (
161+ fp32_model ,
162162 fp32_onnx_path ,
163163):
164- """Get quantizable onnx ops .
164+ """Get PyTorch module and ONNX node mapping .
165165
166166 Args:
167- q_config (dict ): quantization configuration from PyTorch.
167+ fp32_model (torch.nn.Module ): quantization configuration from PyTorch.
168168 fp32_onnx_path (str): path to fp32 onnx model.
169169
170170 Returns:
171- quantize_nodes: all onnx node that should be quantized.
172171 module_node_mapping: op mapping from PyTorch to ONNX.
173172 linear_matmul_list: contains matmul that comes from linear.
174173 """
174+ def check_data (op_type , data , module_dict ):
175+ for name , value in module_dict .items ():
176+ if value .shape == data .shape :
177+ if (value == data ).all ():
178+ return name
179+ # Convolution weight data mismatch.
180+ elif op_type == 'Conv' and np .allclose (value , data ):
181+ return name
182+ return None
183+
184+ module_dict = {}
185+ for name , module in fp32_model .named_modules ():
186+ if 'Conv' in str (module .__class__ .__name__ ) or \
187+ 'Embedding' in str (module .__class__ .__name__ ) or \
188+ 'Linear' in str (module .__class__ .__name__ ):
189+ if hasattr (module , 'weight' ):
190+ value = module .weight .detach ().cpu ().numpy ()
191+ module_dict [name ] = value
192+
193+ module_node_mapping = {}
194+ linear_matmul_list = []
175195 fp32_onnx_model = onnx .load (fp32_onnx_path )
176- # Clarify ONNX nodes that we can mapping from PyTorch
177- if 'dynamic' in q_config ['approach' ]:
178- op_types_to_quantize = ['MatMul' , 'Gather' , "LSTM" ]
179- else :
180- op_types_to_quantize = ['MatMul' , 'Gather' , 'Conv' ]
196+ initializer_data = {tensor .name : tensor for tensor in fp32_onnx_model .graph .initializer }
197+ from onnx import numpy_helper
198+ for node in fp32_onnx_model .graph .node :
199+ if node .op_type in op_types_to_quantize :
200+ if node .op_type == 'MatMul' and node .input [1 ] in initializer_data :
201+ data = numpy_helper .to_array (initializer_data [node .input [1 ]]).T
202+ elif node .op_type == 'Gather' and node .input [0 ] in initializer_data :
203+ data = numpy_helper .to_array (initializer_data [node .input [0 ]])
204+ elif node .op_type in ['Conv' , 'Gemm' ]:
205+ data = numpy_helper .to_array (initializer_data [node .input [1 ]])
206+ else :
207+ continue
208+ pt_name = check_data (node .op_type , data , module_dict )
209+ if pt_name :
210+ module_node_mapping [pt_name ] = node .name
211+ if node .op_type == 'MatMul' :
212+ linear_matmul_list .append (node .name )
213+ return module_node_mapping , linear_matmul_list
181214
182- from neural_compressor .adaptor .onnxrt import ONNXRTAdaptor
183- # pylint: disable=E1120
184- fp32_onnx_model = ONNXRTAdaptor ._replace_gemm_with_matmul (fp32_onnx_model ).model
185- onnx .save (fp32_onnx_model , fp32_onnx_path )
186215
187- # Get weight name from onnx initializer
188- weight_name_list = []
189- for tensor in fp32_onnx_model .graph .initializer :
190- weight_name_list .append (tensor .name )
216+ def get_quantizable_onnx_ops (
217+ int8_model ,
218+ module_node_mapping
219+ ):
220+ """Get quantizable onnx ops.
221+
222+ Args:
223+ int8_model (torch.nn.Module): PyTorch int8 model.
224+ module_node_mapping (dict): op mapping from PyTorch to ONNX.
191225
192- # Match weight name with onnx node name
226+ Returns:
227+ quantize_nodes: all onnx node that should be quantized.
228+ """
193229 quantize_nodes = []
194- tmp_node_mapping = {}
195- module_node_mapping = {}
196- for node in fp32_onnx_model .graph .node :
197- if node .op_type not in op_types_to_quantize :
198- for inp in node .input :
199- if inp in weight_name_list and 'weight' in inp :
200- tmp_node_mapping .update ({node .output [0 ] : inp .split ('.weight' )[0 ]})
201- elif inp in tmp_node_mapping :
202- tmp_node_mapping .update ({node .output [0 ] : tmp_node_mapping [inp ]})
203- else :
204- for inp in node .input :
205- if inp in weight_name_list and 'weight' in inp :
206- module_node_mapping .update ({inp .split ('.weight' )[0 ] : node .name })
207- elif inp in tmp_node_mapping :
208- module_node_mapping .update ({tmp_node_mapping [inp ]: node .name })
209-
210- quantize_nodes = list (module_node_mapping .values ())
211- # Fetch all matmul in ONNX that comes from PyTorch Linear
212- # Match pytorch module name with onnx node name for fallbacked fp32 module
213- linear_matmul_list = []
214- for k , v in q_config ['op' ].items (): # pragma: no cover
215- if 'Linear' in k [1 ]:
216- k_0 = k [0 ].split ('.module' )[0 ] if k [0 ] not in module_node_mapping else k [0 ]
217- linear_matmul_list .append (module_node_mapping [k_0 ])
218- if not 'int8' in v ['weight' ]['dtype' ]:
219- k_0 = k [0 ].split ('.module' )[0 ] if k [0 ] not in module_node_mapping else k [0 ]
220- if k [0 ] in module_node_mapping :
221- fallback_op = module_node_mapping [k_0 ]
222- quantize_nodes .remove (fallback_op )
223- return quantize_nodes , module_node_mapping , linear_matmul_list
230+ for name , module in int8_model .named_modules ():
231+ if 'Conv' in str (module .__class__ .__name__ ) or \
232+ 'Embedding' in str (module .__class__ .__name__ ) or \
233+ 'Linear' in str (module .__class__ .__name__ ):
234+ if hasattr (module , 'weight' ) and callable (module .weight ):
235+ if module .weight ().dtype == torch .qint8 :
236+ node = module_node_mapping [name .split ('.module' )[0 ]]
237+ quantize_nodes .append (node )
238+ return quantize_nodes
224239
225240
226241def get_scale_info (
@@ -257,7 +272,7 @@ def build_scale_mapping(
257272 module_node_mapping ,
258273 int8_scale_info ,
259274):
260- """_summary_
275+ """Build scale mapping.
261276
262277 Args:
263278 fp32_onnx_path (str): path to fp32 onnx model.
@@ -405,13 +420,17 @@ def torch_to_int8_onnx(
405420 quant_format (str, optional): quantization format of ONNX model. Defaults to 'QDQ'.
406421 dtype (str, optional): data types of activation and weight of ONNX model. Defaults to 'U8S8'.
407422 """
423+ global op_types_to_quantize
424+ if q_config ['approach' ] == 'post_training_dynamic_quant' :
425+ op_types_to_quantize = ['MatMul' , 'Gemm' , 'Gather' ]
426+ else :
427+ op_types_to_quantize = ['MatMul' , 'Gemm' , 'Gather' , 'Conv' ]
428+
408429 if quant_format == 'QDQ' and opset_version < 13 : # pragma: no cover
409430 opset_version = 13
410431 logger .warning ("QDQ format requires opset_version >= 13, " +
411432 "we reset opset_version={} here" .format (opset_version ))
412433
413- activation_type , weight_type = set_data_type (dtype )
414-
415434 # pylint: disable=E1101
416435 fp32_onnx_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp'
417436 torch_to_fp32_onnx (
@@ -422,13 +441,12 @@ def torch_to_int8_onnx(
422441 input_names = input_names ,
423442 output_names = output_names ,
424443 dynamic_axes = dynamic_axes ,
425- do_constant_folding = False ,
426444 verbose = False ,
427445 )
428446
429- quantize_nodes , module_node_mapping , linear_matmul_list = get_quantizable_onnx_ops (
430- q_config , fp32_onnx_path
431- )
447+ activation_type , weight_type = set_data_type ( dtype )
448+ module_node_mapping , linear_matmul_list = get_node_mapping ( fp32_model , fp32_onnx_path )
449+ quantize_nodes = get_quantizable_onnx_ops ( int8_model , module_node_mapping )
432450
433451 if q_config ['approach' ] == 'quant_aware_training' :
434452 update_weight_bias (int8_model , fp32_onnx_path )
@@ -439,6 +457,7 @@ def torch_to_int8_onnx(
439457 quant_format = ortq .QuantFormat .QOperator if quant_format != 'QDQ' else ortq .QuantFormat .QDQ
440458
441459 if q_config ['approach' ] == 'post_training_dynamic_quant' :
460+ logger .info ("Quantization format is not avalible when executing dynamic quantization." )
442461 ortq .quantize_dynamic (
443462 fp32_onnx_path ,
444463 save_path ,
0 commit comments