Skip to content

Commit 255eb0c

Browse files
committed
enable scale mapping for fused module
Signed-off-by: Xin He <[email protected]>
1 parent eab46f7 commit 255eb0c

File tree

4 files changed

+54
-54
lines changed

4 files changed

+54
-54
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,6 +2701,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27012701
repr(e)))
27022702
q_model = model
27032703
q_model._model.eval()
2704+
hook_list = torch_utils.util._set_input_scale_hook(q_model._model, op_cfgs)
27042705
if q_model.kwargs is not None:
27052706
self.prepare_custom_config_dict = q_model.kwargs.get('prepare_custom_config_dict',
27062707
None)
@@ -2738,7 +2739,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27382739
# q_func can be created by neural_compressor internal or passed by user. It's critical to
27392740
# distinguish how q_func is passed since neural_compressor built-in functions accept
27402741
# neural_compressor model and user defined func should accept framework model.
2741-
hook_list = torch_utils.util._set_input_scale_hook(q_model._model, op_cfgs)
27422742
q_model._model = q_func(
27432743
q_model if getattr(q_func, 'builtin', None) else q_model._model)
27442744
assert q_model._model is not None, "Please return a trained model in train function!"
@@ -2767,7 +2767,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27672767
prefix='',
27682768
example_inputs=example_inputs)
27692769
if self.approach in ['post_training_static_quant', 'post_training_auto_quant']:
2770-
hook_list = torch_utils.util._set_input_scale_hook(q_model._model, op_cfgs)
27712770
iterations = tune_cfg.get('calib_iteration', 1)
27722771
if q_func is not None:
27732772
q_func(q_model._model)
@@ -2778,7 +2777,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27782777
calib_sampling_size=tune_cfg.get('calib_sampling_size', 1))
27792778

27802779
if self.approach != 'post_training_dynamic_quant':
2781-
input_scale_info = torch_utils.util._get_input_scale(q_model._model, hook_list)
2780+
scale_info = torch_utils.util._get_input_scale(q_model._model, hook_list)
27822781

27832782
if self.sub_module_list is None:
27842783
if self.version > Version("1.12.1"): # pragma: no cover
@@ -2802,7 +2801,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
28022801
q_model.q_config = copy.deepcopy(self.tune_cfg)
28032802
if self.approach != 'post_training_dynamic_quant':
28042803
self._get_scale_zeropoint(q_model._model, q_model.q_config)
2805-
q_model.q_config['input_scale_info'] = input_scale_info
2804+
q_model.q_config['scale_info'] = scale_info
28062805

28072806
self._dump_model_op_stats(q_model._model, q_model.q_config, self.approach)
28082807
torch_utils.util.get_embedding_contiguous(q_model._model)
@@ -2940,8 +2939,8 @@ def _pre_hook_for_qat(self, dataloader=None):
29402939
hook_list = torch_utils.util._set_input_scale_hook(self.model._model, quantized_ops)
29412940

29422941
def _post_hook_for_qat(self):
2943-
input_scale_info = torch_utils.util._get_input_scale(self.model._model, hook_list)
2944-
self.model.q_config['input_scale_info'] = input_scale_info
2942+
scale_info = torch_utils.util._get_input_scale(self.model._model, hook_list)
2943+
self.model.q_config['scale_info'] = scale_info
29452944
from torch.quantization.quantize_fx import convert_fx
29462945
if self.sub_module_list is None:
29472946
if self.version > Version("1.12.1"): # pragma: no cover

neural_compressor/adaptor/torch_utils/util.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ def contiguous_hook(module, input):
4444
child.register_forward_pre_hook(contiguous_hook)
4545

4646

47+
def is_fused_module(module):
48+
"""This is a helper function for `_propagate_qconfig_helper` to detecte
49+
if this module is fused.
50+
51+
Args:
52+
module (object): input module
53+
54+
Returns:
55+
(bool): is fused or not
56+
"""
57+
op_type = str(type(module))
58+
if 'fused' in op_type:
59+
return True
60+
else:
61+
return False
62+
63+
4764
def _set_input_scale_hook(model, op_cfgs):
4865
"""Insert hooks to observer input scale and zeropoint.
4966
@@ -55,19 +72,24 @@ def _set_input_scale_hook(model, op_cfgs):
5572
hook_list (list): input observer hooks
5673
"""
5774
def input_scale_hook(module, input):
58-
module.input_observer = module.input_config.activation()
75+
module.input_observer = module.qconfig.activation()
5976
module.input_observer(input[0])
6077
return input
6178

79+
def output_scale_hook(module, input, output):
80+
module.output_observer = module.qconfig.activation()
81+
module.output_observer(output)
82+
return output
83+
6284
hook_list = []
6385
for name, module in model.named_modules():
6486
if 'Conv' in str(module.__class__.__name__) or \
6587
'Linear' in str(module.__class__.__name__):
66-
if name not in op_cfgs or op_cfgs[name] is None:
88+
if is_fused_module(module):
6789
continue
68-
module.input_config = op_cfgs[name]
69-
handle = module.register_forward_pre_hook(input_scale_hook)
70-
hook_list.append(handle)
90+
handle_in = module.register_forward_pre_hook(input_scale_hook)
91+
handle_out = module.register_forward_hook(output_scale_hook)
92+
hook_list.extend([handle_in, handle_out])
7193
return hook_list
7294

