From ae231b24846c0d3d9b1080ac30f669d807080904 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 11 Jul 2024 11:46:13 +0800 Subject: [PATCH 1/4] export teq model Signed-off-by: yiliu30 --- .../torch/algorithms/weight_only/teq.py | 70 ++++++++++++++++--- 1 file changed, 59 insertions(+), 11 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/teq.py b/neural_compressor/torch/algorithms/weight_only/teq.py index 9783d913070..6513553800d 100644 --- a/neural_compressor/torch/algorithms/weight_only/teq.py +++ b/neural_compressor/torch/algorithms/weight_only/teq.py @@ -21,7 +21,7 @@ import torch from neural_compressor.torch.algorithms.base_algorithm import Quantizer -from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger +from neural_compressor.torch.utils import get_accelerator, get_model_device, is_transformers_imported, logger from .modules import MulLinear, TEQLinearFakeQuant from .utility import get_module, quant_tensor, set_module @@ -265,18 +265,66 @@ def transform(self): set_module(self.model, n, m.orig_layer) @torch.no_grad() - def quantize(self): + def quantize(self, **kwargs): """quantization.""" - - for n, m in self.model.named_modules(): - if self.weight_config.get(n) is None: # pragma: no cover - logger.info(f"quantize layer {n} not in weight config, skip.") + use_optimum_format = kwargs.get("use_optimum_format", True) + device = get_accelerator().current_device_name() + model_device = get_model_device(self.model) # return model on the same device + model = self.model + for name, m in model.named_modules(): + if self.weight_config.get(name) is None: # pragma: no cover + logger.info(f"quantize layer {name} not in weight config, skip.") continue - num_bits = self.weight_config[n]["bits"] - group_size = self.weight_config[n]["group_size"] - scheme = self.weight_config[n]["scheme"] + num_bits = self.weight_config[name]["bits"] + group_size = self.weight_config[name]["group_size"] + scheme = self.weight_config[name]["scheme"] + group_dim = self.weight_config[name].get("group_dim", 1) + # for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight. + if is_transformers_imported(): + transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D)) + else: + transpose = group_dim == 0 + if transpose: + weight = m.weight.detach().T.contiguous() + else: + weight = m.weight.detach() if isinstance(m, torch.nn.Linear): # pragma: no cover - quant_tensor(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme) + int_weight, scale, zp = quant_tensor( + weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme + ) + int_weight = int_weight.t_().contiguous() if transpose else int_weight + scale = scale.t_().contiguous() if transpose else scale + zp = zp.t_().contiguous() if transpose and zp is not None else zp + if isinstance(m, torch.nn.Linear): + in_features = m.in_features + out_features = m.out_features + elif is_transformers_imported() and isinstance(m, transformers.Conv1D): + in_features = m.weight.shape[0] + out_features = m.weight.shape[1] + int_weight = int_weight.t_().contiguous() + scale = scale.t_().contiguous() + zp = zp.t_().contiguous() if zp is not None else zp + from .modules import WeightOnlyLinear + + new_module = WeightOnlyLinear( + in_features, + out_features, + bits=num_bits, + group_size=group_size, + zp=zp is not None, + bias=m.bias is not None, + use_optimum_format=use_optimum_format, + device=device, + ) + new_module.pack(int_weight, scale, zp, m.bias) + if name == "": + return new_module + else: + set_module(model, name, new_module) + # Move modules back to the model device layer-by-layer + m.to(model_device) + new_module.to(model_device) + self.model = model def save(self, save_scale_file="", save_state_dict_file=""): """ @@ -328,6 +376,6 @@ def convert(self, model, *args: Any, **kwargs: Any): setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None)) self._quantizer.model = model self._quantizer.transform() - self._quantizer.quantize() + self._quantizer.quantize(**kwargs) logger.info("TEQ quantizing done.") return self._quantizer.model From e508c942a63bd323be0d0ff1f120c408da9af83e Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 11 Jul 2024 14:55:26 +0800 Subject: [PATCH 2/4] fixed export Signed-off-by: yiliu30 --- neural_compressor/torch/algorithms/weight_only/teq.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/weight_only/teq.py b/neural_compressor/torch/algorithms/weight_only/teq.py index 6513553800d..f84b86bcd1d 100644 --- a/neural_compressor/torch/algorithms/weight_only/teq.py +++ b/neural_compressor/torch/algorithms/weight_only/teq.py @@ -290,7 +290,11 @@ def quantize(self, **kwargs): weight = m.weight.detach() if isinstance(m, torch.nn.Linear): # pragma: no cover int_weight, scale, zp = quant_tensor( - weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme + weight.data, + num_bits=num_bits, + group_size=group_size, + scheme=scheme, + return_int=True, ) int_weight = int_weight.t_().contiguous() if transpose else int_weight scale = scale.t_().contiguous() if transpose else scale From 21c5e28e68108cf60135a1239d24e1a6d6298f6c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 11 Jul 2024 03:09:51 -0400 Subject: [PATCH 3/4] fixed the quant tensor Signed-off-by: yiliu30 --- neural_compressor/torch/algorithms/weight_only/teq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/weight_only/teq.py b/neural_compressor/torch/algorithms/weight_only/teq.py index 6513553800d..905141a8d35 100644 --- a/neural_compressor/torch/algorithms/weight_only/teq.py +++ b/neural_compressor/torch/algorithms/weight_only/teq.py @@ -290,7 +290,7 @@ def quantize(self, **kwargs): weight = m.weight.detach() if isinstance(m, torch.nn.Linear): # pragma: no cover int_weight, scale, zp = quant_tensor( - weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme + weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme, return_int=True ) int_weight = int_weight.t_().contiguous() if transpose else int_weight scale = scale.t_().contiguous() if transpose else scale From 0a185d790cd687bf35202877cddb5802a7b7973d Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 11 Jul 2024 16:21:20 +0800 Subject: [PATCH 4/4] disable some check Signed-off-by: yiliu30 --- neural_compressor/torch/algorithms/weight_only/teq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/teq.py b/neural_compressor/torch/algorithms/weight_only/teq.py index f84b86bcd1d..595a2e8479f 100644 --- a/neural_compressor/torch/algorithms/weight_only/teq.py +++ b/neural_compressor/torch/algorithms/weight_only/teq.py @@ -282,9 +282,9 @@ def quantize(self, **kwargs): # for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight. if is_transformers_imported(): transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D)) - else: + else: # pragma: no cover transpose = group_dim == 0 - if transpose: + if transpose: # pragma: no cover weight = m.weight.detach().T.contiguous() else: weight = m.weight.detach()