@@ -275,6 +275,15 @@ def set_scale_info(
275275
276276
277277def remove_nodes_by_name (int8_onnx_model , node_names ):
278+ """Remove nodes from model by names.
279+
280+ Args:
281+ int8_onnx_model (ModelProto): onnx int8 model to process.
282+ node_names (list): names of nodes to remove.
283+
284+ Returns:
285+ int8_onnx_model: processed onnx int8 model.
286+ """
278287 while node_names :
279288 for node in int8_onnx_model .graph .node :
280289 if node .name in node_names :
@@ -283,14 +292,27 @@ def remove_nodes_by_name(int8_onnx_model, node_names):
283292 return int8_onnx_model
284293
285294
286- def generate_int32_bias_structure (
295+ def sub_graph_with_int32_bias (
287296 int8_onnx_model ,
288297 node ,
289298 a_info ,
290299 b_info ,
291300 bias_name ,
292301 output_name ,
293302):
303+ """Generate a sub graph with int32 bias.
304+
305+ Args:
306+ int8_onnx_model (ModelProto): onnx int8 model to process.
307+ node (NodeProto): MatMul node belonging to nn.quantized.Linear module.
308+ a_info (list): info of input a for nn.quantized.Linear module.
309+ b_info (list): info of input b for nn.quantized.Linear module.
310+ bias_name (str): name of bias.
311+ output_name (_type_): output name of the sub graph.
312+
313+ Returns:
314+ int8_onnx_model: processed onnx int8 model.
315+ """
294316 from onnx import TensorProto
295317 a , a_scale , a_zero_point = a_info
296318 b , b_scale , b_zero_point = b_info
@@ -344,7 +366,16 @@ def qdq_model_use_int32_bias(
344366 int8_onnx_model ,
345367 quantize_nodes ,
346368):
347- # nn.quantized.Lienar module will be converted to the following format:
369+ """Export a QDQ model with recalculated int32 bias and remapped input scale and zero point
370+ for nn.quantized.Linear module.
371+
372+ Args:
373+ int8_onnx_model (ModelProto): onnx int8 model to process.
374+
375+ Returns:
376+ int8_onnx_model: processed onnx int8 model.
377+ """
378+ # nn.quantized.Linear module will be converted to the following format:
348379 # QuantizeLinear
349380 # |
350381 # MatMulInteger
@@ -372,7 +403,7 @@ def qdq_model_use_int32_bias(
372403 if grand_parent :
373404 replace_input [parent .output [0 ]] = grand_parent [0 ].input [0 ]
374405
375- int8_onnx_model = generate_int32_bias_structure (int8_onnx_model ,
406+ int8_onnx_model = sub_graph_with_int32_bias (int8_onnx_model ,
376407 node ,
377408 parents [0 ].input [:3 ],
378409 parents [1 ].input [:3 ],
@@ -394,7 +425,16 @@ def qdq_model_use_output_scale_zp(
394425 int8_onnx_model ,
395426 quantize_nodes ,
396427):
397- # nn.quantized.Lienar module will be converted to the following format:
428+ """Export a QDQ model with FP32 bias and remapped in/output scale and zero point
429+ for nn.quantized.Linear module.
430+
431+ Args:
432+ int8_onnx_model (ModelProto): onnx int8 model to process.
433+
434+ Returns:
435+ int8_onnx_model: processed onnx int8 model.
436+ """
437+ # nn.quantized.Linear module will be converted to the following format:
398438 # QuantizeLinear
399439 # |
400440 # DequantizeLinear DequantizeLinear
@@ -425,6 +465,15 @@ def qdq_model_use_output_scale_zp(
425465def qop_model_default (
426466 int8_onnx_model
427467):
468+ """Export a QOperator model with FP32 bias and remapped input scale and zero point
469+ for nn.quantized.Linear module.
470+
471+ Args:
472+ int8_onnx_model (ModelProto): onnx int8 model to process.
473+
474+ Returns:
475+ int8_onnx_model: processed onnx int8 model.
476+ """
428477 # nn.quantized.Linear module will be converted to the following format:
429478 # QuantizeLinear
430479 # |
@@ -461,7 +510,16 @@ def qop_model_default(
461510def qop_model_use_int32_bias (
462511 int8_onnx_model
463512):
464- # nn.quantized.Lienar module will be converted to the following format:
513+ """Export a QOperator model with recalculated int32 bias and remapped input scale and zero point
514+ for nn.quantized.Linear module.
515+
516+ Args:
517+ int8_onnx_model (ModelProto): onnx int8 model to process.
518+
519+ Returns:
520+ int8_onnx_model: processed onnx int8 model.
521+ """
522+ # nn.quantized.Linear module will be converted to the following format:
465523 # QuantizeLinear
466524 # |
467525 # MatMulInteger
@@ -484,7 +542,7 @@ def qop_model_use_int32_bias(
484542 if not bias_name : # pragma: no cover
485543 continue
486544
487- int8_onnx_model = generate_int32_bias_structure (int8_onnx_model ,
545+ int8_onnx_model = sub_graph_with_int32_bias (int8_onnx_model ,
488546 node ,
489547 node .input [:3 ],
490548 node .input [3 :6 ],
@@ -501,6 +559,15 @@ def qop_model_use_int32_bias(
501559def qop_model_use_output_scale_zp (
502560 int8_onnx_model
503561):
562+ """Export a QOperator model with FP32 bias and remapped in/output scale and zero point
563+ for nn.quantized.Linear module.
564+
565+ Args:
566+ int8_onnx_model (ModelProto): onnx int8 model to process.
567+
568+ Returns:
569+ int8_onnx_model: processed onnx int8 model.
570+ """
504571 # nn.quantized.Lienar module will be converted to the following format:
505572 # QuantizeLinear
506573 # |
@@ -627,11 +694,11 @@ def torch_to_int8_onnx(
627694 quant_format (str, optional): quantization format of ONNX model. Defaults to 'QDQ'.
628695 dtype (str, optional): data types of activation and weight of ONNX model. Defaults to 'U8S8'.
629696 linear_options (dict, optionl): Recipe with options for processing nn.quantized.Linear module.
630- Recipe 1: use fp32 bias, map input scale and zero point from PyTorch model.
631- Recipe 2: use int32 bias, map input scale and zero point from PyTorch model.
632- Recipe 3: use fp32 bias, map input and otput scale and zero point from PyTorch model.
633- Defaults to recipe 1: {'use_int32_bias': False,
634- 'use_output_scale_zp': False}
697+ Recipe 1: use fp32 bias, map input scale and zero point from PyTorch model.
698+ Recipe 2: use int32 bias, map input scale and zero point from PyTorch model.
699+ Recipe 3: use fp32 bias, map input and otput scale and zero point from PyTorch model.
700+ Defaults to recipe 1: {'use_int32_bias': False,
701+ 'use_output_scale_zp': False}
635702 """
636703 global op_types_to_quantize
637704 if q_config ['approach' ] == 'post_training_dynamic_quant' :
0 commit comments