7395

@@ -81,19 +103,20 @@ def _get_input_scale(model, hook_list):
81103
Returns:
82104
input_scale_info (dict): input scale and zero_point of each modules
83105
"""
84-
input_scale_info = {}
106+
scale_info = {}
85107
for name, module in model.named_modules():
86-
if hasattr(module, "input_observer"):
87-
scale, zero_point = module.input_observer.calculate_qparams()
88-
input_scale_info[name] = {
89-
'scale': float(scale),
90-
'zero_point': int(zero_point)
108+
if hasattr(module, "input_observer") and hasattr(module, "output_observer"):
109+
scale_in, zero_point_in = module.input_observer.calculate_qparams()
110+
scale_out, zero_point_out = module.output_observer.calculate_qparams()
111+
scale_info[name] = {
112+
'input_scale': float(scale_in),
113+
'input_zeropoint': int(zero_point_in),
114+
'output_scale': float(scale_out),
115+
'output_zeropoint': int(zero_point_out)
91116
}
92-
if hasattr(module, "input_config"):
93-
del module.input_config
94117
for h in hook_list:
95118
h.remove()
96-
return input_scale_info
119+
return scale_info
97120

98121

99122
def collate_torch_preds(results):

neural_compressor/experimental/export/torch2onnx.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -195,35 +195,6 @@ def get_quantizable_onnx_ops(
195195
return quantize_nodes
196196

197197

198-
def get_scale_info(
199-
int8_model,
200-
q_config,
201-
):
202-
"""Fetch scale information from q_config.
203-
204-
Args:
205-
int8_model (torch.nn.Module): PyTorch int8 model.
206-
q_config (dict): quantization configuration.
207-
208-
Returns:
209-
int8_scale_info: int8 scale infomation.
210-
"""
211-
# get output scale and zp from module
212-
int8_scale_info = {}
213-
for name, scale_info in q_config['input_scale_info'].items():
214-
int8_scale_info[name] = {
215-
'input_scale': scale_info['scale'],
216-
'input_zeropoint': scale_info['zero_point'],
217-
}
218-
for name, module in int8_model.named_modules():
219-
if name in int8_scale_info:
220-
int8_scale_info[name].update({
221-
'output_scale': module.scale,
222-
'output_zeropoint': module.zero_point,
223-
})
224-
return int8_scale_info
225-
226-
227198
def build_scale_mapping(
228199
fp32_onnx_path,
229200
module_node_mapping,
@@ -242,14 +213,21 @@ def build_scale_mapping(
242213
node_module_mapping = {}
243214
for module_name, node_name in module_node_mapping.items():
244215
node_module_mapping[node_name] = module_name
245-
# match scale and zeropoint from PyTorch to ONNX node
216+
# Match scale and zeropoint from PyTorch to ONNX node
246217
scale_zp_dict = {}
247218
fp32_onnx_model = onnx.load(fp32_onnx_path)
248219
for node in fp32_onnx_model.graph.node:
249220
if node.name in node_module_mapping:
250221
module_name = node_module_mapping[node.name]
251-
if module_name not in int8_scale_info:
222+
223+
# For fine-grained fx and fuse pattern
224+
if module_name + '.module' in int8_scale_info:
252225
module_name = module_name + '.module'
226+
elif module_name + '.0' in int8_scale_info:
227+
module_name = module_name + '.0'
228+
elif module_name + '.module.0' in int8_scale_info:
229+
module_name = module_name + '.module.0'
230+
253231
if module_name in int8_scale_info:
254232
recoder = int8_scale_info[module_name]
255233
input_scale_args = node.input[0] + '_scale'
@@ -447,7 +425,7 @@ def qdq_model_use_output_scale_zp(
447425
def qop_model_default(
448426
int8_onnx_model
449427
):
450-
# nn.quantized.Lienar module will be converted to the following format:
428+
# nn.quantized.Linear module will be converted to the following format:
451429
# QuantizeLinear
452430
# |
453431
# MatMulIntegerToFloat
@@ -696,7 +674,7 @@ def torch_to_int8_onnx(
696674
if q_config['approach'] == 'quant_aware_training':
697675
update_weight_bias(int8_model, fp32_onnx_path)
698676
if q_config['approach'] != 'post_training_dynamic_quant':
699-
int8_scale_info = get_scale_info(int8_model, q_config)
677+
int8_scale_info = q_config['scale_info']
700678
scale_mapping = build_scale_mapping(fp32_onnx_path, module_node_mapping, int8_scale_info)
701679

702680
quant_format = ortq.QuantFormat.QOperator if quant_format != 'QDQ' else ortq.QuantFormat.QDQ

neural_compressor/experimental/export/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, fp32_onnx_path):
6262
for dim in node.shape:
6363
shape.append(dim if isinstance(dim, int) else 1)
6464
dtype = ONNX2Numpy_dtype(node.type)
65-
input[node.name] = np.ones(shape).astype(dtype)
65+
input[node.name] = np.zeros(shape).astype(dtype)
6666
self.data = [input]
6767
self.data = iter(self.data)
6868

0 commit comments

Comments
 (0)