@@ -67,14 +67,6 @@ def __init__(self, framework_specific_info):
6767 "supported backends: {}" .format (ONNXRT_BACKENDS [self .backend ],
6868 [ONNXRT_BACKENDS [i ] for i in ort .get_all_providers ()]))
6969
70- if self .backend == 'TensorrtExecutionProvider' :
71- from neural_compressor import options
72- options .onnxrt .qdq_setting .AddQDQPairToWeight = True
73- options .onnxrt .qdq_setting .DedicatedQDQPair = True
74- options .onnxrt .graph_optimization .level = 'DISABLE_ALL'
75- self .static = True
76- self .dynamic = False
77-
7870 if (not self .dynamic and "format" in framework_specific_info and \
7971 framework_specific_info ["format" ].lower () == 'qdq' ) or \
8072 self .backend == 'TensorrtExecutionProvider' :
@@ -114,6 +106,16 @@ def __init__(self, framework_specific_info):
114106 self .quantizable_op_types += \
115107 self .query_handler .get_op_types_by_precision (precision = precision )
116108
109+ if self .backend == 'TensorrtExecutionProvider' :
110+ from neural_compressor import options
111+ options .onnxrt .qdq_setting .AddQDQPairToWeight = True
112+ options .onnxrt .qdq_setting .DedicatedQDQPair = True
113+ options .onnxrt .graph_optimization .level = 'DISABLE_ALL'
114+ options .onnxrt .qdq_setting .OpTypesToExcludeOutputQuantizatioin = \
115+ ['Conv' , 'Gemm' , 'Add' , 'MatMul' ]
116+ self .static = True
117+ self .dynamic = False
118+
117119 self .evaluate_nums = 0
118120
119121 self .fp32_results = []
@@ -517,10 +519,44 @@ def _pre_optimize(self, model, level=1):
517519 if self .graph_optimization .gemm2matmul else tmp_model
518520 model .model = self ._rename_node (model .model )
519521 model = self ._revert_fusedconv (model )
522+ if self .backend == 'TensorrtExecutionProvider' :
523+ model = self ._revert_conv_add_fusion (model )
520524 model = split_shared_bias (model )
521525 model .topological_sort ()
522526 self .pre_optimized_model = copy .deepcopy (model )
523527
528+ def _revert_conv_add_fusion (self , model ):
529+ from onnx import numpy_helper
530+ from neural_compressor .adaptor .ox_utils .util import attribute_to_kwarg
531+ add_nodes = []
532+ remove_nodes = []
533+ for node in model .model .graph .node :
534+ if node .op_type == 'Conv' and len (node .input ) == 3 :
535+ bias_tensor = model .get_initializer (node .input [2 ])
536+ bias_array = numpy_helper .to_array (bias_tensor ).reshape ((- 1 , 1 , 1 ))
537+ model .remove_initializer (bias_tensor )
538+ model .add_initializer (numpy_helper .from_array (bias_array , bias_tensor .name ))
539+ kwargs = {}
540+ activation_params = None
541+ for attr in node .attribute :
542+ kwargs .update (attribute_to_kwarg (attr ))
543+ conv = onnx .helper .make_node (
544+ 'Conv' ,
545+ node .input [0 :2 ],
546+ [node .name + '_revert' ],
547+ node .name , ** kwargs )
548+ add = onnx .helper .make_node (
549+ 'Add' ,
550+ [conv .output [0 ], node .input [2 ]],
551+ node .output ,
552+ node .name + '_add' )
553+ add_nodes .extend ([conv , add ])
554+
555+ model .remove_nodes (remove_nodes )
556+ model .add_nodes (add_nodes )
557+ model .update ()
558+ return model
559+
524560 def _revert_fusedconv (self , model ):
525561 from neural_compressor .adaptor .ox_utils .util import attribute_to_kwarg
526562 from onnx import onnx_pb as onnx_proto
@@ -684,6 +720,10 @@ def query_fw_capability(self, model):
684720 else :
685721 continue
686722
723+ if self .backend == 'TensorrtExecutionProvider' and \
724+ precision not in query .get_fallback_list ():
725+ optypes .append ('Add' )
726+
687727 for op in optypes :
688728 if op not in quantizable_optype :
689729 continue
@@ -736,6 +776,14 @@ def query_fw_capability(self, model):
736776 all_conv_matmul .append (node )
737777
738778 for _ , node in enumerate (self .pre_optimized_model .nodes ()):
779+ # for TRT EP, only insert Q/DQ to inputs of Add nodes followed by ReduceMean
780+ if node .op_type == 'Add' and self .backend == 'TensorrtExecutionProvider' :
781+ children = self .pre_optimized_model .get_children (node )
782+ if 'ReduceMean' not in [i .op_type for i in children ]:
783+ op_wise .update ({(node .name , node .op_type ):
784+ [{'weight' : {'dtype' : 'fp32' }, 'activation' : {'dtype' : 'fp32' }}]})
785+ continue
786+
739787 if node .op_type in optype_wise :
740788 if (exclude_first_quantizable_op and node .name in first_quantizable_node ) \
741789 or (exclude_last_quantizable_op and node .name in last_quantizable_node ):
0 commit comments