Skip to content

Commit 089b7e1

Browse files
committed
update PT2ONNX export recipe
Signed-off-by: yuwenzho <[email protected]>
1 parent 317e7a8 commit 089b7e1

File tree

3 files changed

+560
-3
lines changed

3 files changed

+560
-3
lines changed

neural_compressor/experimental/export/torch2onnx.py

Lines changed: 308 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,277 @@ def set_scale_info(
288288
)
289289
int8_onnx_model.graph.initializer.remove(tensor)
290290
int8_onnx_model.graph.initializer.append(new_tensor)
291-
onnx.save(int8_onnx_model, int8_onnx_path)
291+
return int8_onnx_model
292+
293+
294+
def remove_nodes_by_name(int8_onnx_model, node_names):
295+
while node_names:
296+
for node in int8_onnx_model.graph.node:
297+
if node.name in node_names:
298+
int8_onnx_model.graph.node.remove(node)
299+
node_names.remove(node.name)
300+
return int8_onnx_model
301+
302+
303+
def generate_int32_bias_structure(
304+
int8_onnx_model,
305+
node,
306+
a_info,
307+
b_info,
308+
bias_name,
309+
output_name,
310+
):
311+
from onnx import TensorProto
312+
a, a_scale, a_zero_point = a_info
313+
b, b_scale, b_zero_point = b_info
314+
a_scale = ortq.onnx_model.ONNXModel(int8_onnx_model).get_initializer(a_scale)
315+
a_scale = onnx.numpy_helper.to_array(a_scale)
316+
b_scale = ortq.onnx_model.ONNXModel(int8_onnx_model).get_initializer(b_scale)
317+
b_scale = onnx.numpy_helper.to_array(b_scale)
318+
bias = ortq.onnx_model.ONNXModel(int8_onnx_model).get_initializer(bias_name)
319+
bias_dims = bias.dims
320+
bias = onnx.numpy_helper.to_array(bias)
321+
bias_scale = a_scale * b_scale
322+
quantized_bias = (bias / bias_scale).round().astype(np.int32)
323+
quantized_bias = np.asarray(quantized_bias, dtype=np.int32).reshape(bias_dims)
324+
packed_bias_initializer = onnx.numpy_helper.from_array(quantized_bias,
325+
bias_name + "_quantized")
326+
int8_onnx_model.graph.initializer.extend([packed_bias_initializer])
327+
328+
matmul_node = onnx.helper.make_node("MatMulInteger",
329+
inputs=[a, b, a_zero_point, b_zero_point],
330+
outputs=[node.output[0] + '_matmulinteger'],
331+
name = node.name + '_matmulinteger')
332+
add_node = onnx.helper.make_node("Add",
333+
inputs=[node.output[0] + '_matmulinteger', bias_name + '_quantized'],
334+
outputs=[node.output[0] + '_add'],
335+
name = node.name + '_add'
336+
)
337+
cast_node = onnx.helper.make_node("Cast",
338+
inputs=[node.output[0] + '_add'],
339+
outputs=[node.output[0] + '_cast'],
340+
to=getattr(TensorProto, 'FLOAT'),
341+
name = node.name + '_cast')
342+
343+
new_tensor = onnx.helper.make_tensor(
344+
name=node.name + '_bias_scale',
345+
data_type=TensorProto.FLOAT,
346+
dims=list(bias_scale.shape),
347+
vals=bias_scale,
348+
)
349+
int8_onnx_model.graph.initializer.append(new_tensor)
350+
351+
mul_node = onnx.helper.make_node("Mul",
352+
inputs=[node.output[0] + '_cast', node.name + '_bias_scale'],
353+
outputs=[output_name],
354+
name=node.name + '_mul')
355+
356+
int8_onnx_model.graph.node.extend([matmul_node, add_node, cast_node, mul_node])
357+
return int8_onnx_model
358+
359+
360+
def qdq_model_use_int32_bias(
361+
int8_onnx_model,
362+
quantize_nodes,
363+
):
364+
# nn.quantized.Lienar module will be converted to the following format:
365+
# QuantizeLinear
366+
# |
367+
# MatMulInteger
368+
# |
369+
# Add
370+
# |
371+
# Cast
372+
# |
373+
# Mul
374+
remove_nodes = set()
375+
replace_input = {}
376+
for node in int8_onnx_model.graph.node:
377+
if node.name in quantize_nodes and node.op_type == 'MatMul':
378+
parents = ortq.onnx_model.ONNXModel(int8_onnx_model).get_parents(node)
379+
add_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(node)[0]
380+
bias_name = None
381+
for inp in add_node.input:
382+
if inp.endswith('.bias'):
383+
bias_name = inp
384+
if not bias_name: # pragma: no cover
385+
continue
386+
387+
for parent in parents:
388+
grand_parent = ortq.onnx_model.ONNXModel(int8_onnx_model).get_parents(parent)
389+
if grand_parent:
390+
replace_input[parent.output[0]] = grand_parent[0].input[0]
391+
392+
int8_onnx_model = generate_int32_bias_structure(int8_onnx_model,
393+
node,
394+
parents[0].input[:3],
395+
parents[1].input[:3],
396+
bias_name,
397+
add_node.output[0])
398+
remove_nodes.add(node.name)
399+
remove_nodes.add(parents[0].name)
400+
remove_nodes.add(parents[1].name)
401+
remove_nodes.add(add_node.name)
402+
int8_onnx_model = remove_nodes_by_name(int8_onnx_model, remove_nodes)
403+
for node in int8_onnx_model.graph.node: # pragma: no cover
404+
for i in range(len(node.input)):
405+
if node.input[i] in replace_input:
406+
node.input[i] = replace_input[node.input[i]]
407+
return int8_onnx_model
408+
409+
410+
def qdq_model_use_output_scale_zp(
411+
int8_onnx_model,
412+
quantize_nodes,
413+
):
414+
# nn.quantized.Lienar module will be converted to the following format:
415+
# QuantizeLinear
416+
# |
417+
# DequantizeLinear DequantizeLinear
418+
# | |
419+
# --------------------
420+
# |
421+
# MatMul
422+
# |
423+
# Add
424+
# |
425+
# QuantizeLinear
426+
# |
427+
# DequantizeLinear
428+
for node in int8_onnx_model.graph.node:
429+
if node.name in quantize_nodes and node.op_type == 'MatMul':
430+
quantizelinear_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(node)[0]
431+
deqauntizelinear_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(quantizelinear_node)[0]
432+
add_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(deqauntizelinear_node)[0]
433+
deqauntizelinear_node.output[0] = add_node.output[0]
434+
add_node.output[0] = add_node.output[0] + '_add'
435+
for i in range(len(add_node.input)):
436+
if not add_node.input[i].endswith('.bias'):
437+
add_node.input[i] = node.output[0]
438+
quantizelinear_node.input[0] = add_node.output[0]
439+
return int8_onnx_model
440+
441+
442+
def qop_model_default(
443+
int8_onnx_model
444+
):
445+
# nn.quantized.Lienar module will be converted to the following format:
446+
# QuantizeLinear
447+
# |
448+
# MatMulIntegerToFloat
449+
# |
450+
# Add
451+
remove_nodes = set()
452+
for node in int8_onnx_model.graph.node:
453+
if node.op_type == 'QLinearMatMul':
454+
dequantizelinear_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(node)[0]
455+
add_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(dequantizelinear_node)[0]
456+
a = node.input[0]
457+
a_scale = node.input[1]
458+
a_zero_point = node.input[2]
459+
b = node.input[3]
460+
b_scale = node.input[4]
461+
b_zero_point = node.input[5]
462+
matmulintegertofloat_node = onnx.helper.make_node("MatMulIntegerToFloat",
463+
inputs=[a, b, a_scale, b_scale, a_zero_point, b_zero_point],
464+
outputs=[node.output[0]],
465+
name=node.name + '_matmulintegertofloat',
466+
domain='com.microsoft')
467+
for idx in range(len(add_node.input)):
468+
if add_node.input[idx] == dequantizelinear_node.output[0]:
469+
add_node.input[idx] = node.output[0]
470+
remove_nodes.add(node.name)
471+
remove_nodes.add(dequantizelinear_node.name)
472+
int8_onnx_model.graph.node.extend([matmulintegertofloat_node])
473+
474+
int8_onnx_model = remove_nodes_by_name(int8_onnx_model, remove_nodes)
475+
return int8_onnx_model
476+
477+
478+
def qop_model_use_int32_bias(
479+
int8_onnx_model
480+
):
481+
# nn.quantized.Lienar module will be converted to the following format:
482+
# QuantizeLinear
483+
# |
484+
# MatMulInteger
485+
# |
486+
# Add
487+
# |
488+
# Cast
489+
# |
490+
# Mul
491+
remove_nodes = set()
492+
for node in int8_onnx_model.graph.node:
493+
if node.op_type == 'QLinearMatMul':
494+
dequantizelinear_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(node)[0]
495+
add_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(dequantizelinear_node)[0]
496+
497+
bias_name = None
498+
for inp in add_node.input:
499+
if inp.endswith('.bias'):
500+
bias_name = inp
501+
if not bias_name: # pragma: no cover
502+
continue
503+
504+
int8_onnx_model = generate_int32_bias_structure(int8_onnx_model,
505+
node,
506+
node.input[:3],
507+
node.input[3:6],
508+
bias_name,
509+
add_node.output[0])
510+
remove_nodes.add(node.name)
511+
remove_nodes.add(add_node.name)
512+
remove_nodes.add(dequantizelinear_node.name)
513+
514+
int8_onnx_model = remove_nodes_by_name(int8_onnx_model, remove_nodes)
515+
return int8_onnx_model
516+
517+
518+
def qop_model_use_output_scale_zp(
519+
int8_onnx_model
520+
):
521+
# nn.quantized.Lienar module will be converted to the following format:
522+
# QuantizeLinear
523+
# |
524+
# MatMulIntegerToFloat
525+
# |
526+
# Add
527+
# |
528+
# QuantizeLinear
529+
# |
530+
# DequantizeLinear
531+
import copy
532+
remove_nodes = set()
533+
for node in int8_onnx_model.graph.node:
534+
if node.op_type == 'QLinearMatMul':
535+
dequantizelinear_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(node)[0]
536+
add_node = ortq.onnx_model.ONNXModel(int8_onnx_model).get_children(dequantizelinear_node)[0]
537+
a, a_scale, a_zero_point, b, b_scale, b_zero_point, y_scale, y_zero_point = node.input[:8]
538+
matmulintegertofloat_node = onnx.helper.make_node("MatMulIntegerToFloat",
539+
inputs=[a, b, a_scale, b_scale, a_zero_point, b_zero_point],
540+
outputs=[node.output[0]],
541+
name=node.name + '_matmulintegertofloat',
542+
domain='com.microsoft')
543+
544+
for idx in range(len(add_node.input)):
545+
if add_node.input[idx] == dequantizelinear_node.output[0]:
546+
add_node.input[idx] = node.output[0]
547+
548+
quantizelinear_node = onnx.helper.make_node("QuantizeLinear",
549+
inputs=[add_node.output[0] +'_add', y_scale, y_zero_point],
550+
outputs=[node.output[0] + '_quantizelinear'],
551+
name=node.name + '_quantizelinear')
552+
553+
dequantizelinear_node.input[0] = node.output[0] + '_quantizelinear'
554+
dequantizelinear_node.output[0] = copy.deepcopy(add_node.output[0])
555+
add_node.output[0] = add_node.output[0] +'_add'
556+
557+
remove_nodes.add(node.name)
558+
int8_onnx_model.graph.node.extend([matmulintegertofloat_node, quantizelinear_node])
559+
560+
int8_onnx_model = remove_nodes_by_name(int8_onnx_model, remove_nodes)
561+
return int8_onnx_model
292562

