Skip to content

Commit bd70199

Browse files
committed
enhance export API and add 3 recipes
Signed-off-by: Xin He <[email protected]>
1 parent 09c64db commit bd70199

File tree

8 files changed

+1446
-65
lines changed

8 files changed

+1446
-65
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2738,6 +2738,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27382738
# q_func can be created by neural_compressor internal or passed by user. It's critical to
27392739
# distinguish how q_func is passed since neural_compressor built-in functions accept
27402740
# neural_compressor model and user defined func should accept framework model.
2741+
# For export API
2742+
hook_list = torch_utils.util._set_input_scale_hook(q_model._model, op_cfgs)
27412743
q_model._model = q_func(
27422744
q_model if getattr(q_func, 'builtin', None) else q_model._model)
27432745
assert q_model._model is not None, "Please return a trained model in train function!"
@@ -2766,6 +2768,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27662768
prefix='',
27672769
example_inputs=example_inputs)
27682770
if self.approach in ['post_training_static_quant', 'post_training_auto_quant']:
2771+
# For export API
2772+
hook_list = torch_utils.util._set_input_scale_hook(q_model._model, op_cfgs)
27692773
iterations = tune_cfg.get('calib_iteration', 1)
27702774
if q_func is not None:
27712775
q_func(q_model._model)
@@ -2774,6 +2778,11 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27742778
dataloader,
27752779
iterations,
27762780
calib_sampling_size=tune_cfg.get('calib_sampling_size', 1))
2781+
2782+
if self.approach != 'post_training_dynamic_quant':
2783+
# For export API
2784+
scale_info = torch_utils.util._get_input_scale(q_model._model, hook_list)
2785+
27772786
if self.sub_module_list is None:
27782787
if self.version > Version("1.12.1"): # pragma: no cover
27792788
# pylint: disable=E1123
@@ -2796,6 +2805,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27962805
q_model.q_config = copy.deepcopy(self.tune_cfg)
27972806
if self.approach != 'post_training_dynamic_quant':
27982807
self._get_scale_zeropoint(q_model._model, q_model.q_config)
2808+
q_model.q_config['scale_info'] = scale_info
27992809

28002810
self._dump_model_op_stats(q_model._model, q_model.q_config, self.approach)
28012811
torch_utils.util.get_embedding_contiguous(q_model._model)
@@ -2873,10 +2883,11 @@ def _pre_hook_for_qat(self, dataloader=None):
28732883
quantized_ops[op[0]] = torch.quantization.default_dynamic_qconfig
28742884
else:
28752885
quantized_ops[op[0]] = q_cfgs
2876-
# build for fetching scale and zeropoint
2886+
# build op_config_dict to save module scale and zeropoint
28772887
op_config_dict = {}
28782888
for op in quantizable_ops:
28792889
op_config_dict[op] = {'weight': {'dtype': 'int8'}, 'activation': {'dtype': 'uint8'}}
2890+
28802891
if self.version.release < Version("1.11.0").release:
28812892
quantized_ops["default_qconfig"] = None
28822893
else:
@@ -2928,8 +2939,14 @@ def _pre_hook_for_qat(self, dataloader=None):
29282939
'sub_module_list': self.sub_module_list,
29292940
'approach': 'quant_aware_training'
29302941
}
2942+
# For export API
2943+
global hook_list
2944+
hook_list = torch_utils.util._set_input_scale_hook(self.model._model, quantized_ops)
29312945

29322946
def _post_hook_for_qat(self):
2947+
# For export API
2948+
scale_info = torch_utils.util._get_input_scale(self.model._model, hook_list)
2949+
self.model.q_config['scale_info'] = scale_info
29332950
from torch.quantization.quantize_fx import convert_fx
29342951
if self.sub_module_list is None:
29352952
if self.version > Version("1.12.1"): # pragma: no cover

neural_compressor/adaptor/torch_utils/util.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,125 @@ def contiguous_hook(module, input):
4444
child.register_forward_pre_hook(contiguous_hook)
4545

4646

