Skip to content

Commit 764357b

Browse files
authored
Add recipe for TRT EP (#278)
* Add recipe for TRT EP Signed-off-by: Mengni Wang <[email protected]> * remove codes Signed-off-by: Mengni Wang <[email protected]> Signed-off-by: Mengni Wang <[email protected]>
1 parent 7d1e1f9 commit 764357b

File tree

3 files changed

+83
-20
lines changed

3 files changed

+83
-20
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,6 @@ def __init__(self, framework_specific_info):
6767
"supported backends: {}".format(ONNXRT_BACKENDS[self.backend],
6868
[ONNXRT_BACKENDS[i] for i in ort.get_all_providers()]))
6969

70-
if self.backend == 'TensorrtExecutionProvider':
71-
from neural_compressor import options
72-
options.onnxrt.qdq_setting.AddQDQPairToWeight = True
73-
options.onnxrt.qdq_setting.DedicatedQDQPair = True
74-
options.onnxrt.graph_optimization.level = 'DISABLE_ALL'
75-
self.static = True
76-
self.dynamic = False
77-
7870
if (not self.dynamic and "format" in framework_specific_info and \
7971
framework_specific_info["format"].lower() == 'qdq') or \
8072
self.backend == 'TensorrtExecutionProvider':
@@ -114,6 +106,16 @@ def __init__(self, framework_specific_info):
114106
self.quantizable_op_types += \
115107
self.query_handler.get_op_types_by_precision(precision=precision)
116108

109+
if self.backend == 'TensorrtExecutionProvider':
110+
from neural_compressor import options
111+
options.onnxrt.qdq_setting.AddQDQPairToWeight = True
112+
options.onnxrt.qdq_setting.DedicatedQDQPair = True
113+
options.onnxrt.graph_optimization.level = 'DISABLE_ALL'
114+
options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin = \
115+
['Conv', 'Gemm', 'Add', 'MatMul']
116+
self.static = True
117+
self.dynamic = False
118+
117119
self.evaluate_nums = 0
118120

119121
self.fp32_results = []
@@ -517,10 +519,44 @@ def _pre_optimize(self, model, level=1):
517519
if self.graph_optimization.gemm2matmul else tmp_model
518520
model.model = self._rename_node(model.model)
519521
model = self._revert_fusedconv(model)
522+
if self.backend == 'TensorrtExecutionProvider':
523+
model = self._revert_conv_add_fusion(model)
520524
model = split_shared_bias(model)
521525
model.topological_sort()
522526
self.pre_optimized_model = copy.deepcopy(model)
523527

528+
def _revert_conv_add_fusion(self, model):
529+
from onnx import numpy_helper
530+
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg
531+
add_nodes = []
532+
remove_nodes = []
533+
for node in model.model.graph.node:
534+
if node.op_type == 'Conv' and len(node.input) == 3:
535+
bias_tensor = model.get_initializer(node.input[2])
536+
bias_array = numpy_helper.to_array(bias_tensor).reshape((-1, 1, 1))
537+
model.remove_initializer(bias_tensor)
538+
model.add_initializer(numpy_helper.from_array(bias_array, bias_tensor.name))
539+
kwargs = {}
540+
activation_params = None
541+
for attr in node.attribute:
542+
kwargs.update(attribute_to_kwarg(attr))
543+
conv = onnx.helper.make_node(
544+
'Conv',
545+
node.input[0:2],
546+
[node.name + '_revert'],
547+
node.name, **kwargs)
548+
add = onnx.helper.make_node(
549+
'Add',
550+
[conv.output[0], node.input[2]],
551+
node.output,
552+
node.name + '_add')
553+
add_nodes.extend([conv, add])
554+
555+
model.remove_nodes(remove_nodes)
556+
model.add_nodes(add_nodes)
557+
model.update()
558+
return model
559+
524560
def _revert_fusedconv(self, model):
525561
from neural_compressor.adaptor.ox_utils.util import attribute_to_kwarg
526562
from onnx import onnx_pb as onnx_proto
@@ -684,6 +720,10 @@ def query_fw_capability(self, model):
684720
else:
685721
continue
686722

723+
if self.backend == 'TensorrtExecutionProvider' and \
724+
precision not in query.get_fallback_list():
725+
optypes.append('Add')
726+
687727
for op in optypes:
688728
if op not in quantizable_optype:
689729
continue
@@ -736,6 +776,14 @@ def query_fw_capability(self, model):
736776
all_conv_matmul.append(node)
737777

