diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index 2786dd9d1ec..b1ab86d1fd5 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -96,6 +96,14 @@ To support low memory inference, Neural Compressor implemented WeightOnlyLinear, | compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64] | | compression_dim | 1 | 0 means output channel while 1 means input channel | | scale_dtype | torch.float32 | Data type for scale and bias | +| use_hf_format | False | Whether to use the popular format present on HuggingFace hub | + +**Note:** HuggingFace format is quite special, the main differences are as follows: + +> 1: Compression Dimension: weight = 1, zero = 0 and both are transposed. +> 2: Zero Point: zero_point-= 1 before compression. zero_point is always required even for sym. +> 3: Group Index: Use the same number for a group instead of recording channel order. + ### **User Code Example** ```python @@ -119,12 +127,14 @@ conf = PostTrainingQuantConfig( ) q_model = quantization.fit(model, conf, eval_func=eval_func) q_model.save("saved_results") -compressed_model = q_model.export_compressed_model( - compression_dtype=torch.int32, - compression_dim=1, - scale_dtype=torch.float16, -) +compressed_model = q_model.export_compressed_model() torch.save(compressed_model.state_dict(), "compressed_model.pt") +# or +model = Model() +compressed_model = export_compressed_model( + model, + saved_dir="saved_results", +) ``` The saved_results folder contains two files: `best_model.pt` and `qconfig.json`, and the generated q_model is a fake quantized model. diff --git a/neural_compressor/adaptor/torch_utils/model_wrapper.py b/neural_compressor/adaptor/torch_utils/model_wrapper.py index d30182a7b9e..57103566d9d 100644 --- a/neural_compressor/adaptor/torch_utils/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/model_wrapper.py @@ -215,10 +215,12 @@ def __init__( scale_dtype=torch.float32, compression_dtype=torch.int32, compression_dim=1, - gptq_perm=False, + g_idx=False, device="cpu", + use_hf_format=False, ): super().__init__() + self.use_hf_format = use_hf_format self.dtype = dtype if "int" not in self.dtype: # for nf4, fp4 from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING @@ -249,69 +251,105 @@ def __init__( assert compression_dim in [0, 1], ( "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel." ) - self.register_buffer( - "scale", - torch.zeros( - (out_features, math.ceil(in_features / self.groupsize)), - dtype=self.float_type, - ).to(device), - ) - if compression_dim == 1: + if self.use_hf_format: self.register_buffer( - "packed_weight", + "scales", torch.zeros( - (out_features, math.ceil(in_features / self.n_pack)), + (math.ceil(in_features / self.groupsize), out_features), + dtype=self.float_type, + ).to(device), + ) + self.scales = self.scales.T + self.register_buffer( + "qweight", + torch.zeros( + (math.ceil(in_features / self.n_pack), out_features), dtype=self.compressed_dtype, ).to(device), ) - if zp: - self.register_buffer( - "packed_zp", - torch.zeros( - (self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)), - dtype=self.compressed_dtype, - ).to(device), - ) - else: + self.qweight = self.qweight.T self.register_buffer( - "packed_weight", + "qzeros", torch.zeros( - (math.ceil(out_features / self.n_pack), in_features), + (math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)), dtype=self.compressed_dtype, ).to(device), ) - if zp: + self.qzeros = self.qzeros.T + else: + self.register_buffer( + "scales", + torch.zeros( + (out_features, math.ceil(in_features / self.groupsize)), + dtype=self.float_type, + ).to(device), + ) + if compression_dim == 1: self.register_buffer( - "packed_zp", + "qweight", torch.zeros( - (math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)), + (out_features, math.ceil(in_features / self.n_pack)), dtype=self.compressed_dtype, ).to(device), ) + if zp: + self.register_buffer( + "qzeros", + torch.zeros( + (self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)), + dtype=self.compressed_dtype, + ).to(device), + ) + else: + self.register_buffer( + "qweight", + torch.zeros( + (math.ceil(out_features / self.n_pack), in_features), + dtype=self.compressed_dtype, + ).to(device), + ) + if zp: + self.register_buffer( + "qzeros", + torch.zeros( + (math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)), + dtype=self.compressed_dtype, + ).to(device), + ) + if g_idx: + self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) + else: + self.g_idx = None if bias: self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) else: self.bias = None - if gptq_perm: - self.register_buffer("gptq_perm", torch.zeros(in_features, dtype=torch.int32).to(device)) - else: - self.gptq_perm = None - def pack(self, int_weight, scale, zp, bias, gptq_perm=None): + def pack(self, int_weight, scale, zp, bias, g_idx=None): int_weight = int_weight.to(self.device) + if self.use_hf_format and zp is None: + # to avoid overflow + int_weight = int_weight.type(torch.int32) + shift_bias = 2 ** (self.bits - 1) + int_weight += shift_bias + zp = torch.zeros_like(scale, dtype=torch.uint8) + shift_bias if bias is not None: assert hasattr(self, "bias"), "bias is not set when initializing." self.bias = bias.type(self.float_type).to(self.device) - if gptq_perm is not None: - assert hasattr(self, "gptq_perm"), "gptq_perm is not set when initializing." - self.gptq_perm = gptq_perm.type(torch.int32).to(self.device) - assert scale.shape == self.scale.shape, "Scale shape is mismatched." - self.scale = scale.type(self.float_type).to(self.device) - if self.compression_dim == 0: + if g_idx is not None: + assert hasattr(self, "g_idx"), "g_idx is not set when initializing." + self.g_idx = g_idx.type(torch.int32).to(self.device) + if self.use_hf_format: + invperm = torch.argsort(self.g_idx) + self.g_idx = invperm // self.groupsize + self.g_idx = self.g_idx.type(torch.int32).to(self.device) + assert scale.shape == self.scales.shape, "Scale shape is mismatched." + self.scales = scale.type(self.float_type).to(self.device) + if not self.use_hf_format and self.compression_dim == 0: int_weight = int_weight.T - self.packed_weight = self.packed_weight.T + self.qweight = self.qweight.T origin_shape = int_weight.shape - target_shape = self.packed_weight.shape + target_shape = self.qweight.shape assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(self.device) @@ -323,17 +361,19 @@ def pack(self, int_weight, scale, zp, bias, gptq_perm=None): for e in range(tmp.shape[1]): tmp[:, e] &= mask tmp[:, e] = tmp[:, e] << (self.bits * e) - self.packed_weight[:, j] |= tmp[:, e] - if self.compression_dim == 0: - self.packed_weight = self.packed_weight.T + self.qweight[:, j] |= tmp[:, e] + if not self.use_hf_format and self.compression_dim == 0: + self.qweight = self.qweight.T if zp is not None: zp = zp.to(self.device) - if self.compression_dim == 0: + if self.use_hf_format: + zp -= 1 + if self.use_hf_format or self.compression_dim == 0: zp = zp.T - self.packed_zp = self.packed_zp.T - assert hasattr(self, "packed_zp"), "zp is not set when initializing." - target_shape = self.packed_zp.shape + self.qzeros = self.qzeros.T + assert hasattr(self, "qzeros"), "zp is not set when initializing." + target_shape = self.qzeros.shape for j in range(target_shape[1]): start = self.n_pack * j end = self.n_pack * (j + 1) @@ -341,38 +381,53 @@ def pack(self, int_weight, scale, zp, bias, gptq_perm=None): for e in range(tmp.shape[1]): tmp[:, e] &= mask tmp[:, e] = tmp[:, e] << (self.bits * e) - self.packed_zp[:, j] |= tmp[:, e] - if self.compression_dim == 0: - self.packed_zp = self.packed_zp.T + self.qzeros[:, j] |= tmp[:, e] + if self.use_hf_format or self.compression_dim == 0: + self.qzeros = self.qzeros.T + if self.use_hf_format: + self.scales = self.scales.T + self.qweight = self.qweight.T + self.g_idx = self.g_idx + self.qzeros = self.qzeros.T def recover(self): logger.debug(f"Recovering {self} weight") - device = self.scale.device + if self.use_hf_format: + # Prevent broken id links of self.scales and self.scales + self.scales = self.scales.T + self.qweight = self.qweight.T + self.g_idx = self.g_idx + self.qzeros = self.qzeros.T + device = self.scales.device + fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) + if self.g_idx is None: + # used for recovering fp32_weight + self.g_idx = torch.tensor([i // self.groupsize for i in range(self.in_features)], dtype=torch.int32) mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(device) - if hasattr(self, "packed_zp"): + if hasattr(self, "qzeros"): weight_dtype = torch.uint8 else: weight_dtype = torch.int8 # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) - packed_weight = self.packed_weight - if self.compression_dim == 0: + qweight = self.qweight + if not self.use_hf_format and self.compression_dim == 0: weight = weight.T - packed_weight = packed_weight.T + qweight = qweight.T origin_shape = weight.shape - target_shape = packed_weight.shape + target_shape = qweight.shape for j in range(target_shape[1]): for e in range(self.n_pack): index = j * self.n_pack + e if index >= origin_shape[1]: continue - tmp = packed_weight[:, j] + tmp = qweight[:, j] tmp = tmp << (self.compress_bits - self.bits * (e + 1)) tmp = tmp >> self.compress_bits - self.bits if weight_dtype == torch.uint8: tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) - if self.compression_dim == 0: + if not self.use_hf_format and self.compression_dim == 0: weight = weight.T if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) @@ -380,64 +435,38 @@ def recover(self): new_weight += torch.where(weight == k, v, 0) weight = new_weight # unpack zero_point - if hasattr(self, "packed_zp"): + if hasattr(self, "qzeros"): zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp - zp = torch.zeros(self.scale.shape, dtype=zp_dtype).to(device) - packed_zp = self.packed_zp - if self.compression_dim == 0: + zp = torch.zeros(self.scales.shape, dtype=zp_dtype).to(device) + qzeros = self.qzeros + if self.use_hf_format or self.compression_dim == 0: zp = zp.T - packed_zp = packed_zp.T + qzeros = qzeros.T origin_shape = zp.shape - target_shape = packed_zp.shape + target_shape = qzeros.shape for j in range(target_shape[1]): for e in range(self.n_pack): index = j * self.n_pack + e if index >= origin_shape[1]: continue - tmp = packed_zp[:, j] + tmp = qzeros[:, j] tmp = tmp << (self.compress_bits - self.bits * (e + 1)) tmp = tmp >> self.compress_bits - self.bits tmp &= mask zp[:, index] = tmp.type(zp_dtype) - if self.compression_dim == 0: + if self.use_hf_format or self.compression_dim == 0: zp = zp.T + if self.use_hf_format: + # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 + zp += 1 + zp = torch.where(zp > (2**self.bits - 1), 0, zp) # recover fp32 weight with int_weight, scale, and zero_point - left_element = self.in_features % self.groupsize - if left_element != 0: - split_index = self.in_features // self.groupsize * self.groupsize - weight1 = weight[:, :-split_index].reshape(-1, self.groupsize) - scale1 = self.scale[:, :-1].reshape(-1, 1) - zp1 = zp[:, :-1].reshape(-1, 1) - weight1 = ((weight1 - zp1) * scale1).reshape(self.out_features, -1) - weight2 = weight[:, -split_index:] - scale2 = self.scale[:, -1:] - zp2 = zp[:, -1].reshape(-1, 1) - weight2 = (weight2 - zp2) * scale2 - fp32_weight = torch.cat((weight1, weight2), dim=1) - else: - weight = weight.reshape(-1, self.groupsize) - scale = self.scale.reshape(-1, 1) - zp = zp.reshape(-1, 1) - fp32_weight = ((weight - zp) * scale).reshape(self.out_features, -1) + for idx in range(self.in_features): + fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * self.scales[:, self.g_idx[idx]] else: # recover fp32 weight with int_weight, scale - left_element = self.in_features % self.groupsize - if left_element != 0: - split_index = self.in_features // self.groupsize * self.groupsize - weight1 = weight[:, :split_index].reshape(-1, self.groupsize) - scale1 = self.scale[:, :-1].reshape(-1, 1) - weight1 = (weight1 * scale1).reshape(self.out_features, -1) - weight2 = weight[:, split_index:] - scale2 = self.scale[:, -1:] - weight2 = weight2 * scale2 - fp32_weight = torch.cat((weight1, weight2), dim=1) - else: - weight = weight.reshape(-1, self.groupsize) - scale = self.scale.reshape(-1, 1) - fp32_weight = (weight * scale).reshape(self.out_features, -1) - if self.gptq_perm is not None: - invperm = torch.argsort(self.gptq_perm) - fp32_weight = fp32_weight[:, invperm] + for idx in range(self.in_features): + fp32_weight[:, idx] = weight[:, idx] * self.scales[:, self.g_idx[idx]] return fp32_weight def forward(self, input): @@ -453,9 +482,16 @@ def forward(self, input): return F.linear(input, weight, self.bias) def extra_repr(self) -> str: - return "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format( - self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None + tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format( + self.in_features, + self.out_features, + self.bits, + self.groupsize, + self.bias is not None, ) + if self.use_hf_format: + tmp_str += ", use_hf_format=True" + return tmp_str class FakeAffineTensorQuantFunction(Function): diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 7ba86eaa344..f866ac12410 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -396,6 +396,7 @@ def rtn_quantize( compression_dim = kwargs.get("compression_dim", 1) scale_dtype = kwargs.get("scale_dtype", torch.float32) device = kwargs.get("device", "cpu") + use_hf_format = kwargs.get("use_hf_format", False) for name, m in model.named_modules(): if m.__class__.__name__ not in supported_layers: continue @@ -448,6 +449,7 @@ def rtn_quantize( compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, + use_hf_format=use_hf_format, ) new_module.pack(int_weight, scale, zp, m.bias) if name == "": diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index eeada402f35..fb7046a1607 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -459,6 +459,7 @@ def export_compressed_model( scale_dtype=torch.float32, gptq_config_path=None, device="cpu", + use_hf_format=False, ): """Convert Linear to WeightOnlyLinear for low memory inference. @@ -474,6 +475,12 @@ def export_compressed_model( Defaults to torch.float32. gptq_config_path (str, optional): Path of gptq_config.json. Defaults to None. device (str, optional): choose device for compression. Defaults to cpu. + use_hf_format (bool, optional): use the popular huggingface compression format. + 1: compression_dim: weight = 1, zeros = 0 and both are transposed. + 2: zeros -= 1 before compression. Why we need it? + 3: g_idx: use same number for one group instead of recording the channel order. + 4. parameter name changed, such as 'packed_weight' -> 'qweight'. + 5. zeros is always needed even for sym. """ from ..adaptor.torch_utils.model_wrapper import WeightOnlyLinear from ..adaptor.torch_utils.util import collect_weight_info, fetch_module, set_module @@ -513,6 +520,7 @@ def export_compressed_model( compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, + use_hf_format=use_hf_format, ) set_module(self.model, k, new_module) continue @@ -523,9 +531,13 @@ def export_compressed_model( else: fp32_weight = m.weight.data gptq_perm = None - gptq_scale = torch.tensor(gptq_conf["scale"]) - gptq_zp = None if scheme == "sym" else torch.tensor(gptq_conf["zero"]) + gptq_scale = torch.tensor(gptq_conf["scale"], dtype=torch.float32) + gptq_zp = None if scheme == "sym" else torch.tensor(gptq_conf["zero"], dtype=torch.int32) int_weight = quant_weight_w_scale(fp32_weight, gptq_scale, gptq_zp, group_size) + int_weight = int_weight.type(torch.int32) + if "perm" in gptq_conf: + invperm = torch.argsort(gptq_perm) + int_weight = int_weight[:, invperm] new_module = WeightOnlyLinear( m.in_features, m.out_features, @@ -534,11 +546,12 @@ def export_compressed_model( dtype=dtype, zp=gptq_zp is not None, bias=m.bias is not None, - gptq_perm=gptq_perm is not None, + g_idx=gptq_perm is not None, compression_dtype=compression_dtype, compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, + use_hf_format=use_hf_format, ) new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm) set_module(self.model, k, new_module) @@ -565,6 +578,7 @@ def export_compressed_model( compression_dim=compression_dim, scale_dtype=scale_dtype, device=device, + use_hf_format=use_hf_format, ) set_module(self.model, k, mod) return self.model diff --git a/neural_compressor/utils/load_huggingface.py b/neural_compressor/utils/load_huggingface.py index c68259a4abc..fff4c050603 100644 --- a/neural_compressor/utils/load_huggingface.py +++ b/neural_compressor/utils/load_huggingface.py @@ -230,3 +230,53 @@ def save_for_huggingface_upstream(model, tokenizer, output_dir): model.model.config.architectures = [model.model.__class__.__name__] model.model.config.torch_dtype = "int8" model.model.config.save_pretrained(output_dir) + + +def export_compressed_model( + model, + saved_dir=None, + use_hf_format=False, + enable_full_range=False, + compression_dtype=torch.int32, + compression_dim=1, + scale_dtype=torch.float32, + device="cpu", +): + """Support get compressed model from saved_dir. + + Args: + model (torch.nn.Module): origin fp32 model. + saved_dir (_type_, optional): the dir path of compression info. Defaults to None. + use_hf_format (bool, optional): whether use HuggingFace format. Defaults to False. + enable_full_range (bool, optional): Whether to leverage the full compression range + under symmetric quantization. Defaults to False. + compression_dtype (torch.Tensor, optional): The target dtype after comoression. + Defaults to torch.int32. + compression_dim (int, optional): Select from [0, 1], 0 is output channel, + 1 is input channel. Defaults to 1. + scale_dtype (torch.Tensor, optional): Use float32 or float16. + Defaults to torch.float32. + device (str, optional): choose device for compression. Defaults to cpu. + """ + stat_dict = os.path.join(saved_dir, "best_model.pt") + qweight_config_path = os.path.join(saved_dir, "qconfig.json") + gptq_config_path = os.path.join(saved_dir, "gptq_config.json") + if not os.path.exists(gptq_config_path): + gptq_config_path = None + model.load_state_dict(torch.load(stat_dict)) + + from neural_compressor.model import Model as INCModel + + # pylint: disable=E1101 + inc_model = INCModel(model) + inc_model.export_compressed_model( + qweight_config_path=qweight_config_path, + enable_full_range=enable_full_range, + compression_dtype=compression_dtype, + compression_dim=compression_dim, + scale_dtype=scale_dtype, + gptq_config_path=gptq_config_path, + device=device, + use_hf_format=use_hf_format, + ) + return inc_model.model diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py index 2e6e5b85ee0..47202b86b52 100644 --- a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py @@ -8,12 +8,15 @@ from neural_compressor import PostTrainingQuantConfig, quantization from neural_compressor.adaptor.torch_utils.model_wrapper import MulLinear, WeightOnlyLinear +from neural_compressor.model import Model as INCModel +from neural_compressor.utils.load_huggingface import export_compressed_model +from neural_compressor.utils.pytorch import load class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() - self.fc1 = torch.nn.Linear(30, 50) + self.fc1 = torch.nn.Linear(30, 50, bias=True) self.fc2 = torch.nn.Linear(50, 30) self.fc3 = torch.nn.Linear(30, 5) @@ -81,13 +84,29 @@ def test_RTN_int_quant(self): approach="weight_only", ) q_model = quantization.fit(model, conf) + q_model.save("saved") out2 = q_model(input) self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) self.assertFalse(torch.all(out1 == out2)) compressed_model = q_model.export_compressed_model() out3 = compressed_model(input) + self.assertTrue("fc1.qweight" in compressed_model.state_dict().keys()) + self.assertTrue("fc1.qzeros" not in compressed_model.state_dict().keys()) + shape2 = compressed_model.state_dict()["fc1.scales"] self.assertTrue(torch.all(out3 == out2)) + # test huggingface popular int4 format + model = Model() + new_model = load("saved", model, weight_only=True) + inc_model = INCModel(new_model) + inc_model.export_compressed_model(qweight_config_path="saved/qconfig.json", use_hf_format=True) + out4 = inc_model.model(input) + self.assertTrue("fc1.qzeros" in inc_model.model.state_dict().keys()) + model = Model() + compressed_model = export_compressed_model(model, saved_dir="saved", use_hf_format=True) + self.assertTrue("fc1.qzeros" in inc_model.model.state_dict().keys()) + self.assertTrue(torch.all(out3 == out4)) + model = Model() out1 = model(input) conf = PostTrainingQuantConfig( @@ -218,7 +237,6 @@ def test_RTN_int_quant(self): self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) self.assertFalse(torch.all(out1 == out2)) q_model.save("saved") - from neural_compressor.utils.pytorch import load new_model = load("saved", model, weight_only=True) out1 = new_model(input) @@ -226,8 +244,6 @@ def test_RTN_int_quant(self): model_size1 = os.path.getsize("saved/best_model.pt") / 1024 print("FP32 Model size:{:.3f}M".format(model_size1)) - from neural_compressor.model import Model as INCModel - inc_model = INCModel(new_model) inc_model.export_compressed_model(qweight_config_path="saved/qconfig.json") torch.save(inc_model.state_dict(), "saved/tmp.pt") @@ -521,6 +537,16 @@ def __iter__(self): # # case 2: list or tuple model_2 = copy.deepcopy(self.gptj) input = torch.ones([1, 512], dtype=torch.long) + conf.op_type_dict = { + ".*": { # re.match + "weight": { + "bits": 4, # 1-8 bits + "group_size": 8, # -1 (per-channel) + "scheme": "asym", + "algorithm": "GPTQ", + }, + }, + } q_model = quantization.fit( model_2, conf, @@ -528,7 +554,7 @@ def __iter__(self): ) q_model.save("saved") out1 = q_model.model(input) - compressed_model = q_model.export_compressed_model() + compressed_model = q_model.export_compressed_model(use_hf_format=True) out2 = compressed_model(input) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) @@ -543,7 +569,7 @@ def __iter__(self): ) q_model.save("saved") out1 = q_model.model(input) - compressed_model = q_model.export_compressed_model() + compressed_model = q_model.export_compressed_model(use_hf_format=True) out2 = compressed_model(input) torch.save(compressed_model.state_dict(), "saved/compressed_model.pt") self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-05)) @@ -703,7 +729,6 @@ def __iter__(self): calib_dataloader=dataloader, ) out2 = q_model.model(input) - print(out1[0] - out2[0]) self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01)) diff --git a/test/model/test_model_pytorch.py b/test/model/test_model_pytorch.py index 7b42ef63729..05edfd9c6fb 100644 --- a/test/model/test_model_pytorch.py +++ b/test/model/test_model_pytorch.py @@ -123,8 +123,8 @@ def test_WeightOnlyLinear(self): model_size2 = os.path.getsize("saved/tmp.pt") / 1024 print("WeightOnlyLinear Model size:{:.3f}M".format(model_size2)) self.assertTrue(isinstance(inc_model.model.fc1, WeightOnlyLinear)) - self.assertTrue(inc_model.model.fc1.packed_weight.dtype == dtype) - self.assertTrue(inc_model.model.fc1.scale.dtype == torch.float32) + self.assertTrue(inc_model.model.fc1.qweight.dtype == dtype) + self.assertTrue(inc_model.model.fc1.scales.dtype == torch.float32) self.assertTrue(model_size1 / model_size2 > 2) self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) @@ -143,9 +143,9 @@ def test_WeightOnlyLinear(self): print("WeightOnlyLinear Model size:{:.3f}M".format(model_size2)) self.assertTrue(isinstance(inc_model.model.fc1, WeightOnlyLinear)) if dim == 1: - self.assertTrue(inc_model.model.fc1.packed_weight.shape[0] == inc_model.model.fc1.out_features) + self.assertTrue(inc_model.model.fc1.qweight.shape[0] == inc_model.model.fc1.out_features) else: - self.assertTrue(inc_model.model.fc1.packed_weight.shape[1] == inc_model.model.fc1.in_features) + self.assertTrue(inc_model.model.fc1.qweight.shape[1] == inc_model.model.fc1.in_features) self.assertTrue(model_size1 / model_size2 > 2) self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) @@ -161,7 +161,7 @@ def test_WeightOnlyLinear(self): model_size2 = os.path.getsize("saved/tmp.pt") / 1024 print("WeightOnlyLinear Model size:{:.3f}M".format(model_size2)) self.assertTrue(isinstance(inc_model.model.fc1, WeightOnlyLinear)) - self.assertTrue(inc_model.model.fc1.scale.dtype == torch.float16) + self.assertTrue(inc_model.model.fc1.scales.dtype == torch.float16) self.assertTrue(model_size1 / model_size2 > 2) self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1)))