Skip to content

Commit 0349b9a

Browse files
authored
fix bug in TorchSmoothQuant (#1149)
* [bug fix] when folding=False and QKV is not fully converted to SQLinear. Signed-off-by: Xin He <[email protected]> --------- Signed-off-by: Xin He <[email protected]>
1 parent aa4770d commit 0349b9a

File tree

3 files changed

+56
-33
lines changed

3 files changed

+56
-33
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,7 +2605,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
26052605
if self.version.release >= Version("1.12.0").release:
26062606
# Check save_qconf_summary part is a workaroud for IPEX bug.
26072607
# Sometimes the prepared model from get_op_capablitiy loss this attribute
2608-
if not hasattr(model._model, "save_qconf_summary"):
2608+
if not hasattr(model._model, "save_qconf_summary") or \
2609+
not hasattr(model._model, "load_qconf_summary"):
26092610
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
26102611
if self.version.release >= Version("2.1").release:
26112612
static_qconfig = ipex.quantization.default_static_qconfig_mapping
@@ -2950,7 +2951,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
29502951
ipex_conf.save(self.ipex_config_path)
29512952
else:
29522953
if self.approach in ['post_training_static_quant', 'post_training_auto_quant']:
2953-
assert self.q_dataloader or self.example_inputs, \
2954+
assert self.q_dataloader is not None or self.example_inputs is not None, \
29542955
"IPEX need q_dataloader or example_inputs to prepare the model"
29552956
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
29562957
if self.version.release >= Version("2.1").release:
@@ -2983,7 +2984,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
29832984
model = ipex.quantization.prepare(model, static_qconfig,
29842985
example_inputs=self.example_inputs, inplace=True)
29852986

2986-
if self.q_dataloader or self.example_inputs:
2987+
if self.q_dataloader is not None or self.example_inputs is not None:
29872988
self._simple_inference(model, self.q_dataloader, iterations=1)
29882989
else:
29892990
try:
@@ -3141,7 +3142,8 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
31413142

31423143
# Check save_qconf_summary part is a workaroud for IPEX bug.
31433144
# Sometimes the prepared model from get_op_capablitiy loss this attribute
3144-
if not hasattr(model._model, "save_qconf_summary"):
3145+
if not hasattr(model._model, "save_qconf_summary") or \
3146+
not hasattr(model._model, "load_qconf_summary"):
31453147
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
31463148
if isinstance(self.example_inputs, dict):
31473149
model._model = ipex.quantization.prepare(model._model, static_qconfig,

neural_compressor/adaptor/torch_utils/smooth_quant.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -501,44 +501,43 @@ def _reshape_scale_for_input(self, layer, scale):
501501

502502
return scale
503503

504-
def _scale_layer_weight(self, layer_name, scale): ##input channel
504+
def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel
505505
"""
506506
Scale the layer weights at input channel, depthwise conv output channel
507507
:param layer_name: The layer name
508508
:param scale: The scale to be multiplied
509+
:param alpha: alpha for SQLinearWrapper
510+
:param input_minmax: input_minmax for SQLinearWrapper
509511
:return:
510512
"""
511513
layer = get_module(self.model, layer_name)
512-
if layer.__class__.__name__ == "SQLinearWrapper":
513-
return scale # weigth update is done in SQLinearWrapper initialization
514-
scale = self._reshape_scale_for_weight(layer, scale)
515-
layer.weight = torch.nn.Parameter(layer.weight * scale)
514+
if self.insert_mul:
515+
from .model_wrapper import SQLinearWrapper
516+
layer = get_module(self.model, layer_name)
517+
if isinstance(layer, SQLinearWrapper):
518+
layer._recover_sq_linear()
519+
set_module(self.model, layer_name, layer.sq_linear) ##recover
520+
else:
521+
new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha)
522+
set_module(self.model, layer_name, new_module)
523+
elif self.allow_absorb:
524+
scale = self._reshape_scale_for_weight(layer, scale)
525+
layer.weight = torch.nn.Parameter(layer.weight * scale)
516526
return scale
517527

518-
def _absorb_scales(self, layer_name, scale, alpha=0.5): ##output channel
528+
def _absorb_scales(self, layer_name, scale): ##output channel
519529
"""
520530
Absorb the scale to the layer at output channel
521531
:param layer_name: The module name
522532
:param scale: The scale to be absorbed
523533
:param alpha_key: The alpha passed to SQLinearWrapper
524534
:return:
525535
"""
526-
layer = get_module(self.model, layer_name)
527-
if self.insert_mul:
528-
if layer.__class__.__name__ == "SQLinearWrapper":
529-
layer._recover_sq_linear()
530-
set_module(self.model, layer_name, layer.sq_linear) ##recover
531-
else:
532-
from .model_wrapper import SQLinearWrapper
533-
input_minmax = [self.input_mins[layer_name], self.input_maxes[layer_name]]
534-
new_module = SQLinearWrapper(layer, scale, input_minmax, alpha)
535-
set_module(self.model, layer_name, new_module)
536-
return
537-
538-
if not self.allow_absorb:
539-
return ## change the code style due to too many if/else statements in the following
536+
if self.insert_mul or not self.allow_absorb:
537+
return # absorb is updated in SQLinearWrapper in def _scale_layer_weight
540538

541539
##if self.allow absorb
540+
layer = get_module(self.model, layer_name)
542541
if layer.__class__.__name__ == 'WrapperLayer':
543542
layer = layer.orig_layer
544543
if isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.GroupNorm) or \
@@ -650,7 +649,9 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
650649
:param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict
651650
:return:
652651
"""
653-
absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha, tuning)
652+
absorb_scales_info, weight_scales_info = self._cal_scales(
653+
absorb_to_layer, input_maxes, alpha, tuning
654+
)
654655
if not absorb_scales_info or not weight_scales_info:
655656
return weight_scales_info, absorb_scales_info
656657
for index, key in enumerate(absorb_to_layer.keys()):
@@ -659,10 +660,13 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
659660
elif isinstance(alpha, dict):
660661
alpha_tmp = alpha[key]
661662
absorb_scale = absorb_scales_info[key]
662-
self._absorb_scales(key, absorb_scale, alpha_tmp)
663+
self._absorb_scales(key, absorb_scale)
663664
layer_names = absorb_to_layer[key]
664665
for layer_name in layer_names:
665-
self._scale_layer_weight(layer_name, weight_scales_info[layer_name])
666+
input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]]
667+
self._scale_layer_weight(
668+
layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax
669+
)
666670
return weight_scales_info, absorb_scales_info
667671

668672
def _check_need_calibration(self, alpha, percentile, op_types,
@@ -1110,10 +1114,14 @@ def _get_example_input(self):
11101114
if self.dataloader == None and self.example_inputs == None:
11111115
return None
11121116
if self.example_inputs is None:
1113-
##assert self.dataloader, "Please provide dataloader or example_inputs"
1114-
for idx, input in enumerate(self.dataloader):
1115-
self.example_inputs = input
1116-
break
1117+
try:
1118+
for idx, (input, label) in enumerate(self.dataloader):
1119+
self.example_inputs = input
1120+
break
1121+
except:
1122+
for idx, input in enumerate(self.dataloader):
1123+
self.example_inputs = input
1124+
break
11171125

11181126
return self.example_inputs
11191127

test/algorithm/test_smooth_quant.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,19 +273,22 @@ def __init__(self):
273273
self.norm = torch.nn.GroupNorm(num_channels=4, num_groups=2)
274274
self.act = torch.nn.ReLU()
275275
self.conv2 = torch.nn.Conv2d(4, 3, 1, 1)
276+
self.conv3 = torch.nn.Conv2d(4, 3, 1, 1)
276277

277278
def forward(self, x):
278279
out = self.conv1(x)
279280
out = self.norm(out)
280281
out = self.act(out)
281-
out = self.conv2(out)
282+
tmp1 = self.conv2(out)
283+
tmp2 = self.conv3(out)
284+
out = tmp1 + tmp2
282285
return out
283286

284287
model = Model()
285288

286289
sq = TorchSmoothQuant(model, self.conv_dl)
287290
sq.transform(alpha=0.6, calib_iter=2, folding=True)
288-
assert len(sq.absorb_to_layer) == 1
291+
assert len(sq.absorb_to_layer['norm']) == 2
289292

290293
def test_sq_add(self):
291294
class Model(torch.nn.Module):
@@ -626,6 +629,15 @@ def forward(self, x):
626629
sq.transform(alpha=0.5, calib_iter=1) # By default, folding=False
627630
assert isinstance(sq.model.fc1, SQLinearWrapper)
628631

632+
def test_sq_qkv(self):
633+
model = transformers.AutoModelForCausalLM.from_pretrained(
634+
'facebook/opt-125m', torchscript=True,)
635+
sq = TorchSmoothQuant(model, LLMCalibDataloader())
636+
sq.transform(alpha=0.5, calib_iter=-1, folding=False)
637+
assert isinstance(
638+
sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper
639+
)
640+
629641
def test_sq_quant(self):
630642
from neural_compressor import PostTrainingQuantConfig, quantization
631643
class Model(torch.nn.Module):
@@ -734,6 +746,7 @@ def calib_func(model):
734746
calib_func=calib_func,
735747
)
736748

749+
fp32_model = Model()
737750
conf = PostTrainingQuantConfig(
738751
backend="ipex",
739752
calibration_sampling_size=8,

0 commit comments

Comments
 (0)