47+
def is_fused_module(module):
48+
"""This is a helper function for `_propagate_qconfig_helper` to detecte
49+
if this module is fused.
50+
51+
Args:
52+
module (object): input module
53+
54+
Returns:
55+
(bool): is fused or not
56+
"""
57+
op_type = str(type(module))
58+
if 'fused' in op_type:
59+
return True
60+
else:
61+
return False
62+
63+
64+
def _set_input_scale_hook(model, op_cfgs):
65+
"""Insert hooks to observer input scale and zeropoint.
66+
67+
Args:
68+
model (object): input model
69+
op_cfgs (dict): dictionary of quantization configure for each op
70+
71+
Returns:
72+
hook_list (list): input observer hooks
73+
"""
74+
def input_scale_hook(module, input):
75+
module.input_observer = module.qconfig.activation()
76+
module.input_observer(input[0])
77+
return input
78+
79+
def output_scale_hook(module, input, output):
80+
module.output_observer = module.qconfig.activation()
81+
module.output_observer(output)
82+
return output
83+
84+
def ConvReLU2d_scale_hook(module, input):
85+
module.input_observer = module.qconfig.activation()
86+
module.input_observer(input[0])
87+
output = module._conv_forward(input[0], module.weight_fake_quant(module.weight), module.bias)
88+
module.output_observer = module.qconfig.activation()
89+
module.output_observer(output)
90+
return input
91+
92+
def LinearReLU_scale_hook(module, input):
93+
import torch.nn.functional as F
94+
module.input_observer = module.qconfig.activation()
95+
module.input_observer(input[0])
96+
output = F.linear(input[0], module.weight_fake_quant(module.weight), module.bias)
97+
module.output_observer = module.qconfig.activation()
98+
module.output_observer(output)
99+
return input
100+
101+
hook_list = []
102+
for name, module in model.named_modules():
103+
if 'Conv' in str(module.__class__.__name__) or \
104+
'Linear' in str(module.__class__.__name__):
105+
if not hasattr(module, 'qconfig') or not module.qconfig:
106+
continue
107+
from torch.nn.intrinsic.qat import ConvBn2d, ConvReLU2d, ConvBnReLU2d, LinearReLU
108+
if type(module) in [ConvBn2d, ConvBnReLU2d]:
109+
handle_in = module.register_forward_pre_hook(input_scale_hook)
110+
# module[0] == torch.nn.BatchNorm2d
111+
module[0].qconfig = module.qconfig
112+
handle_out = module[0].register_forward_hook(output_scale_hook)
113+
hook_list.extend([handle_in, handle_out])
114+
elif type(module) in [ConvReLU2d]:
115+
handle_in_out = module.register_forward_pre_hook(ConvReLU2d_scale_hook)
116+
hook_list.extend([handle_in_out])
117+
elif type(module) in [LinearReLU]:
118+
handle_in_out = module.register_forward_pre_hook(LinearReLU_scale_hook)
119+
hook_list.extend([handle_in_out])
120+
else:
121+
if is_fused_module(module):
122+
continue
123+
handle_in = module.register_forward_pre_hook(input_scale_hook)
124+
handle_out = module.register_forward_hook(output_scale_hook)
125+
hook_list.extend([handle_in, handle_out])
126+
return hook_list
127+
128+
129+
def _get_input_scale(model, hook_list):
130+
"""Fetch input scale and zeropoint from observer.
131+
132+
Args:
133+
model (object): input model
134+
hook_list (list): input observer hooks
135+
136+
Returns:
137+
input_scale_info (dict): input scale and zero_point of each modules
138+
"""
139+
scale_info = {}
140+
for name, module in model.named_modules():
141+
from torch.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
142+
if type(module) in [ConvBn2d, ConvBnReLU2d]:
143+
if hasattr(module, "input_observer") and hasattr(module[0], "output_observer"):
144+
scale_in, zero_point_in = module.input_observer.calculate_qparams()
145+
scale_out, zero_point_out = module[0].output_observer.calculate_qparams()
146+
scale_info[name] = {
147+
'input_scale': float(scale_in),
148+
'input_zeropoint': int(zero_point_in),
149+
'output_scale': float(scale_out),
150+
'output_zeropoint': int(zero_point_out)
151+
}
152+
elif hasattr(module, "input_observer") and hasattr(module, "output_observer"):
153+
scale_in, zero_point_in = module.input_observer.calculate_qparams()
154+
scale_out, zero_point_out = module.output_observer.calculate_qparams()
155+
scale_info[name] = {
156+
'input_scale': float(scale_in),
157+
'input_zeropoint': int(zero_point_in),
158+
'output_scale': float(scale_out),
159+
'output_zeropoint': int(zero_point_out)
160+
}
161+
for h in hook_list:
162+
h.remove()
163+
return scale_info
164+
165+
47166
def collate_torch_preds(results):
48167
batch = results[0]
49168
if isinstance(batch, list):

neural_compressor/experimental/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .model_conversion import ModelConversion
2828
from .distillation import Distillation
2929
from .nas import NAS
30+
from . import export
3031

3132
__all__ = ['Component', 'Quantization', 'Pruning', 'Benchmark', 'Graph_Optimization', \
32-
'GraphOptimization', 'ModelConversion', 'Distillation', 'NAS', 'MixedPrecision']
33+
'GraphOptimization', 'ModelConversion', 'Distillation', 'NAS', 'MixedPrecision', \
34+
'export']
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2021 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
"""Intel Neural Compressor Export."""
20+
21+
from .torch2onnx import torch_to_fp32_onnx, torch_to_int8_onnx

0 commit comments

Comments
 (0)