Skip to content

Commit 8b81146

Browse files
xin3hexinhe3Sylwester Fraczek
authored
[SW-214269] support g_idx for uint4 (#246)
* support g_idx for uint4 --------- Signed-off-by: Xin He <[email protected]> Co-authored-by: Xin He <[email protected]> Co-authored-by: Sylwester Fraczek <[email protected]>
1 parent e886cd9 commit 8b81146

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,6 @@ def unpack(self):
365365
qweight = self.qweight.T.contiguous() if self.use_optimum_format else self.qweight
366366

367367
device = scales.device
368-
if self.g_idx is None:
369-
# used for recovering fp32_weight
370-
self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32).to(
371-
device
372-
)
373368
# unpack weight
374369
if not self.use_optimum_format and self.compression_dim == 0:
375370
qweight = qweight.T.contiguous()
@@ -413,6 +408,11 @@ def recover(self):
413408
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
414409

415410
# recover fp32 weight
411+
if self.g_idx is None:
412+
# used for recovering fp32_weight
413+
self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32).to(
414+
device
415+
)
416416
if zp is not None:
417417
# recover fp32 weight with int_weight, scale, and zero_point
418418
for idx in range(self.in_features):
@@ -729,7 +729,8 @@ def forward(self, input):
729729
scales = self.scales
730730
qweight = self.qweight
731731
zeros = self.qzeros
732-
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype)
732+
g_idx = self.g_idx
733+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype, g_idx)
733734
output = self.matmul_internal(input, weight)
734735
output = output.to(dtype=input_dtype).reshape(
735736
output_shape
@@ -760,6 +761,9 @@ def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=
760761
if bias is not None:
761762
self.bias = bias.to("hpu").to(torch.bfloat16)
762763

764+
if g_idx is not None:
765+
self.g_idx = g_idx.to("hpu").to(torch.int32)
766+
763767
def unpack(self):
764768
"""Unpack weight and zero point."""
765769
logger.debug("Unpacking from HPU")

neural_compressor/torch/algorithms/weight_only/save_load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def _replace_woqlinear_modules(self, name, linear_module, module_quantization_co
470470
module_kwargs["group_size"] = module_quantization_config.get("group_size", 32)
471471

472472
# spceific initialization kwargs
473-
module_kwargs["g_idx"] = True if name + ".g_idx" in self.loaded_state_dict_keys else False
473+
module_kwargs["g_idx"] = module_quantization_config.get("desc_act", False)
474474
module_kwargs["zp"] = True if name + ".qzeros" in self.loaded_state_dict_keys else False
475475
module_kwargs["use_optimum_format"] = True
476476
module_kwargs["bias"] = linear_module.bias is not None

0 commit comments

Comments
 (0)