293563

294564
def torch_to_fp32_onnx(
@@ -355,6 +625,8 @@ def torch_to_int8_onnx(
355625
output_names=None,
356626
quant_format: str = 'QDQ',
357627
dtype: str = 'U8S8',
628+
linear_options: dict = {'use_int32_bias': False,
629+
'use_output_scale_zp': False},
358630
):
359631
"""Export INT8 PyTorch model into INT8 ONNX model.
360632
@@ -371,6 +643,12 @@ def torch_to_int8_onnx(
371643
output_names (list, optional): output names. Defaults to None.
372644
quant_format (str, optional): quantization format of ONNX model. Defaults to 'QDQ'.
373645
dtype (str, optional): data types of activation and weight of ONNX model. Defaults to 'U8S8'.
646+
linear_options (dict, optionl): Recipe with options for processing nn.quantized.Linear module.
647+
Recipe 1: use fp32 bias, map input scale and zero point from PyTorch model.
648+
Recipe 2: use int32 bias, map input scale and zero point from PyTorch model.
649+
Recipe 3: use fp32 bias, map input and otput scale and zero point from PyTorch model.
650+
Defaults to recipe 1: {'use_int32_bias': False,
651+
'use_output_scale_zp': False}
374652
"""
375653
global op_types_to_quantize
376654
if q_config['approach'] == 'post_training_dynamic_quant':
@@ -382,6 +660,16 @@ def torch_to_int8_onnx(
382660
opset_version = 13
383661
logger.warning("QDQ format requires opset_version >= 13, " +
384662
"we reset opset_version={} here".format(opset_version))
663+
664+
use_int32_bias = linear_options['use_int32_bias']
665+
use_output_scale_zp = linear_options['use_output_scale_zp']
666+
if use_int32_bias and use_output_scale_zp: # pragma: no cover
667+
use_output_scale_zp = False
668+
linear_options = {'use_int32_bias': use_int32_bias,
669+
'use_output_scale_zp': use_output_scale_zp}
670+
logger.warning(f"For linear_options, only one of 'use_int32_bias' "
671+
f"and 'use_output_scale_zp' can be set to True. "
672+
f"We reset linear_options = {linear_options} here")
385673

386674
# pylint: disable=E1101
387675
fp32_onnx_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp'
@@ -408,6 +696,9 @@ def torch_to_int8_onnx(
408696

409697
quant_format = ortq.QuantFormat.QOperator if quant_format != 'QDQ' else ortq.QuantFormat.QDQ
410698

699+
extra_options = {'OpTypesToExcludeOutputQuantizatioin': ['MatMul']} \
700+
if (not use_output_scale_zp and quant_format == ortq.QuantFormat.QDQ) else {}
701+
411702
if q_config['approach'] == 'post_training_dynamic_quant':
412703
logger.info("Quantization format is not avalible when executing dynamic quantization.")
413704
ortq.quantize_dynamic(
@@ -433,10 +724,24 @@ def torch_to_int8_onnx(
433724
activation_type=activation_type,
434725
nodes_to_quantize=quantize_nodes,
435726
nodes_to_exclude=[],
436-
extra_options={'OpTypesToExcludeOutputQuantizatioin': ['MatMul']},
727+
extra_options=extra_options,
437728
)
438729

439-
set_scale_info(save_path, scale_mapping, activation_type)
730+
int8_onnx_model = set_scale_info(save_path, scale_mapping, activation_type)
731+
if quant_format == ortq.QuantFormat.QDQ:
732+
if use_int32_bias:
733+
int8_onnx_model = qdq_model_use_int32_bias(int8_onnx_model, quantize_nodes)
734+
if use_output_scale_zp:
735+
int8_onnx_model = qdq_model_use_output_scale_zp(int8_onnx_model, quantize_nodes)
736+
elif quant_format == ortq.QuantFormat.QOperator:
737+
if not use_int32_bias and not use_output_scale_zp:
738+
int8_onnx_model = qop_model_default(int8_onnx_model)
739+
if use_int32_bias:
740+
int8_onnx_model = qop_model_use_int32_bias(int8_onnx_model)
741+
if use_output_scale_zp:
742+
int8_onnx_model = qop_model_use_output_scale_zp(int8_onnx_model)
743+
744+
onnx.save(int8_onnx_model, save_path)
440745

441746
os.remove(fp32_onnx_path)
442747
info = "The INT8 ONNX Model is exported to path: {0}".format(save_path)

neural_compressor/model/torch_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,23 @@ def export(
666666
torch_to_int8_onnx
667667
)
668668
if conf.dtype == 'int8':
669+
linear_options = conf.kwargs.pop('linear_options', {'use_int32_bias': False,
670+
'use_output_scale_zp': False})
671+
recipe = conf.kwargs.pop('recipe', 1)
672+
assert recipe in [1, 2, 3], "`recipe` refers to how to process " \
673+
"nn.quantized.Linear module, which can only be 1 or 2 or 3."
674+
if recipe == 1:
675+
use_int32_bias = False
676+
use_output_scale_zp = False
677+
elif recipe == 2:
678+
use_int32_bias = True
679+
use_output_scale_zp = False
680+
elif recipe == 3:
681+
use_int32_bias = False
682+
use_output_scale_zp = True
683+
linear_options = {'use_int32_bias': use_int32_bias,
684+
'use_output_scale_zp': use_output_scale_zp}
685+
669686
torch_to_int8_onnx(
670687
self.fp32_model,
671688
self.model,
@@ -678,6 +695,7 @@ def export(
678695
output_names=conf.output_names,
679696
quant_format=conf.quant_format,
680697
dtype='U8S8',
698+
linear_options=linear_options
681699
)
682700
elif conf.dtype == 'fp32':
683701
torch_to_fp32_onnx(

0 commit comments

Comments
 (0)