diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index 940e9826785..b8c4329de3b 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -89,6 +89,36 @@ def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={} return block_absorb_dict, absorb_layer_dict +def _get_absorb_dict(model, absorb_layer_dict): + """Get absorbed layer per block from absorbed layer dict. + + Args: + model (torch.nn.Module): input model + absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}. + + Returns: + block_absorb_dict: dict of absorbed layer per block. eg. {0, [[absorbed_1, xx], [xx]], ...} + """ + block_absorb_dict = {} + block_prefix, block_num = get_block_prefix(model) + new_absorb_layer_dict = {} + for i in range(block_num): + block_absorb_dict[i] = [] + block_name = block_prefix + "." + str(i) + "." + + for k, v in absorb_layer_dict.items(): + + if isinstance(v, str): + name_list = (block_name + v,) + else: + name_list = tuple(block_name + vv for vv in v) + block_absorb_dict[i].append(name_list) + new_absorb_layer_dict[name_list] = block_name + k + logger.debug(f"The absorbed layers per block: {block_absorb_dict}") + logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}") + return block_absorb_dict, new_absorb_layer_dict + + @torch.no_grad() def _get_weight_scale(weight, q_group_size=-1): org_shape = weight.shape @@ -123,6 +153,7 @@ def __init__( total_block_args=[], total_block_kwargs=[], device="auto", + absorb_layer_dict={}, ): self.example_inputs = example_inputs @@ -140,6 +171,7 @@ def __init__( self.scheme = scheme self.use_full_range = use_full_range self.weight_config = weight_config + self.absorb_layer_dict = absorb_layer_dict def _move_model_and_data_to_device(self): # Put the model and example_inputs into target device @@ -164,13 +196,16 @@ def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, retu # Step 1: get absorbed module list per block, includes self-absorption # block_absorb_dict is split per block, includes all absorb relationship. # absorb_layer_dict is the inverse of block_absorb_dict for all blocks - self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block( - self.model, - self.example_inputs, - # for only use_mse_search, folding is useless. - folding=folding if use_auto_scale else False, - weight_config=self.weight_config, - ) + if not self.absorb_layer_dict: + self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block( + self.model, + self.example_inputs, + # for only use_mse_search, folding is useless. + folding=folding if use_auto_scale else False, + weight_config=self.weight_config, + ) + else: + self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_dict(self.model, self.absorb_layer_dict) # process per block for i, module_list in self.block_absorb_dict.items(): logger.info(f"Processing block: {i+1}/{self.block_num}") @@ -491,13 +526,15 @@ def module_inference(self, model, inputs): class AWQQuantizer(Quantizer): - def __init__(self, quant_config: OrderedDict = {}): + def __init__(self, quant_config: OrderedDict = {}, absorb_layer_dict: dict = {}): """Init an AWQQuantizer object. Args: quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}. + absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}. """ super().__init__(quant_config) + self.absorb_layer_dict = absorb_layer_dict @torch.no_grad() def prepare(self, model, *args, **kwargs): @@ -566,6 +603,7 @@ def convert( weight_config=self.quant_config, total_block_args=total_block_args, total_block_kwargs=total_block_kwargs, + absorb_layer_dict=self.absorb_layer_dict, ) qdq_model = awq.quantize( use_auto_scale=use_auto_scale, diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index b8a1e3b9202..ea2d53e7353 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -317,10 +317,10 @@ def awq_quantize_entry( from neural_compressor.torch.algorithms.weight_only.save_load import save weight_config = {} - for (op_name, op_type), op_config in configs_mapping.items(): - if op_config.name != AWQ: + for (op_name, op_type), quant_config in configs_mapping.items(): + if quant_config.name != AWQ: continue - if op_config.dtype == "fp32": + if quant_config.dtype == "fp32": weight_config[op_name] = { "bits": -1, "dtype": "fp32", # skip quantization @@ -329,31 +329,34 @@ def awq_quantize_entry( } else: weight_config[op_name] = { - "dtype": op_config.dtype, - "bits": op_config.bits, - "group_size": op_config.group_size, - "group_dim": op_config.group_dim, - "scheme": "sym" if op_config.use_sym else "asym", - "use_full_range": op_config.use_full_range, - "use_mse_search": op_config.use_mse_search, - "use_layer_wise": op_config.use_layer_wise, - "use_double_quant": op_config.use_double_quant, - "double_quant_dtype": op_config.double_quant_dtype, - "double_quant_bits": op_config.double_quant_bits, - "double_quant_scheme": op_config.double_quant_use_sym, - "double_quant_group_size": op_config.double_quant_group_size, + "dtype": quant_config.dtype, + "bits": quant_config.bits, + "group_size": quant_config.group_size, + "group_dim": quant_config.group_dim, + "scheme": "sym" if quant_config.use_sym else "asym", + "use_full_range": quant_config.use_full_range, + "use_mse_search": quant_config.use_mse_search, + "use_layer_wise": quant_config.use_layer_wise, + "use_double_quant": quant_config.use_double_quant, + "double_quant_dtype": quant_config.double_quant_dtype, + "double_quant_bits": quant_config.double_quant_bits, + "double_quant_scheme": quant_config.double_quant_use_sym, + "double_quant_group_size": quant_config.double_quant_group_size, } - use_auto_scale = op_config.use_auto_scale - use_mse_search = op_config.use_auto_clip # for awq clip - folding = op_config.folding - use_full_range = op_config.use_full_range + use_auto_scale = quant_config.use_auto_scale + use_mse_search = quant_config.use_auto_clip # for awq clip + folding = quant_config.folding + use_full_range = quant_config.use_full_range + absorb_layer_dict = quant_config.absorb_layer_dict run_fn = kwargs.get("run_fn", None) run_args = kwargs.get("run_args", None) example_inputs = kwargs.get("example_inputs", None) assert example_inputs is not None, "Please provide example_inputs for AWQ quantization." - quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config) + quantizer = get_quantizer( + model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_layer_dict=absorb_layer_dict + ) model = quantizer.execute( model, mode=mode, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 9014f1576a3..f2b12f89b5f 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -425,6 +425,7 @@ class AWQConfig(BaseConfig): "use_auto_scale", "use_auto_clip", "folding", + "absorb_layer_dict", ] name = AWQ @@ -451,6 +452,7 @@ def __init__( use_auto_clip: bool = True, folding: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + absorb_layer_dict: dict = {}, ): """Init AWQ weight-only quantization config. @@ -473,6 +475,7 @@ def __init__( use_auto_clip (bool): Enables clip range search. Defaults to True. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, default is False. + absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}. """ super().__init__(white_list=white_list) self.dtype = dtype @@ -493,6 +496,7 @@ def __init__( self.use_auto_scale = use_auto_scale self.use_auto_clip = use_auto_clip self.folding = folding + self.absorb_layer_dict = absorb_layer_dict self._post_init() @classmethod @@ -609,7 +613,7 @@ def __init__( double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True. double_quant_group_size (int): Size of double_quant groups, default is 32. quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False. - absorb_to_layer (bool): The layer dict that scale can be absorbed, default is {}. + absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, default is False. """ diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 6d33eb1a913..c877288f7dc 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -157,3 +157,41 @@ def test_quant_lm_head(self): assert ( id(model.model.decoder.embed_tokens.weight) == lm_head_id ), "The tied lm_head weight is not deep copied, please check!" + + def test_awq_absorb_to_layer(self): + absorb_layer_dict = { + "ln_1": ( + "attn.q_proj", + "attn.k_proj", + "attn.v_proj", + "mlp.fc_in", + ), + "attn.out_proj": "attn.out_proj", + "mlp.fc_out": ("mlp.fc_out"), + } + + quant_config = AWQConfig(absorb_layer_dict=absorb_layer_dict) + logger.info(f"Test AWQ with config {quant_config}") + # prepare + convert API + model = prepare( + model=copy.deepcopy(self.tiny_gptj), + quant_config=quant_config, + example_inputs=self.example_inputs, + ) + calib_func(model) + model = convert(model) + out1 = model(self.example_inputs) + quant_config = AWQConfig() + logger.info(f"Test AWQ with config {quant_config}") + + # prepare + convert API + model = prepare( + model=copy.deepcopy(self.tiny_gptj), + quant_config=quant_config, + example_inputs=self.example_inputs, + ) + calib_func(model) + model = convert(model) + out2 = model(self.example_inputs) + + assert torch.all(out1[0].eq(out2[0])), "The results should be equal."