From c226e51305f35c070a51d486df4a95804aee090c Mon Sep 17 00:00:00 2001 From: Xin He Date: Wed, 16 Aug 2023 11:58:50 +0800 Subject: [PATCH 1/5] fix bug for example_inputs Signed-off-by: Xin He --- neural_compressor/adaptor/pytorch.py | 4 ++-- .../adaptor/torch_utils/smooth_quant.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index ddd2b4b67db..96ac2e445da 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -2950,7 +2950,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): ipex_conf.save(self.ipex_config_path) else: if self.approach in ['post_training_static_quant', 'post_training_auto_quant']: - assert self.q_dataloader or self.example_inputs, \ + assert self.q_dataloader is not None or self.example_inputs is not None, \ "IPEX need q_dataloader or example_inputs to prepare the model" from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig if self.version.release >= Version("2.1").release: @@ -2983,7 +2983,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): model = ipex.quantization.prepare(model, static_qconfig, example_inputs=self.example_inputs, inplace=True) - if self.q_dataloader or self.example_inputs: + if self.q_dataloader is not None or self.example_inputs is not None: self._simple_inference(model, self.q_dataloader, iterations=1) else: try: diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 85746b41e3d..49c66f98904 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -1110,10 +1110,14 @@ def _get_example_input(self): if self.dataloader == None and self.example_inputs == None: return None if self.example_inputs is None: - ##assert self.dataloader, "Please provide dataloader or example_inputs" - for idx, input in enumerate(self.dataloader): - self.example_inputs = input - break + try: + for idx, (input, label) in enumerate(self.dataloader): + self.example_inputs = input + break + except: + for idx, input in enumerate(self.dataloader): + self.example_inputs = input + break return self.example_inputs From dd832ff0539666f8fad0c1fad9406735cec3a50d Mon Sep 17 00:00:00 2001 From: Xin He Date: Wed, 16 Aug 2023 13:47:57 +0800 Subject: [PATCH 2/5] fix QKV is not fully converted to SQLinear Signed-off-by: Xin He --- neural_compressor/adaptor/pytorch.py | 3 +- .../adaptor/torch_utils/smooth_quant.py | 50 ++++++++++--------- test/algorithm/test_smooth_quant.py | 9 ++++ 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 96ac2e445da..03cccaa7522 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -3141,7 +3141,8 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func): # Check save_qconf_summary part is a workaroud for IPEX bug. # Sometimes the prepared model from get_op_capablitiy loss this attribute - if not hasattr(model._model, "save_qconf_summary"): + if not hasattr(model._model, "save_qconf_summary") or \ + not hasattr(model._model, "load_qconf_summary"): static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5) if isinstance(self.example_inputs, dict): model._model = ipex.quantization.prepare(model._model, static_qconfig, diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 49c66f98904..69e69d47e2f 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -501,21 +501,31 @@ def _reshape_scale_for_input(self, layer, scale): return scale - def _scale_layer_weight(self, layer_name, scale): ##input channel + def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel """ Scale the layer weights at input channel, depthwise conv output channel :param layer_name: The layer name :param scale: The scale to be multiplied + :param alpha: alpha for SQLinearWrapper + :param input_minmax: input_minmax for SQLinearWrapper :return: """ layer = get_module(self.model, layer_name) - if layer.__class__.__name__ == "SQLinearWrapper": - return scale # weigth update is done in SQLinearWrapper initialization - scale = self._reshape_scale_for_weight(layer, scale) - layer.weight = torch.nn.Parameter(layer.weight * scale) + if self.insert_mul: + from .model_wrapper import SQLinearWrapper + layer = get_module(self.model, layer_name) + if isinstance(layer, SQLinearWrapper): + layer._recover_sq_linear() + set_module(self.model, layer_name, layer.sq_linear) ##recover + else: + new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha) + set_module(self.model, layer_name, new_module) + elif self.allow_absorb: + scale = self._reshape_scale_for_weight(layer, scale) + layer.weight = torch.nn.Parameter(layer.weight * scale) return scale - def _absorb_scales(self, layer_name, scale, alpha=0.5): ##output channel + def _absorb_scales(self, layer_name, scale): ##output channel """ Absorb the scale to the layer at output channel :param layer_name: The module name @@ -523,22 +533,11 @@ def _absorb_scales(self, layer_name, scale, alpha=0.5): ##output channel :param alpha_key: The alpha passed to SQLinearWrapper :return: """ - layer = get_module(self.model, layer_name) - if self.insert_mul: - if layer.__class__.__name__ == "SQLinearWrapper": - layer._recover_sq_linear() - set_module(self.model, layer_name, layer.sq_linear) ##recover - else: - from .model_wrapper import SQLinearWrapper - input_minmax = [self.input_mins[layer_name], self.input_maxes[layer_name]] - new_module = SQLinearWrapper(layer, scale, input_minmax, alpha) - set_module(self.model, layer_name, new_module) - return - - if not self.allow_absorb: - return ## change the code style due to too many if/else statements in the following + if self.insert_mul or not self.allow_absorb: + return # absorb is updated in SQLinearWrapper in def _scale_layer_weight ##if self.allow absorb + layer = get_module(self.model, layer_name) if layer.__class__.__name__ == 'WrapperLayer': layer = layer.orig_layer if isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.GroupNorm) or \ @@ -650,7 +649,9 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict :return: """ - absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha, tuning) + absorb_scales_info, weight_scales_info = self._cal_scales( + absorb_to_layer, input_maxes, alpha, tuning + ) if not absorb_scales_info or not weight_scales_info: return weight_scales_info, absorb_scales_info for index, key in enumerate(absorb_to_layer.keys()): @@ -659,10 +660,13 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal elif isinstance(alpha, dict): alpha_tmp = alpha[key] absorb_scale = absorb_scales_info[key] - self._absorb_scales(key, absorb_scale, alpha_tmp) + self._absorb_scales(key, absorb_scale) layer_names = absorb_to_layer[key] for layer_name in layer_names: - self._scale_layer_weight(layer_name, weight_scales_info[layer_name]) + input_minmax = [self.input_mins[key], self.input_maxes[key]] + self._scale_layer_weight( + layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax + ) return weight_scales_info, absorb_scales_info def _check_need_calibration(self, alpha, percentile, op_types, diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 579b520a8f2..d875e7b65d3 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -626,6 +626,15 @@ def forward(self, x): sq.transform(alpha=0.5, calib_iter=1) # By default, folding=False assert isinstance(sq.model.fc1, SQLinearWrapper) + def test_sq_qkv(self): + model = transformers.AutoModelForCausalLM.from_pretrained( + 'facebook/opt-125m', torchscript=True,) + sq = TorchSmoothQuant(model, LLMCalibDataloader()) + sq.transform(alpha=0.5, calib_iter=-1, folding=False) + assert isinstance( + sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper + ) + def test_sq_quant(self): from neural_compressor import PostTrainingQuantConfig, quantization class Model(torch.nn.Module): From 07dac655c1f930e533821a3de3f1a2ab4a8dc632 Mon Sep 17 00:00:00 2001 From: Xin He Date: Wed, 16 Aug 2023 13:56:29 +0800 Subject: [PATCH 3/5] fix bug Signed-off-by: Xin He --- neural_compressor/adaptor/pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 03cccaa7522..e9eb0428474 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -2605,7 +2605,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): if self.version.release >= Version("1.12.0").release: # Check save_qconf_summary part is a workaroud for IPEX bug. # Sometimes the prepared model from get_op_capablitiy loss this attribute - if not hasattr(model._model, "save_qconf_summary"): + if not hasattr(model._model, "save_qconf_summary") or \ + not hasattr(model._model, "load_qconf_summary"): from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig if self.version.release >= Version("2.1").release: static_qconfig = ipex.quantization.default_static_qconfig_mapping From 7258fecb566106d4a1f73f6e0fabd9488b30e58e Mon Sep 17 00:00:00 2001 From: Xin He Date: Wed, 16 Aug 2023 13:59:31 +0800 Subject: [PATCH 4/5] fix bug Signed-off-by: Xin He --- test/algorithm/test_smooth_quant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index d875e7b65d3..82624b40fe1 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -743,6 +743,7 @@ def calib_func(model): calib_func=calib_func, ) + fp32_model = Model() conf = PostTrainingQuantConfig( backend="ipex", calibration_sampling_size=8, From 567c24d9f2d3f2b31789bb86c140139e5ec74d85 Mon Sep 17 00:00:00 2001 From: Xin He Date: Thu, 17 Aug 2023 09:59:49 +0800 Subject: [PATCH 5/5] enhance ut Signed-off-by: Xin He --- neural_compressor/adaptor/torch_utils/smooth_quant.py | 2 +- test/algorithm/test_smooth_quant.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 69e69d47e2f..3438ee6a9b8 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -663,7 +663,7 @@ def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=Fal self._absorb_scales(key, absorb_scale) layer_names = absorb_to_layer[key] for layer_name in layer_names: - input_minmax = [self.input_mins[key], self.input_maxes[key]] + input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] self._scale_layer_weight( layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax ) diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 82624b40fe1..eabc3431b46 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -273,19 +273,22 @@ def __init__(self): self.norm = torch.nn.GroupNorm(num_channels=4, num_groups=2) self.act = torch.nn.ReLU() self.conv2 = torch.nn.Conv2d(4, 3, 1, 1) + self.conv3 = torch.nn.Conv2d(4, 3, 1, 1) def forward(self, x): out = self.conv1(x) out = self.norm(out) out = self.act(out) - out = self.conv2(out) + tmp1 = self.conv2(out) + tmp2 = self.conv3(out) + out = tmp1 + tmp2 return out model = Model() sq = TorchSmoothQuant(model, self.conv_dl) sq.transform(alpha=0.6, calib_iter=2, folding=True) - assert len(sq.absorb_to_layer) == 1 + assert len(sq.absorb_to_layer['norm']) == 2 def test_sq_add(self): class Model(torch.nn.Module):