Skip to content

Commit 958d3b0

Browse files
committed
enhance node mapping
Signed-off-by: Xin He <[email protected]>
1 parent 2299aed commit 958d3b0

File tree

2 files changed

+89
-68
lines changed

2 files changed

+89
-68
lines changed

neural_compressor/experimental/export/torch2onnx.py

Lines changed: 75 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -157,70 +157,85 @@ def set_data_type(
157157
return activation_type, weight_type
158158

159159

160-
def get_quantizable_onnx_ops(
161-
q_config,
160+
def get_node_mapping(
161+
fp32_model,
162162
fp32_onnx_path,
163163
):
164-
"""Get quantizable onnx ops.
164+
"""Get PyTorch module and ONNX node mapping.
165165
166166
Args:
167-
q_config (dict): quantization configuration from PyTorch.
167+
fp32_model (torch.nn.Module): quantization configuration from PyTorch.
168168
fp32_onnx_path (str): path to fp32 onnx model.
169169
170170
Returns:
171-
quantize_nodes: all onnx node that should be quantized.
172171
module_node_mapping: op mapping from PyTorch to ONNX.
173172
linear_matmul_list: contains matmul that comes from linear.
174173
"""
174+
def check_data(op_type, data, module_dict):
175+
for name, value in module_dict.items():
176+
if value.shape == data.shape:
177+
if (value == data).all():
178+
return name
179+
# Convolution weight data mismatch.
180+
elif op_type == 'Conv' and np.allclose(value, data):
181+
return name
182+
return None
183+
184+
module_dict = {}
185+
for name, module in fp32_model.named_modules():
186+
if 'Conv' in str(module.__class__.__name__) or \
187+
'Embedding' in str(module.__class__.__name__) or \
188+
'Linear' in str(module.__class__.__name__):
189+
if hasattr(module, 'weight'):
190+
value = module.weight.detach().cpu().numpy()
191+
module_dict[name] = value
192+
193+
module_node_mapping = {}
194+
linear_matmul_list = []
175195
fp32_onnx_model = onnx.load(fp32_onnx_path)
176-
# Clarify ONNX nodes that we can mapping from PyTorch
177-
if 'dynamic' in q_config['approach']:
178-
op_types_to_quantize=['MatMul', 'Gather', "LSTM"]
179-
else:
180-
op_types_to_quantize=['MatMul', 'Gather', 'Conv']
196+
initializer_data = {tensor.name: tensor for tensor in fp32_onnx_model.graph.initializer}
197+
from onnx import numpy_helper
198+
for node in fp32_onnx_model.graph.node:
199+
if node.op_type in op_types_to_quantize:
200+
if node.op_type == 'MatMul' and node.input[1] in initializer_data:
201+
data = numpy_helper.to_array(initializer_data[node.input[1]]).T
202+
elif node.op_type == 'Gather' and node.input[0] in initializer_data:
203+
data = numpy_helper.to_array(initializer_data[node.input[0]])
204+
elif node.op_type in ['Conv', 'Gemm']:
205+
data = numpy_helper.to_array(initializer_data[node.input[1]])
206+
else:
207+
continue
208+
pt_name = check_data(node.op_type, data, module_dict)
209+
if pt_name:
210+
module_node_mapping[pt_name] = node.name
211+
if node.op_type == 'MatMul':
212+
linear_matmul_list.append(node.name)
213+
return module_node_mapping, linear_matmul_list
181214

182-
from neural_compressor.adaptor.onnxrt import ONNXRTAdaptor
183-
# pylint: disable=E1120
184-
fp32_onnx_model = ONNXRTAdaptor._replace_gemm_with_matmul(fp32_onnx_model).model
185-
onnx.save(fp32_onnx_model, fp32_onnx_path)
186215

187-
# Get weight name from onnx initializer
188-
weight_name_list = []
189-
for tensor in fp32_onnx_model.graph.initializer:
190-
weight_name_list.append(tensor.name)
216+
def get_quantizable_onnx_ops(
217+
int8_model,
218+
module_node_mapping
219+
):
220+
"""Get quantizable onnx ops.
221+
222+
Args:
223+
int8_model (torch.nn.Module): PyTorch int8 model.
224+
module_node_mapping (dict): op mapping from PyTorch to ONNX.
191225
192-
# Match weight name with onnx node name
226+
Returns:
227+
quantize_nodes: all onnx node that should be quantized.
228+
"""
193229
quantize_nodes = []
194-
tmp_node_mapping = {}
195-
module_node_mapping = {}
196-
for node in fp32_onnx_model.graph.node:
197-
if node.op_type not in op_types_to_quantize:
198-
for inp in node.input:
199-
if inp in weight_name_list and 'weight' in inp:
200-
tmp_node_mapping.update({node.output[0] : inp.split('.weight')[0]})
201-
elif inp in tmp_node_mapping:
202-
tmp_node_mapping.update({node.output[0] : tmp_node_mapping[inp]})
203-
else:
204-
for inp in node.input:
205-
if inp in weight_name_list and 'weight' in inp:
206-
module_node_mapping.update({inp.split('.weight')[0] : node.name})
207-
elif inp in tmp_node_mapping:
208-
module_node_mapping.update({tmp_node_mapping[inp]: node.name})
209-
210-
quantize_nodes = list(module_node_mapping.values())
211-
# Fetch all matmul in ONNX that comes from PyTorch Linear
212-
# Match pytorch module name with onnx node name for fallbacked fp32 module
213-
linear_matmul_list = []
214-
for k, v in q_config['op'].items(): # pragma: no cover
215-
if 'Linear' in k[1]:
216-
k_0 = k[0].split('.module')[0] if k[0] not in module_node_mapping else k[0]
217-
linear_matmul_list.append(module_node_mapping[k_0])
218-
if not 'int8' in v['weight']['dtype']:
219-
k_0 = k[0].split('.module')[0] if k[0] not in module_node_mapping else k[0]
220-
if k[0] in module_node_mapping:
221-
fallback_op = module_node_mapping[k_0]
222-
quantize_nodes.remove(fallback_op)
223-
return quantize_nodes, module_node_mapping, linear_matmul_list
230+
for name, module in int8_model.named_modules():
231+
if 'Conv' in str(module.__class__.__name__) or \
232+
'Embedding' in str(module.__class__.__name__) or \
233+
'Linear' in str(module.__class__.__name__):
234+
if hasattr(module, 'weight') and callable(module.weight):
235+
if module.weight().dtype == torch.qint8:
236+
node = module_node_mapping[name.split('.module')[0]]
237+
quantize_nodes.append(node)
238+
return quantize_nodes
224239

225240

226241
def get_scale_info(
@@ -257,7 +272,7 @@ def build_scale_mapping(
257272
module_node_mapping,
258273
int8_scale_info,
259274
):
260-
"""_summary_
275+
"""Build scale mapping.
261276
262277
Args:
263278
fp32_onnx_path (str): path to fp32 onnx model.
@@ -405,13 +420,17 @@ def torch_to_int8_onnx(
405420
quant_format (str, optional): quantization format of ONNX model. Defaults to 'QDQ'.
406421
dtype (str, optional): data types of activation and weight of ONNX model. Defaults to 'U8S8'.
407422
"""
423+
global op_types_to_quantize
424+
if q_config['approach'] == 'post_training_dynamic_quant':
425+
op_types_to_quantize=['MatMul', 'Gemm', 'Gather']
426+
else:
427+
op_types_to_quantize=['MatMul', 'Gemm', 'Gather', 'Conv']
428+
408429
if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover
409430
opset_version = 13
410431
logger.warning("QDQ format requires opset_version >= 13, " +
411432
"we reset opset_version={} here".format(opset_version))
412433

413-
activation_type, weight_type = set_data_type(dtype)
414-
415434
# pylint: disable=E1101
416435
fp32_onnx_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp'
417436
torch_to_fp32_onnx(
@@ -422,13 +441,12 @@ def torch_to_int8_onnx(
422441
input_names=input_names,
423442
output_names=output_names,
424443
dynamic_axes=dynamic_axes,
425-
do_constant_folding=False,
426444
verbose=False,
427445
)
428446

429-
quantize_nodes, module_node_mapping, linear_matmul_list = get_quantizable_onnx_ops(
430-
q_config, fp32_onnx_path
431-
)
447+
activation_type, weight_type = set_data_type(dtype)
448+
module_node_mapping, linear_matmul_list = get_node_mapping(fp32_model, fp32_onnx_path)
449+
quantize_nodes = get_quantizable_onnx_ops(int8_model, module_node_mapping)
432450

433451
if q_config['approach'] == 'quant_aware_training':
434452
update_weight_bias(int8_model, fp32_onnx_path)
@@ -439,6 +457,7 @@ def torch_to_int8_onnx(
439457
quant_format = ortq.QuantFormat.QOperator if quant_format != 'QDQ' else ortq.QuantFormat.QDQ
440458

441459
if q_config['approach'] == 'post_training_dynamic_quant':
460+
logger.info("Quantization format is not avalible when executing dynamic quantization.")
442461
ortq.quantize_dynamic(
443462
fp32_onnx_path,
444463
save_path,

test/export/test_torch2onnx.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,12 @@ def setUpClass(self):
9494
@classmethod
9595
def tearDownClass(self):
9696
shutil.rmtree('nc_workspace', ignore_errors=True)
97-
os.remove('fp32-cv-model.onnx')
98-
os.remove('int8-cv-model.onnx')
99-
os.remove('fp32-nlp-model.onnx')
100-
os.remove('int8-nlp-model.onnx')
97+
# os.remove('fp32-cv-model.onnx')
98+
# os.remove('int8-cv-qdq-model.onnx')
99+
# os.remove('int8-cv-qlinear-model.onnx')
100+
# os.remove('fp32-nlp-model.onnx')
101+
# os.remove('int8-nlp-qdq-model.onnx')
102+
# os.remove('int8-nlp-qlinear-model.onnx')
101103

102104
def test_fp32_CV_models(self):
103105
model = self.cv_model
@@ -151,8 +153,8 @@ def test_int8_CV_models(self):
151153
dynamic_axes={"input": {0: "batch_size"},
152154
"output": {0: "batch_size"}},
153155
)
154-
q_model.export('int8-cv-model.onnx', int8_onnx_config)
155-
check_CV_onnx('int8-cv-model.onnx', self.cv_dataloader)
156+
q_model.export('int8-cv-qdq-model.onnx', int8_onnx_config)
157+
check_CV_onnx('int8-cv-qdq-model.onnx', self.cv_dataloader)
156158

157159
int8_onnx_config = Torch2ONNXConfig(
158160
dtype="int8",
@@ -164,8 +166,8 @@ def test_int8_CV_models(self):
164166
dynamic_axes={"input": {0: "batch_size"},
165167
"output": {0: "batch_size"}},
166168
)
167-
q_model.export('int8-cv-model.onnx', int8_onnx_config)
168-
check_CV_onnx('int8-cv-model.onnx', self.cv_dataloader)
169+
q_model.export('int8-cv-qlinear-model.onnx', int8_onnx_config)
170+
check_CV_onnx('int8-cv-qlinear-model.onnx', self.cv_dataloader)
169171

170172
def test_fp32_NLP_models(self):
171173
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
@@ -227,8 +229,8 @@ def test_int8_NLP_models(self):
227229
output_names=['labels'],
228230
dynamic_axes=dynamic_axes,
229231
)
230-
q_model.export('int8-nlp-model.onnx', int8_onnx_config)
231-
check_NLP_onnx('int8-nlp-model.onnx', self.nlp_input)
232+
q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
233+
check_NLP_onnx('int8-nlp-qdq-model.onnx', self.nlp_input)
232234

233235
int8_onnx_config = Torch2ONNXConfig(
234236
dtype="int8",
@@ -239,8 +241,8 @@ def test_int8_NLP_models(self):
239241
output_names=['labels'],
240242
dynamic_axes=dynamic_axes,
241243
)
242-
q_model.export('int8-nlp-model.onnx', int8_onnx_config)
243-
check_NLP_onnx('int8-nlp-model.onnx', self.nlp_input)
244+
q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config)
245+
check_NLP_onnx('int8-nlp-qlinear-model.onnx', self.nlp_input)
244246

245247
if __name__ == "__main__":
246248
unittest.main()

0 commit comments

Comments
 (0)