738778
for _, node in enumerate(self.pre_optimized_model.nodes()):
779+
# for TRT EP, only insert Q/DQ to inputs of Add nodes followed by ReduceMean
780+
if node.op_type == 'Add' and self.backend == 'TensorrtExecutionProvider':
781+
children = self.pre_optimized_model.get_children(node)
782+
if 'ReduceMean' not in [i.op_type for i in children]:
783+
op_wise.update({(node.name, node.op_type):
784+
[{'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}}]})
785+
continue
786+
739787
if node.op_type in optype_wise:
740788
if (exclude_first_quantizable_op and node.name in first_quantizable_node) \
741789
or (exclude_last_quantizable_op and node.name in last_quantizable_node):

neural_compressor/adaptor/onnxrt_qdq.yaml

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,21 @@
7575
CPUExecutionProvider: *ref_1_7
7676
CUDAExecutionProvider: *ref_1_7
7777
TensorrtExecutionProvider: {
78-
'Conv': &cap_s8_sym_pertensor_default {
78+
'Conv': &cap_s8_sym_default {
79+
'weight': {
80+
'dtype': ['int8'],
81+
'scheme': ['sym'],
82+
'granularity': ['per_tensor', 'per_channel'],
83+
'algorithm': ['minmax']
84+
},
85+
'activation': {
86+
'dtype': ['int8'],
87+
'scheme': ['sym'],
88+
'granularity': ['per_tensor'],
89+
'algorithm': ['minmax']
90+
}
91+
},
92+
'MatMul': &cap_s8_sym_pertensor_default {
7993
'weight': {
8094
'dtype': ['int8'],
8195
'scheme': ['sym'],
@@ -89,16 +103,16 @@
89103
'algorithm': ['minmax']
90104
}
91105
},
92-
'MatMul': *cap_s8_sym_pertensor_default,
93106
'Attention': *cap_s8_sym_pertensor_default,
94107
'LeakyRelu': *cap_s8_sym_pertensor_default,
95-
'Gather': *cap_s8_sym_pertensor_default,
108+
'Gather': *cap_s8_sym_default,
96109
'Sigmoid': *cap_s8_sym_pertensor_default,
97110
'MaxPool': *cap_s8_sym_pertensor_default,
98111
'EmbedLayerNormalization': *cap_s8_sym_pertensor_default,
99112
'GlobalAveragePool': *cap_s8_sym_pertensor_default,
100113
'Pad': *cap_s8_sym_pertensor_default,
101114
'Split': *cap_s8_sym_pertensor_default,
115+
'Add': *cap_s8_sym_pertensor_default,
102116
}
103117

104118
graph_optimization: &default_optimization # from onnxruntime graph_optimization_level
@@ -123,11 +137,11 @@
123137
CPUExecutionProvider: *ref_1_7
124138
CUDAExecutionProvider: *ref_1_7
125139
TensorrtExecutionProvider: &ref_1_8 {
126-
'Conv': *cap_s8_sym_pertensor_default,
140+
'Conv': *cap_s8_sym_default,
127141
'MatMul': *cap_s8_sym_pertensor_default,
128142
'Attention': *cap_s8_sym_pertensor_default,
129143
'LeakyRelu': *cap_s8_sym_pertensor_default,
130-
'Gather': *cap_s8_sym_pertensor_default,
144+
'Gather': *cap_s8_sym_default,
131145
'Sigmoid': *cap_s8_sym_pertensor_default,
132146
'MaxPool': *cap_s8_sym_pertensor_default,
133147
'EmbedLayerNormalization': *cap_s8_sym_pertensor_default,
@@ -140,7 +154,8 @@
140154
'AveragePool': *cap_s8_sym_pertensor_default,
141155
'Unsqueeze': *cap_s8_sym_pertensor_default,
142156
'Transpose': *cap_s8_sym_pertensor_default,
143-
'Resize': *cap_s8_sym_pertensor_default
157+
'Resize': *cap_s8_sym_pertensor_default,
158+
'Add': *cap_s8_sym_pertensor_default,
144159
}
145160

146161
graph_optimization:
@@ -317,11 +332,11 @@
317332
CPUExecutionProvider: *ref_1_11
318333
CUDAExecutionProvider: *ref_1_11
319334
TensorrtExecutionProvider: {
320-
'Conv': *cap_s8_sym_pertensor_default,
321-
'MatMul': *cap_s8_sym_pertensor_default,
335+
'Conv': *cap_s8_sym_default,
336+
'MatMul': *cap_s8_sym_default,
322337
'Attention': *cap_s8_sym_pertensor_default,
323338
'LeakyRelu': *cap_s8_sym_pertensor_default,
324-
'Gather': *cap_s8_sym_pertensor_default,
339+
'Gather': *cap_s8_sym_default,
325340
'Sigmoid': *cap_s8_sym_pertensor_default,
326341
'MaxPool': *cap_s8_sym_pertensor_default,
327342
'EmbedLayerNormalization': *cap_s8_sym_pertensor_default,
@@ -335,7 +350,8 @@
335350
'Unsqueeze': *cap_s8_sym_pertensor_default,
336351
'Transpose': *cap_s8_sym_pertensor_default,
337352
'Resize': *cap_s8_sym_pertensor_default,
338-
'Gemm': *cap_s8_sym_pertensor_default
353+
'Gemm': *cap_s8_sym_default,
354+
'Add': *cap_s8_sym_pertensor_default,
339355
}
340356

341357

neural_compressor/adaptor/ox_utils/calibration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,7 @@ def calculate_quantization_params(self, q_config, quantization_thresholds):
436436
if parent and parent.name in q_config and q_config[parent.name] not in ['fp32']:
437437
scheme = q_config[parent.name]['activation']['scheme']
438438
qType = q_config[parent.name]['activation']['dtype']
439-
elif tensor_name in self.model_wrapper.input() and \
440-
self.backend in ['TensorrtExecutionProvider']:
439+
elif self.backend in ['TensorrtExecutionProvider']:
441440
scheme = 'sym'
442441
qType = 3
443442
node_thresholds = quantization_thresholds[tensor_name]

0 commit comments

Comments
 (0)