Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,7 +2605,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
if self.version.release >= Version("1.12.0").release:
# Check save_qconf_summary part is a workaroud for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model._model, "save_qconf_summary"):
if not hasattr(model._model, "save_qconf_summary") or \
not hasattr(model._model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
if self.version.release >= Version("2.1").release:
static_qconfig = ipex.quantization.default_static_qconfig_mapping
Expand Down Expand Up @@ -2950,7 +2951,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
ipex_conf.save(self.ipex_config_path)
else:
if self.approach in ['post_training_static_quant', 'post_training_auto_quant']:
assert self.q_dataloader or self.example_inputs, \
assert self.q_dataloader is not None or self.example_inputs is not None, \
"IPEX need q_dataloader or example_inputs to prepare the model"
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
if self.version.release >= Version("2.1").release:
Expand Down Expand Up @@ -2983,7 +2984,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
model = ipex.quantization.prepare(model, static_qconfig,
example_inputs=self.example_inputs, inplace=True)

if self.q_dataloader or self.example_inputs:
if self.q_dataloader is not None or self.example_inputs is not None:
self._simple_inference(model, self.q_dataloader, iterations=1)
else:
try:
Expand Down Expand Up @@ -3141,7 +3142,8 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):

# Check save_qconf_summary part is a workaroud for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model._model, "save_qconf_summary"):
if not hasattr(model._model, "save_qconf_summary") or \
not hasattr(model._model, "load_qconf_summary"):
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if isinstance(self.example_inputs, dict):
model._model = ipex.quantization.prepare(model._model, static_qconfig,
Expand Down
62 changes: 35 additions & 27 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,44 +501,43 @@ def _reshape_scale_for_input(self, layer, scale):

return scale

def _scale_layer_weight(self, layer_name, scale): ##input channel
def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel
"""
Scale the layer weights at input channel, depthwise conv output channel
:param layer_name: The layer name
:param scale: The scale to be multiplied
:param alpha: alpha for SQLinearWrapper
:param input_minmax: input_minmax for SQLinearWrapper
:return:
"""
layer = get_module(self.model, layer_name)
if layer.__class__.__name__ == "SQLinearWrapper":
return scale # weigth update is done in SQLinearWrapper initialization
scale = self._reshape_scale_for_weight(layer, scale)
layer.weight = torch.nn.Parameter(layer.weight * scale)
if self.insert_mul:
from .model_wrapper import SQLinearWrapper
layer = get_module(self.model, layer_name)
if isinstance(layer, SQLinearWrapper):
layer._recover_sq_linear()
set_module(self.model, layer_name, layer.sq_linear) ##recover
else:
new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha)
set_module(self.model, layer_name, new_module)
elif self.allow_absorb:
scale = self._reshape_scale_for_weight(layer, scale)
layer.weight = torch.nn.Parameter(layer.weight * scale)
return scale

def _absorb_scales(self, layer_name, scale, alpha=0.5): ##output channel
def _absorb_scales(self, layer_name, scale): ##output channel
"""
Absorb the scale to the layer at output channel
:param layer_name: The module name
:param scale: The scale to be absorbed
:param alpha_key: The alpha passed to SQLinearWrapper
:return:
"""
layer = get_module(self.model, layer_name)
if self.insert_mul:
if layer.__class__.__name__ == "SQLinearWrapper":
layer._recover_sq_linear()
set_module(self.model, layer_name, layer.sq_linear) ##recover
else:
from .model_wrapper import SQLinearWrapper
input_minmax = [self.input_mins[layer_name], self.input_maxes[layer_name]]
new_module = SQLinearWrapper(layer, scale, input_minmax, alpha)
set_module(self.model, layer_name, new_module)
return

if not self.allow_absorb:
return ## change the code style due to too many if/else statements in the following
if self.insert_mul or not self.allow_absorb:
return # absorb is updated in SQLinearWrapper in def _scale_layer_weight

##if self.allow absorb
layer = get_module(self.model, layer_name)
if layer.__class__.__name__ == 'WrapperLayer':
layer = layer.orig_layer
if isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.GroupNorm) or \
Expand Down Expand Up @@ -650,7 +649,9 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
:param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict
:return:
"""
absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha, tuning)
absorb_scales_info, weight_scales_info = self._cal_scales(
absorb_to_layer, input_maxes, alpha, tuning
)
if not absorb_scales_info or not weight_scales_info:
return weight_scales_info, absorb_scales_info
for index, key in enumerate(absorb_to_layer.keys()):
Expand All @@ -659,10 +660,13 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal
elif isinstance(alpha, dict):
alpha_tmp = alpha[key]
absorb_scale = absorb_scales_info[key]
self._absorb_scales(key, absorb_scale, alpha_tmp)
self._absorb_scales(key, absorb_scale)
layer_names = absorb_to_layer[key]
for layer_name in layer_names:
self._scale_layer_weight(layer_name, weight_scales_info[layer_name])
input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]]
self._scale_layer_weight(
layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax
)
return weight_scales_info, absorb_scales_info

def _check_need_calibration(self, alpha, percentile, op_types,
Expand Down Expand Up @@ -1110,10 +1114,14 @@ def _get_example_input(self):
if self.dataloader == None and self.example_inputs == None:
return None
if self.example_inputs is None:
##assert self.dataloader, "Please provide dataloader or example_inputs"
for idx, input in enumerate(self.dataloader):
self.example_inputs = input
break
try:
for idx, (input, label) in enumerate(self.dataloader):
self.example_inputs = input
break
except:
for idx, input in enumerate(self.dataloader):
self.example_inputs = input
break

return self.example_inputs

Expand Down
17 changes: 15 additions & 2 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,19 +273,22 @@ def __init__(self):
self.norm = torch.nn.GroupNorm(num_channels=4, num_groups=2)
self.act = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(4, 3, 1, 1)
self.conv3 = torch.nn.Conv2d(4, 3, 1, 1)

def forward(self, x):
out = self.conv1(x)
out = self.norm(out)
out = self.act(out)
out = self.conv2(out)
tmp1 = self.conv2(out)
tmp2 = self.conv3(out)
out = tmp1 + tmp2
return out

model = Model()

sq = TorchSmoothQuant(model, self.conv_dl)
sq.transform(alpha=0.6, calib_iter=2, folding=True)
assert len(sq.absorb_to_layer) == 1
assert len(sq.absorb_to_layer['norm']) == 2

def test_sq_add(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -626,6 +629,15 @@ def forward(self, x):
sq.transform(alpha=0.5, calib_iter=1) # By default, folding=False
assert isinstance(sq.model.fc1, SQLinearWrapper)

def test_sq_qkv(self):
model = transformers.AutoModelForCausalLM.from_pretrained(
'facebook/opt-125m', torchscript=True,)
sq = TorchSmoothQuant(model, LLMCalibDataloader())
sq.transform(alpha=0.5, calib_iter=-1, folding=False)
assert isinstance(
sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper
)

def test_sq_quant(self):
from neural_compressor import PostTrainingQuantConfig, quantization
class Model(torch.nn.Module):
Expand Down Expand Up @@ -734,6 +746,7 @@ def calib_func(model):
calib_func=calib_func,
)

fp32_model = Model()
conf = PostTrainingQuantConfig(
backend="ipex",
calibration_sampling_size=8,
Expand Down