@@ -288,7 +288,277 @@ def set_scale_info(
288288 )
289289 int8_onnx_model .graph .initializer .remove (tensor )
290290 int8_onnx_model .graph .initializer .append (new_tensor )
291- onnx .save (int8_onnx_model , int8_onnx_path )
291+ return int8_onnx_model
292+
293+
294+ def remove_nodes_by_name (int8_onnx_model , node_names ):
295+ while node_names :
296+ for node in int8_onnx_model .graph .node :
297+ if node .name in node_names :
298+ int8_onnx_model .graph .node .remove (node )
299+ node_names .remove (node .name )
300+ return int8_onnx_model
301+
302+
303+ def generate_int32_bias_structure (
304+ int8_onnx_model ,
305+ node ,
306+ a_info ,
307+ b_info ,
308+ bias_name ,
309+ output_name ,
310+ ):
311+ from onnx import TensorProto
312+ a , a_scale , a_zero_point = a_info
313+ b , b_scale , b_zero_point = b_info
314+ a_scale = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_initializer (a_scale )
315+ a_scale = onnx .numpy_helper .to_array (a_scale )
316+ b_scale = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_initializer (b_scale )
317+ b_scale = onnx .numpy_helper .to_array (b_scale )
318+ bias = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_initializer (bias_name )
319+ bias_dims = bias .dims
320+ bias = onnx .numpy_helper .to_array (bias )
321+ bias_scale = a_scale * b_scale
322+ quantized_bias = (bias / bias_scale ).round ().astype (np .int32 )
323+ quantized_bias = np .asarray (quantized_bias , dtype = np .int32 ).reshape (bias_dims )
324+ packed_bias_initializer = onnx .numpy_helper .from_array (quantized_bias ,
325+ bias_name + "_quantized" )
326+ int8_onnx_model .graph .initializer .extend ([packed_bias_initializer ])
327+
328+ matmul_node = onnx .helper .make_node ("MatMulInteger" ,
329+ inputs = [a , b , a_zero_point , b_zero_point ],
330+ outputs = [node .output [0 ] + '_matmulinteger' ],
331+ name = node .name + '_matmulinteger' )
332+ add_node = onnx .helper .make_node ("Add" ,
333+ inputs = [node .output [0 ] + '_matmulinteger' , bias_name + '_quantized' ],
334+ outputs = [node .output [0 ] + '_add' ],
335+ name = node .name + '_add'
336+ )
337+ cast_node = onnx .helper .make_node ("Cast" ,
338+ inputs = [node .output [0 ] + '_add' ],
339+ outputs = [node .output [0 ] + '_cast' ],
340+ to = getattr (TensorProto , 'FLOAT' ),
341+ name = node .name + '_cast' )
342+
343+ new_tensor = onnx .helper .make_tensor (
344+ name = node .name + '_bias_scale' ,
345+ data_type = TensorProto .FLOAT ,
346+ dims = list (bias_scale .shape ),
347+ vals = bias_scale ,
348+ )
349+ int8_onnx_model .graph .initializer .append (new_tensor )
350+
351+ mul_node = onnx .helper .make_node ("Mul" ,
352+ inputs = [node .output [0 ] + '_cast' , node .name + '_bias_scale' ],
353+ outputs = [output_name ],
354+ name = node .name + '_mul' )
355+
356+ int8_onnx_model .graph .node .extend ([matmul_node , add_node , cast_node , mul_node ])
357+ return int8_onnx_model
358+
359+
360+ def qdq_model_use_int32_bias (
361+ int8_onnx_model ,
362+ quantize_nodes ,
363+ ):
364+ # nn.quantized.Lienar module will be converted to the following format:
365+ # QuantizeLinear
366+ # |
367+ # MatMulInteger
368+ # |
369+ # Add
370+ # |
371+ # Cast
372+ # |
373+ # Mul
374+ remove_nodes = set ()
375+ replace_input = {}
376+ for node in int8_onnx_model .graph .node :
377+ if node .name in quantize_nodes and node .op_type == 'MatMul' :
378+ parents = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_parents (node )
379+ add_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (node )[0 ]
380+ bias_name = None
381+ for inp in add_node .input :
382+ if inp .endswith ('.bias' ):
383+ bias_name = inp
384+ if not bias_name : # pragma: no cover
385+ continue
386+
387+ for parent in parents :
388+ grand_parent = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_parents (parent )
389+ if grand_parent :
390+ replace_input [parent .output [0 ]] = grand_parent [0 ].input [0 ]
391+
392+ int8_onnx_model = generate_int32_bias_structure (int8_onnx_model ,
393+ node ,
394+ parents [0 ].input [:3 ],
395+ parents [1 ].input [:3 ],
396+ bias_name ,
397+ add_node .output [0 ])
398+ remove_nodes .add (node .name )
399+ remove_nodes .add (parents [0 ].name )
400+ remove_nodes .add (parents [1 ].name )
401+ remove_nodes .add (add_node .name )
402+ int8_onnx_model = remove_nodes_by_name (int8_onnx_model , remove_nodes )
403+ for node in int8_onnx_model .graph .node : # pragma: no cover
404+ for i in range (len (node .input )):
405+ if node .input [i ] in replace_input :
406+ node .input [i ] = replace_input [node .input [i ]]
407+ return int8_onnx_model
408+
409+
410+ def qdq_model_use_output_scale_zp (
411+ int8_onnx_model ,
412+ quantize_nodes ,
413+ ):
414+ # nn.quantized.Lienar module will be converted to the following format:
415+ # QuantizeLinear
416+ # |
417+ # DequantizeLinear DequantizeLinear
418+ # | |
419+ # --------------------
420+ # |
421+ # MatMul
422+ # |
423+ # Add
424+ # |
425+ # QuantizeLinear
426+ # |
427+ # DequantizeLinear
428+ for node in int8_onnx_model .graph .node :
429+ if node .name in quantize_nodes and node .op_type == 'MatMul' :
430+ quantizelinear_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (node )[0 ]
431+ deqauntizelinear_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (quantizelinear_node )[0 ]
432+ add_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (deqauntizelinear_node )[0 ]
433+ deqauntizelinear_node .output [0 ] = add_node .output [0 ]
434+ add_node .output [0 ] = add_node .output [0 ] + '_add'
435+ for i in range (len (add_node .input )):
436+ if not add_node .input [i ].endswith ('.bias' ):
437+ add_node .input [i ] = node .output [0 ]
438+ quantizelinear_node .input [0 ] = add_node .output [0 ]
439+ return int8_onnx_model
440+
441+
442+ def qop_model_default (
443+ int8_onnx_model
444+ ):
445+ # nn.quantized.Lienar module will be converted to the following format:
446+ # QuantizeLinear
447+ # |
448+ # MatMulIntegerToFloat
449+ # |
450+ # Add
451+ remove_nodes = set ()
452+ for node in int8_onnx_model .graph .node :
453+ if node .op_type == 'QLinearMatMul' :
454+ dequantizelinear_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (node )[0 ]
455+ add_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (dequantizelinear_node )[0 ]
456+ a = node .input [0 ]
457+ a_scale = node .input [1 ]
458+ a_zero_point = node .input [2 ]
459+ b = node .input [3 ]
460+ b_scale = node .input [4 ]
461+ b_zero_point = node .input [5 ]
462+ matmulintegertofloat_node = onnx .helper .make_node ("MatMulIntegerToFloat" ,
463+ inputs = [a , b , a_scale , b_scale , a_zero_point , b_zero_point ],
464+ outputs = [node .output [0 ]],
465+ name = node .name + '_matmulintegertofloat' ,
466+ domain = 'com.microsoft' )
467+ for idx in range (len (add_node .input )):
468+ if add_node .input [idx ] == dequantizelinear_node .output [0 ]:
469+ add_node .input [idx ] = node .output [0 ]
470+ remove_nodes .add (node .name )
471+ remove_nodes .add (dequantizelinear_node .name )
472+ int8_onnx_model .graph .node .extend ([matmulintegertofloat_node ])
473+
474+ int8_onnx_model = remove_nodes_by_name (int8_onnx_model , remove_nodes )
475+ return int8_onnx_model
476+
477+
478+ def qop_model_use_int32_bias (
479+ int8_onnx_model
480+ ):
481+ # nn.quantized.Lienar module will be converted to the following format:
482+ # QuantizeLinear
483+ # |
484+ # MatMulInteger
485+ # |
486+ # Add
487+ # |
488+ # Cast
489+ # |
490+ # Mul
491+ remove_nodes = set ()
492+ for node in int8_onnx_model .graph .node :
493+ if node .op_type == 'QLinearMatMul' :
494+ dequantizelinear_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (node )[0 ]
495+ add_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (dequantizelinear_node )[0 ]
496+
497+ bias_name = None
498+ for inp in add_node .input :
499+ if inp .endswith ('.bias' ):
500+ bias_name = inp
501+ if not bias_name : # pragma: no cover
502+ continue
503+
504+ int8_onnx_model = generate_int32_bias_structure (int8_onnx_model ,
505+ node ,
506+ node .input [:3 ],
507+ node .input [3 :6 ],
508+ bias_name ,
509+ add_node .output [0 ])
510+ remove_nodes .add (node .name )
511+ remove_nodes .add (add_node .name )
512+ remove_nodes .add (dequantizelinear_node .name )
513+
514+ int8_onnx_model = remove_nodes_by_name (int8_onnx_model , remove_nodes )
515+ return int8_onnx_model
516+
517+
518+ def qop_model_use_output_scale_zp (
519+ int8_onnx_model
520+ ):
521+ # nn.quantized.Lienar module will be converted to the following format:
522+ # QuantizeLinear
523+ # |
524+ # MatMulIntegerToFloat
525+ # |
526+ # Add
527+ # |
528+ # QuantizeLinear
529+ # |
530+ # DequantizeLinear
531+ import copy
532+ remove_nodes = set ()
533+ for node in int8_onnx_model .graph .node :
534+ if node .op_type == 'QLinearMatMul' :
535+ dequantizelinear_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (node )[0 ]
536+ add_node = ortq .onnx_model .ONNXModel (int8_onnx_model ).get_children (dequantizelinear_node )[0 ]
537+ a , a_scale , a_zero_point , b , b_scale , b_zero_point , y_scale , y_zero_point = node .input [:8 ]
538+ matmulintegertofloat_node = onnx .helper .make_node ("MatMulIntegerToFloat" ,
539+ inputs = [a , b , a_scale , b_scale , a_zero_point , b_zero_point ],
540+ outputs = [node .output [0 ]],
541+ name = node .name + '_matmulintegertofloat' ,
542+ domain = 'com.microsoft' )
543+
544+ for idx in range (len (add_node .input )):
545+ if add_node .input [idx ] == dequantizelinear_node .output [0 ]:
546+ add_node .input [idx ] = node .output [0 ]
547+
548+ quantizelinear_node = onnx .helper .make_node ("QuantizeLinear" ,
549+ inputs = [add_node .output [0 ] + '_add' , y_scale , y_zero_point ],
550+ outputs = [node .output [0 ] + '_quantizelinear' ],
551+ name = node .name + '_quantizelinear' )
552+
553+ dequantizelinear_node .input [0 ] = node .output [0 ] + '_quantizelinear'
554+ dequantizelinear_node .output [0 ] = copy .deepcopy (add_node .output [0 ])
555+ add_node .output [0 ] = add_node .output [0 ] + '_add'
556+
557+ remove_nodes .add (node .name )
558+ int8_onnx_model .graph .node .extend ([matmulintegertofloat_node , quantizelinear_node ])
559+
560+ int8_onnx_model = remove_nodes_by_name (int8_onnx_model , remove_nodes )
561+ return int8_onnx_model
292562
293563
294564def torch_to_fp32_onnx (
@@ -355,6 +625,8 @@ def torch_to_int8_onnx(
355625 output_names = None ,
356626 quant_format : str = 'QDQ' ,
357627 dtype : str = 'U8S8' ,
628+ linear_options : dict = {'use_int32_bias' : False ,
629+ 'use_output_scale_zp' : False },
358630):
359631 """Export INT8 PyTorch model into INT8 ONNX model.
360632
@@ -371,6 +643,12 @@ def torch_to_int8_onnx(
371643 output_names (list, optional): output names. Defaults to None.
372644 quant_format (str, optional): quantization format of ONNX model. Defaults to 'QDQ'.
373645 dtype (str, optional): data types of activation and weight of ONNX model. Defaults to 'U8S8'.
646+ linear_options (dict, optionl): Recipe with options for processing nn.quantized.Linear module.
647+ Recipe 1: use fp32 bias, map input scale and zero point from PyTorch model.
648+ Recipe 2: use int32 bias, map input scale and zero point from PyTorch model.
649+ Recipe 3: use fp32 bias, map input and otput scale and zero point from PyTorch model.
650+ Defaults to recipe 1: {'use_int32_bias': False,
651+ 'use_output_scale_zp': False}
374652 """
375653 global op_types_to_quantize
376654 if q_config ['approach' ] == 'post_training_dynamic_quant' :
@@ -382,6 +660,16 @@ def torch_to_int8_onnx(
382660 opset_version = 13
383661 logger .warning ("QDQ format requires opset_version >= 13, " +
384662 "we reset opset_version={} here" .format (opset_version ))
663+
664+ use_int32_bias = linear_options ['use_int32_bias' ]
665+ use_output_scale_zp = linear_options ['use_output_scale_zp' ]
666+ if use_int32_bias and use_output_scale_zp : # pragma: no cover
667+ use_output_scale_zp = False
668+ linear_options = {'use_int32_bias' : use_int32_bias ,
669+ 'use_output_scale_zp' : use_output_scale_zp }
670+ logger .warning (f"For linear_options, only one of 'use_int32_bias' "
671+ f"and 'use_output_scale_zp' can be set to True. "
672+ f"We reset linear_options = { linear_options } here" )
385673
386674 # pylint: disable=E1101
387675 fp32_onnx_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp'
@@ -408,6 +696,9 @@ def torch_to_int8_onnx(
408696
409697 quant_format = ortq .QuantFormat .QOperator if quant_format != 'QDQ' else ortq .QuantFormat .QDQ
410698
699+ extra_options = {'OpTypesToExcludeOutputQuantizatioin' : ['MatMul' ]} \
700+ if (not use_output_scale_zp and quant_format == ortq .QuantFormat .QDQ ) else {}
701+
411702 if q_config ['approach' ] == 'post_training_dynamic_quant' :
412703 logger .info ("Quantization format is not avalible when executing dynamic quantization." )
413704 ortq .quantize_dynamic (
@@ -433,10 +724,24 @@ def torch_to_int8_onnx(
433724 activation_type = activation_type ,
434725 nodes_to_quantize = quantize_nodes ,
435726 nodes_to_exclude = [],
436- extra_options = { 'OpTypesToExcludeOutputQuantizatioin' : [ 'MatMul' ]} ,
727+ extra_options = extra_options ,
437728 )
438729
439- set_scale_info (save_path , scale_mapping , activation_type )
730+ int8_onnx_model = set_scale_info (save_path , scale_mapping , activation_type )
731+ if quant_format == ortq .QuantFormat .QDQ :
732+ if use_int32_bias :
733+ int8_onnx_model = qdq_model_use_int32_bias (int8_onnx_model , quantize_nodes )
734+ if use_output_scale_zp :
735+ int8_onnx_model = qdq_model_use_output_scale_zp (int8_onnx_model , quantize_nodes )
736+ elif quant_format == ortq .QuantFormat .QOperator :
737+ if not use_int32_bias and not use_output_scale_zp :
738+ int8_onnx_model = qop_model_default (int8_onnx_model )
739+ if use_int32_bias :
740+ int8_onnx_model = qop_model_use_int32_bias (int8_onnx_model )
741+ if use_output_scale_zp :
742+ int8_onnx_model = qop_model_use_output_scale_zp (int8_onnx_model )
743+
744+ onnx .save (int8_onnx_model , save_path )
440745
441746 os .remove (fp32_onnx_path )
442747 info = "The INT8 ONNX Model is exported to path: {0}" .format (save_path )
0 commit comments