Skip to content

Commit 6a39f64

Browse files
xin3hewenhuach21
andauthored
update SmoothQuant algorithm with folding choice (#799)
Signed-off-by: Xin He <[email protected]> Co-authored-by: wenhuach21 <[email protected]>
1 parent 7a7cfe5 commit 6a39f64

File tree

9 files changed

+745
-176
lines changed

9 files changed

+745
-176
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/ipex/smooth_quant/eval_lambada.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def evaluate(self, model):
4242
total += label.size(0)
4343
hit += (pred == label).sum().item()
4444
if index % args.log_frequency == 0:
45-
print(hit / total)
45+
print(hit / total, flush=True)
4646
index += 1
4747
acc = hit / total
48-
print(acc)
48+
print(acc, flush=True)
4949
return acc
5050

5151

@@ -145,7 +145,7 @@ def eval_func(model):
145145
recipes = {}
146146
if args.sq:
147147
recipes = {"smooth_quant": True, "smooth_quant_args": {'alpha': args.alpha}}
148-
op_type_dict = None
148+
op_type_dict = {}
149149
if args.kl:
150150
op_type_dict = {'linear': {'activation': {'algorithm': ['kl']}}}
151151
if args.fallback_add:

neural_compressor/adaptor/onnxrt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __init__(self, framework_specific_info):
152152

153153
self.optype_statistics = None
154154

155-
def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5,
155+
def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, folding=False,
156156
percentile=99.999, op_types=['MatMul', 'Linear', 'Conv'], scales_per_op=True):
157157
"""Get augmented model with smooth quant.
158158
@@ -162,6 +162,7 @@ def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5,
162162
iterations: iterations
163163
tune_cfg: quantization config
164164
alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ
165+
folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant
165166
percentile:Percentile of calibration to remove outliers
166167
op_types: The op types whose input tensor will be dumped
167168
scales_per_op: True, each op will have an individual scale, mainly for accuracy

neural_compressor/adaptor/pytorch.py

Lines changed: 236 additions & 36 deletions
Large diffs are not rendered by default.
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2021 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""Torch.nn.Module Class Defination."""
19+
# Note: Do not import this file unless you have already imported torch,
20+
# since the model classes inherit torch.nn.Module.
21+
import torch
22+
from packaging.version import Version
23+
24+
25+
def get_torch_version():
26+
try:
27+
torch_version = torch.__version__.split('+')[0]
28+
except ValueError as e: # pragma: no cover
29+
assert False, 'Got an unknown version of torch: {}'.format(e)
30+
version = Version(torch_version)
31+
return version
32+
33+
PT_VERSION = get_torch_version().release
34+
35+
36+
class QDQLinear(torch.nn.Module):
37+
def __init__(self, module, scale, zero_point, dtype):
38+
super().__init__()
39+
if PT_VERSION < Version("1.13.0").release:
40+
import torch.nn.quantized as nnq
41+
else:
42+
import torch.ao.nn.quantized as nnq
43+
self.add_module('quant', nnq.Quantize(scale, zero_point, dtype))
44+
self.add_module('dequant', nnq.DeQuantize())
45+
self.add_module('module', module)
46+
self.qdq_weight()
47+
48+
def forward(self, X):
49+
X = self.quant(X)
50+
X = self.dequant(X)
51+
X = self.module(X)
52+
return X
53+
54+
def qdq_weight(self):
55+
# update weight w/ QDQ
56+
from .smooth_quant import quant_dequant_w
57+
weith_qdq = quant_dequant_w(self.module)
58+
self.module.weight = torch.nn.Parameter(weith_qdq)
59+
60+
61+
class SQLinearWrapper(torch.nn.Module):
62+
def __init__(self, module, input_scale, input_minmax, dtype=torch.quint8):
63+
super().__init__()
64+
self.input_scale = input_scale
65+
self.dtype = dtype
66+
# calculate and only save scale, zero_point to avoid memory usage
67+
self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype)
68+
self.add_module('sq_linear', module)
69+
self.ipex = False # a flag used for ipex inference
70+
71+
def forward(self, X):
72+
if self.ipex:
73+
X = self.sq_linear(X)
74+
else:
75+
X = torch.mul(X, self.input_scale)
76+
X = self.sq_linear(X)
77+
return X
78+
79+
def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8):
80+
# calculate scale and zero_point
81+
if dtype == torch.quint8:
82+
quant_min, quant_max = 0, 255
83+
min_val = torch.min(input_minmax[0] * input_scale)
84+
max_val = torch.max(input_minmax[1] * input_scale)
85+
# work when min_val bigger than zero.
86+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
87+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
88+
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
89+
scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps]))
90+
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
91+
zero_point = torch.clamp(zero_point, quant_min, quant_max)
92+
return scale, zero_point
93+
94+
def _get_weight_scale(self):
95+
# get weight scale and zero_point
96+
from torch.ao.quantization.observer import default_per_channel_weight_observer
97+
obs = default_per_channel_weight_observer()
98+
obs(self.sq_linear.weight)
99+
scale, _ = obs.calculate_qparams()
100+
return scale
101+
102+
def _recover_sq_linear(self):
103+
# remove mul and reset sq_linear for ipex inference
104+
scale = self.input_scale.view(1, self.input_scale.shape[0])
105+
with torch.no_grad():
106+
self.sq_linear.weight *= scale

0 commit comments

Comments
 (0)