From ec3e0659b8ea858c37cf6487fc2a2a49cb7122c2 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Fri, 22 Aug 2025 09:55:55 +0000 Subject: [PATCH 01/26] Add Int4XPUTensorIntZP --- torchao/quantization/__init__.py | 1 + torchao/quantization/quant_api.py | 8 +++++++- torchao/quantization/quantize_/common/packing_format.py | 3 +++ torchao/quantization/quantize_/workflows/__init__.py | 4 ++++ torchao/quantization/quantize_/workflows/int4/__init__.py | 2 ++ 5 files changed, 17 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 3c541deb83..609dae39d8 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -93,6 +93,7 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, + Int4XPUTensorIntZP, IntxUnpackedTensor, ) from .smoothquant import ( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9bdd8133aa..d0b8cf1a33 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -76,6 +76,7 @@ Int4PreshuffledTensor, Int4Tensor, IntxUnpackedTensor, + Int4XPUTensorIntZP, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -518,7 +519,6 @@ def quantize_( torch._C._log_api_usage_once("torchao.quantization.quantize_") filter_fn = _is_linear if filter_fn is None else filter_fn - if isinstance(config, ModuleFqnToConfig): _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -1080,6 +1080,12 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight + elif packing_format == PackingFormat.INT4_XPU_INT_ZP: + new_weight = Int4XPUTensorIntZP.from_hp( + weight, + block_size + ) + return new_weight else: raise ValueError(f"Unsupported packing format: {packing_format}") diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 89acf4eff3..d761f22df8 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -40,3 +40,6 @@ class PackingFormat(str, Enum): Unpacked means the subbyte quantized data is stored as int8 """ UNPACKED_TO_INT8 = "unpacked_to_int8" + + "int4_xpu_int_zp is referring to the format used by int4 weight-only quantization on XPU with int zero point, which is a groupwise quantization format." + INT4_XPU_INT_ZP = "int4_xpu_int_zp" diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 9eeb0e7dc5..c86d10fbbe 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -11,6 +11,9 @@ from .int4.int4_tensor import ( Int4Tensor, ) +from .int4.int4_xpu_tensor import ( + Int4XPUTensorIntZP, +) from .intx.intx_unpacked_tensor import ( IntxUnpackedTensor, ) @@ -19,6 +22,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", + "Int4XPUTensorIntZP", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "IntxUnpackedTensor", diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py index 3394822214..2c290c5e35 100644 --- a/torchao/quantization/quantize_/workflows/int4/__init__.py +++ b/torchao/quantization/quantize_/workflows/int4/__init__.py @@ -1,7 +1,9 @@ from .int4_preshuffled_tensor import Int4PreshuffledTensor from .int4_tensor import Int4Tensor +from .int4_xpu_tensor import Int4XPUTensorIntZP __all__ = [ "Int4PreshuffledTensor", "Int4Tensor", + "Int4XPUTensorIntZP", ] From 1dc5b2ce876890059aea36b5a53fcbf370b77744 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Fri, 22 Aug 2025 10:28:50 +0000 Subject: [PATCH 02/26] Add int4_xpu_tensor --- .../quantize_/workflows/int4/test_int4_xpu.py | 85 ++++++++ .../workflows/int4/int4_xpu_tensor.py | 185 ++++++++++++++++++ 2 files changed, 270 insertions(+) create mode 100644 test/quantization/quantize_/workflows/int4/test_int4_xpu.py create mode 100644 torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py diff --git a/test/quantization/quantize_/workflows/int4/test_int4_xpu.py b/test/quantization/quantize_/workflows/int4/test_int4_xpu.py new file mode 100644 index 0000000000..f0e7580ead --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_xpu.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(group_size): + return Int4WeightOnlyConfig( + group_size=group_size, + packing_format="int4_xpu_int_zp", + version=2, + ) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +class Int4XPUTensorIntZP(TestCase): + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + @parametrize("dtype", [torch.bfloat16, torch.half]) + @parametrize("group_size", [32, 64, 128]) + def test_linear(self, sizes, dtype, group_size): + device = "xpu" + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(group_size)) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @parametrize("dtype", [torch.bfloat16, torch.half]) + def test_module_path(self, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu") + quantize_(linear, get_config(group_size=128)) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(Int4XPUTensorIntZP) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py new file mode 100644 index 0000000000..007e1fc3fd --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + _choose_qparams_affine, + _quantize_affine, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +__all__ = [ + "Int4TinyGemmCpuTensor", +] + +aten = torch.ops.aten + + +class Int4XPUTensorIntZP(TorchAOBaseTensor): + """ + int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) + + Tensor Attributes: + qdata: preshuffled and packed int4 weight for tinygemm, always viewed as a 2D (N, K/2) tensor, last dimension is packed + preshuffling is specific to CPU kernels, see Note below. + scale: (K/group_size, N), dtype is the same as the original Tensor dtype + zero_point: (K/group_size, N) + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size). + we only support group_size = 32/64/128. + shape: shape of the original Tensor + + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] + + def __new__( + cls, + qdata, + scale, + zero_point, + block_size, + shape, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, scale, zero_point, block_size, shape): + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + #self.shape = shape + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + assert w.ndim == 2 and w.device.type == "xpu", ( + f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + scale_dtype = None + zero_point_dtype = torch.int32 + scale, zero_point = _choose_qparams_affine( + w, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = _quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + ) + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to( + torch.uint8 + ) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + packed_weight.contiguous(), 8 + ) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + return Int4XPUTensorIntZP( + packed_weight, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous().to(torch.int8), + block_size, + original_shape, + ) + + +implements = Int4XPUTensorIntZP.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "xpu", ( + f"For XPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Int4XPUTensorIntZP), ( + f"Expected weight_tensor to be Int4XPUTensorIntZP, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + act_mat = input_tensor + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + zero_point = weight_tensor.zero_point + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + act_mat, packed_weight, groupsize, scale, zero_point + ) + + # remove out_feature padding + assert weight_tensor.ndim == 2 + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +Int4XPUTensorIntZP.__module__ = "torchao.quantization" + +# Allow a model with Int4XPUTensorIntZP weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4XPUTensorIntZP]) \ No newline at end of file From e63b100263d297b7318e13b17bb8aae2191953ee Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 25 Aug 2025 15:37:25 +0800 Subject: [PATCH 03/26] Update int4_xpu_tensor.py --- .../quantization/quantize_/workflows/int4/int4_xpu_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py index 007e1fc3fd..471f9ab764 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py @@ -19,7 +19,7 @@ ) __all__ = [ - "Int4TinyGemmCpuTensor", + "Int4XPUTensorIntZP", ] aten = torch.ops.aten @@ -182,4 +182,4 @@ def _(func, types, args, kwargs): Int4XPUTensorIntZP.__module__ = "torchao.quantization" # Allow a model with Int4XPUTensorIntZP weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([Int4XPUTensorIntZP]) \ No newline at end of file +torch.serialization.add_safe_globals([Int4XPUTensorIntZP]) From 5ef1ca214a166fc44c1809ac09e509b258867816 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 25 Aug 2025 15:54:36 +0000 Subject: [PATCH 04/26] Fix typo --- .../quantization/quantize_/workflows/int4/int4_xpu_tensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py index 471f9ab764..39975f37bd 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py @@ -30,7 +30,7 @@ class Int4XPUTensorIntZP(TorchAOBaseTensor): int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) Tensor Attributes: - qdata: preshuffled and packed int4 weight for tinygemm, always viewed as a 2D (N, K/2) tensor, last dimension is packed + qdata: packed int4 weigh, always viewed as a 2D (N, K/2) tensor, last dimension is packed preshuffling is specific to CPU kernels, see Note below. scale: (K/group_size, N), dtype is the same as the original Tensor dtype zero_point: (K/group_size, N) @@ -64,7 +64,6 @@ def __init__(self, qdata, scale, zero_point, block_size, shape): self.scale = scale self.zero_point = zero_point self.block_size = block_size - #self.shape = shape def _quantization_type(self): return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" From a28dd8948a003e8b3af9f5cc3ec732efd9d3df25 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 25 Aug 2025 15:59:46 +0000 Subject: [PATCH 05/26] Fix code format issue --- torchao/quantization/quant_api.py | 5 +---- torchao/quantization/quantize_/common/packing_format.py | 2 +- .../quantize_/workflows/int4/int4_xpu_tensor.py | 6 ++---- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d0b8cf1a33..a78d810b91 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1081,10 +1081,7 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.INT4_XPU_INT_ZP: - new_weight = Int4XPUTensorIntZP.from_hp( - weight, - block_size - ) + new_weight = Int4XPUTensorIntZP.from_hp(weight, block_size) return new_weight else: raise ValueError(f"Unsupported packing format: {packing_format}") diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index d761f22df8..005639570f 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -40,6 +40,6 @@ class PackingFormat(str, Enum): Unpacked means the subbyte quantized data is stored as int8 """ UNPACKED_TO_INT8 = "unpacked_to_int8" - + "int4_xpu_int_zp is referring to the format used by int4 weight-only quantization on XPU with int zero point, which is a groupwise quantization format." INT4_XPU_INT_ZP = "int4_xpu_int_zp" diff --git a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py index 39975f37bd..268ff6fd43 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py @@ -78,7 +78,7 @@ def from_hp( f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" ) assert len(block_size) == w.ndim - + original_shape = w.shape mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int32 @@ -110,9 +110,7 @@ def from_hp( assert int_data.dtype == torch.int32, ( "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" ) - packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to( - torch.uint8 - ) + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) packed_weight = torch.ops.aten._convert_weight_to_int4pack( packed_weight.contiguous(), 8 ) From 8a0f124b3021b7cd858cdf452ae1d7e374c668d1 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 25 Aug 2025 16:06:20 +0000 Subject: [PATCH 06/26] fix bug --- torchao/quantization/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 609dae39d8..23725abb48 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -163,6 +163,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", + "Int4XPUTensorIntZP", "IntxUnpackedTensor", "Float8Tensor", # smooth quant - subject to change From a0ff36fe84cd4422c300d8093ce491ebbcfa06e3 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 25 Aug 2025 16:23:17 +0000 Subject: [PATCH 07/26] Fix code format --- torchao/quantization/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 23725abb48..28d569fc4a 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -163,7 +163,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", - "Int4XPUTensorIntZP", + "Int4XPUTensorIntZP", "IntxUnpackedTensor", "Float8Tensor", # smooth quant - subject to change From 2c4c2cec974d5744118cc9d4c9975161e72bdd63 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Tue, 26 Aug 2025 10:37:58 +0800 Subject: [PATCH 08/26] Update int4_xpu_tensor.py --- .../quantization/quantize_/workflows/int4/int4_xpu_tensor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py index 268ff6fd43..ebeaf70661 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py @@ -30,8 +30,7 @@ class Int4XPUTensorIntZP(TorchAOBaseTensor): int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) Tensor Attributes: - qdata: packed int4 weigh, always viewed as a 2D (N, K/2) tensor, last dimension is packed - preshuffling is specific to CPU kernels, see Note below. + qdata: packed int4 weigh, always viewed as a 2D (N, K/2) tensor scale: (K/group_size, N), dtype is the same as the original Tensor dtype zero_point: (K/group_size, N) @@ -108,7 +107,7 @@ def from_hp( quant_max, ) assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" ) packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) packed_weight = torch.ops.aten._convert_weight_to_int4pack( From e48ea0ba271c8a8343cff6f8d6b774db597d6886 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Tue, 26 Aug 2025 13:36:49 +0000 Subject: [PATCH 09/26] change the pack format to plain --- .../quantize_/workflows/int4/test_int4_xpu.py | 6 +++++- torchao/quantization/quant_api.py | 17 +++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_xpu.py b/test/quantization/quantize_/workflows/int4/test_int4_xpu.py index f0e7580ead..149e2fb7cb 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_xpu.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_xpu.py @@ -19,6 +19,9 @@ Int4WeightOnlyConfig, quantize_, ) +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, +) from torchao.quantization.utils import compute_error from torchao.utils import ( torch_version_at_least, @@ -28,7 +31,8 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="int4_xpu_int_zp", + packing_format="plain", + zero_point_domain=ZeroPointDomain.INT, version=2, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c7723385fc..80197f178b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,8 +75,8 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, - IntxUnpackedTensor, Int4XPUTensorIntZP, + IntxUnpackedTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -1069,10 +1069,14 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.PLAIN: - new_weight = Int4Tensor.from_hp( - weight, - block_size, - ) + if "xpu" in weight.device and zero_point_domain == ZeroPointDomain.INT: + new_weight = Int4XPUTensorIntZP.from_hp(weight, block_size) + return new_weight + else: + new_weight = Int4Tensor.from_hp( + weight, + block_size, + ) return new_weight elif packing_format == PackingFormat.MARLIN_SPARSE: new_weight = Int4MarlinSparseTensor.from_hp( @@ -1080,9 +1084,6 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight - elif packing_format == PackingFormat.INT4_XPU_INT_ZP: - new_weight = Int4XPUTensorIntZP.from_hp(weight, block_size) - return new_weight else: raise ValueError(f"Unsupported packing format: {packing_format}") From c4e5b9db85a552b6f56a148b95fcee0681effc04 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Tue, 26 Aug 2025 13:45:32 +0000 Subject: [PATCH 10/26] fix typo --- torchao/quantization/quant_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 80197f178b..8e2d0efa84 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1069,7 +1069,7 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.PLAIN: - if "xpu" in weight.device and zero_point_domain == ZeroPointDomain.INT: + if "xpu" in weight.device.type and zero_point_domain == ZeroPointDomain.INT: new_weight = Int4XPUTensorIntZP.from_hp(weight, block_size) return new_weight else: From 7063e56f44950a7597c9330fcf14279a7f1094fe Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Tue, 26 Aug 2025 13:47:57 +0800 Subject: [PATCH 11/26] Update quant_api.py --- torchao/quantization/quant_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8e2d0efa84..713222517c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1071,7 +1071,6 @@ def _int4_weight_only_quantize_tensor(weight, config): elif packing_format == PackingFormat.PLAIN: if "xpu" in weight.device.type and zero_point_domain == ZeroPointDomain.INT: new_weight = Int4XPUTensorIntZP.from_hp(weight, block_size) - return new_weight else: new_weight = Int4Tensor.from_hp( weight, From 5b87d8bb254daa84911e48da9c93bd5d26906712 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Thu, 28 Aug 2025 07:31:19 +0000 Subject: [PATCH 12/26] merge main branch --- .../quantize_/workflows/int4/test_int4_xpu.py | 2 +- torchao/quantization/quant_api.py | 15 +++++++++++---- .../quantize_/common/packing_format.py | 9 ++++++--- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_xpu.py b/test/quantization/quantize_/workflows/int4/test_int4_xpu.py index 149e2fb7cb..06ee086069 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_xpu.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_xpu.py @@ -31,7 +31,7 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="plain", + packing_format="plain_int32", zero_point_domain=ZeroPointDomain.INT, version=2, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 713222517c..05aa10763c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1069,10 +1069,17 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.PLAIN: - if "xpu" in weight.device.type and zero_point_domain == ZeroPointDomain.INT: - new_weight = Int4XPUTensorIntZP.from_hp(weight, block_size) - else: - new_weight = Int4Tensor.from_hp( + new_weight = Int4Tensor.from_hp( + weight, + block_size, + ) + return new_weight + elif packing_format == PackingFormat.PLAIN_INT32: + if ( + "xpu" in weight.device.dtype + and zero_point_domain == ZeroPointDomain.INT + ): + new_weight = Int4XPUTensorIntZP.from_hp( weight, block_size, ) diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 0dd9d4ee86..ed061b626a 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -41,9 +41,12 @@ class PackingFormat(str, Enum): """ UNPACKED_TO_INT8 = "unpacked_to_int8" - "int4_xpu_int_zp is referring to the format used by int4 weight-only quantization on XPU with int zero point, which is a groupwise quantization format." - INT4_XPU_INT_ZP = "int4_xpu_int_zp" - + """ + plain_int32 is referring to the format used by int4 weight-only quantization. + which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32. + """ + PLAIN_INT32 = "plain_int32" + """ Opaque packing format that's used for tensors that does not have a predefined packing format (that may be decided on hardware, tensor shape, library availability etc.) and it's not From a047c0078c65fa2b4721e16e30c01a02d5f54c37 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Fri, 29 Aug 2025 09:45:21 +0000 Subject: [PATCH 13/26] change Int4XPUTensorIntZP to Int4PlainInt32 --- ..._xpu.py => test_int4_plain_int32_tensor.py} | 8 ++++---- torchao/quantization/__init__.py | 10 ++-------- torchao/quantization/quant_api.py | 14 +++++--------- .../quantize_/workflows/__init__.py | 9 +++------ .../quantize_/workflows/int4/__init__.py | 4 ++-- ...pu_tensor.py => int4_plain_int32_tensor.py} | 18 +++++++++--------- 6 files changed, 25 insertions(+), 38 deletions(-) rename test/quantization/quantize_/workflows/int4/{test_int4_xpu.py => test_int4_plain_int32_tensor.py} (91%) rename torchao/quantization/quantize_/workflows/int4/{int4_xpu_tensor.py => int4_plain_int32_tensor.py} (90%) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_xpu.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py similarity index 91% rename from test/quantization/quantize_/workflows/int4/test_int4_xpu.py rename to test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index 06ee086069..22870244da 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_xpu.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -38,7 +38,7 @@ def get_config(group_size): @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") -class Int4XPUTensorIntZP(TestCase): +class Int4PlainInt32(TestCase): @parametrize( "sizes", [ @@ -69,7 +69,7 @@ def test_module_path(self, dtype): quantize_(linear, get_config(group_size=128)) self.assertEqual( str(type(linear.weight)), - "", + "", ) with tempfile.NamedTemporaryFile() as f: @@ -78,11 +78,11 @@ def test_module_path(self, dtype): state_dict = torch.load(f) self.assertEqual( str(type(state_dict["weight"])), - "", + "", ) -instantiate_parametrized_tests(Int4XPUTensorIntZP) +instantiate_parametrized_tests(Int4PlainInt32) if __name__ == "__main__": diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index e369d14b70..c8bf61ab34 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -91,12 +91,9 @@ from .quantize_.workflows import ( Float8Tensor, Int4MarlinSparseTensor, - Int4OpaqueTensor, + Int4PlainInt32, Int4PreshuffledTensor, Int4Tensor, - Int4XPUTensorIntZP, - IntxOpaqueTensor, - IntxUnpackedToInt8Tensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -165,11 +162,8 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", - "Int4XPUTensorIntZP", - "IntxOpaqueTensor", - "IntxUnpackedToInt8Tensor", + "Int4PlainInt32", "Float8Tensor", - "Int4OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 98da5f1c20..6458f7c4b7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -74,9 +74,9 @@ Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, + Int4PlainInt32, Int4PreshuffledTensor, Int4Tensor, - Int4XPUTensorIntZP, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, QuantizeTensorToFloat8Kwargs, @@ -1131,14 +1131,10 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.PLAIN_INT32: - if ( - "xpu" in weight.device.dtype - and zero_point_domain == ZeroPointDomain.INT - ): - new_weight = Int4XPUTensorIntZP.from_hp( - weight, - block_size, - ) + new_weight = Int4PlainInt32.from_hp( + weight, + block_size, + ) return new_weight elif packing_format == PackingFormat.MARLIN_SPARSE: new_weight = Int4MarlinSparseTensor.from_hp( diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index e3b4c2307e..bbcef7acda 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -5,8 +5,8 @@ from .int4.int4_marlin_sparse_tensor import ( Int4MarlinSparseTensor, ) -from .int4.int4_opaque_tensor import ( - Int4OpaqueTensor, +from .int4.int4_plain_int32_tensor import ( + Int4PlainInt32, ) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, @@ -14,9 +14,6 @@ from .int4.int4_tensor import ( Int4Tensor, ) -from .int4.int4_xpu_tensor import ( - Int4XPUTensorIntZP, -) from .intx.intx_opaque_tensor import ( IntxOpaqueTensor, ) @@ -28,7 +25,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", - "Int4XPUTensorIntZP", + "Int4PlainInt32", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "IntxOpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py index 2c290c5e35..37ff03a915 100644 --- a/torchao/quantization/quantize_/workflows/int4/__init__.py +++ b/torchao/quantization/quantize_/workflows/int4/__init__.py @@ -1,9 +1,9 @@ +from .int4_plain_int32_tensor import Int4PlainInt32 from .int4_preshuffled_tensor import Int4PreshuffledTensor from .int4_tensor import Int4Tensor -from .int4_xpu_tensor import Int4XPUTensorIntZP __all__ = [ "Int4PreshuffledTensor", "Int4Tensor", - "Int4XPUTensorIntZP", + "Int4PlainInt32", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py similarity index 90% rename from torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py rename to torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index ebeaf70661..f94445bace 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -19,13 +19,13 @@ ) __all__ = [ - "Int4XPUTensorIntZP", + "Int4PlainInt32", ] aten = torch.ops.aten -class Int4XPUTensorIntZP(TorchAOBaseTensor): +class Int4PlainInt32(TorchAOBaseTensor): """ int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) @@ -115,7 +115,7 @@ def from_hp( ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) - return Int4XPUTensorIntZP( + return Int4PlainInt32( packed_weight, scale.transpose(0, 1).contiguous(), zero_point.transpose(0, 1).contiguous().to(torch.int8), @@ -124,7 +124,7 @@ def from_hp( ) -implements = Int4XPUTensorIntZP.implements +implements = Int4PlainInt32.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -137,8 +137,8 @@ def _(func, types, args, kwargs): assert input_tensor.device.type == "xpu", ( f"For XPU device only but got: {input_tensor.device}" ) - assert isinstance(weight_tensor, Int4XPUTensorIntZP), ( - f"Expected weight_tensor to be Int4XPUTensorIntZP, got: {type(weight_tensor)}" + assert isinstance(weight_tensor, Int4PlainInt32), ( + f"Expected weight_tensor to be Int4PlainInt32, got: {type(weight_tensor)}" ) assert weight_tensor.block_size[0] == 1, ( f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" @@ -175,7 +175,7 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) -Int4XPUTensorIntZP.__module__ = "torchao.quantization" +Int4PlainInt32.__module__ = "torchao.quantization" -# Allow a model with Int4XPUTensorIntZP weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([Int4XPUTensorIntZP]) +# Allow a model with Int4PlainInt32 weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PlainInt32]) From 3f70b2b17f2a534aab0e4852a62c7f16963c7a61 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Fri, 29 Aug 2025 09:52:34 +0800 Subject: [PATCH 14/26] Update __init__.py --- torchao/quantization/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8bf61ab34..7eec555461 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -91,9 +91,12 @@ from .quantize_.workflows import ( Float8Tensor, Int4MarlinSparseTensor, - Int4PlainInt32, + Int4OpaqueTensor, Int4PreshuffledTensor, Int4Tensor, + Int4PlainInt32, + IntxOpaqueTensor, + IntxUnpackedToInt8Tensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, From 8d2acd2fc0145efabdaff923d2ff71bf598f3afb Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Fri, 29 Aug 2025 09:54:24 +0800 Subject: [PATCH 15/26] Update __init__.py --- torchao/quantization/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 7eec555461..1cf1ee43a5 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -163,10 +163,13 @@ "FbgemmConfig", # tensor subclasses "Int4Tensor", + "Int4PlainInt32", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", - "Int4PlainInt32", + "IntxOpaqueTensor", + "IntxUnpackedToInt8Tensor", "Float8Tensor", + "Int4OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", From 43acd6694fa69427a677d48ba2526c16ba06d8d0 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Fri, 29 Aug 2025 09:55:21 +0800 Subject: [PATCH 16/26] Update __init__.py --- torchao/quantization/quantize_/workflows/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index bbcef7acda..a9158296ab 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -5,6 +5,9 @@ from .int4.int4_marlin_sparse_tensor import ( Int4MarlinSparseTensor, ) +from .int4.int4_opaque_tensor import ( + Int4OpaqueTensor, +) from .int4.int4_plain_int32_tensor import ( Int4PlainInt32, ) From 402dd721b7525352f8374eb0313ab5899e5db742 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Fri, 29 Aug 2025 14:03:53 +0000 Subject: [PATCH 17/26] Refine code --- .../int4/test_int4_plain_int32_tensor.py | 12 ++++-------- torchao/quantization/__init__.py | 6 +++--- torchao/quantization/quant_api.py | 4 ++-- .../workflows/int4/int4_plain_int32_tensor.py | 15 +++++++-------- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index 22870244da..2d971be2d8 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -19,9 +19,6 @@ Int4WeightOnlyConfig, quantize_, ) -from torchao.quantization.quant_primitives import ( - ZeroPointDomain, -) from torchao.quantization.utils import compute_error from torchao.utils import ( torch_version_at_least, @@ -32,13 +29,12 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, packing_format="plain_int32", - zero_point_domain=ZeroPointDomain.INT, version=2, ) @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") -class Int4PlainInt32(TestCase): +class Int4PlainInt32Tensor(TestCase): @parametrize( "sizes", [ @@ -69,7 +65,7 @@ def test_module_path(self, dtype): quantize_(linear, get_config(group_size=128)) self.assertEqual( str(type(linear.weight)), - "", + "", ) with tempfile.NamedTemporaryFile() as f: @@ -78,11 +74,11 @@ def test_module_path(self, dtype): state_dict = torch.load(f) self.assertEqual( str(type(state_dict["weight"])), - "", + "", ) -instantiate_parametrized_tests(Int4PlainInt32) +instantiate_parametrized_tests(Int4PlainInt32Tensor) if __name__ == "__main__": diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 1cf1ee43a5..5d99fc02f6 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -92,9 +92,9 @@ Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, + Int4PlainInt32Tensor, Int4PreshuffledTensor, Int4Tensor, - Int4PlainInt32, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -163,10 +163,10 @@ "FbgemmConfig", # tensor subclasses "Int4Tensor", - "Int4PlainInt32", + "Int4PlainInt32Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", - "IntxOpaqueTensor", + "IntxOpaqueTensor", "IntxUnpackedToInt8Tensor", "Float8Tensor", "Int4OpaqueTensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6458f7c4b7..8d9c4658cc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -74,7 +74,7 @@ Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, - Int4PlainInt32, + Int4PlainInt32Tensor, Int4PreshuffledTensor, Int4Tensor, IntxOpaqueTensor, @@ -1131,7 +1131,7 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.PLAIN_INT32: - new_weight = Int4PlainInt32.from_hp( + new_weight = Int4PlainInt32Tensor.from_hp( weight, block_size, ) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index f94445bace..73493496d4 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -11,8 +11,8 @@ from torchao.quantization.quant_primitives import ( MappingType, - _choose_qparams_affine, - _quantize_affine, + choose_qparams_affine, + quantize_affine, ) from torchao.utils import ( TorchAOBaseTensor, @@ -30,13 +30,12 @@ class Int4PlainInt32(TorchAOBaseTensor): int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) Tensor Attributes: - qdata: packed int4 weigh, always viewed as a 2D (N, K/2) tensor + qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2) scale: (K/group_size, N), dtype is the same as the original Tensor dtype - zero_point: (K/group_size, N) + zero_point: (K/group_size, N), dtype is int8 Non-Tensor Attributes: - block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size). - we only support group_size = 32/64/128. + block_size: the block size for quantization, representing the granularity. shape: shape of the original Tensor """ @@ -86,7 +85,7 @@ def from_hp( eps = 1e-6 scale_dtype = None zero_point_dtype = torch.int32 - scale, zero_point = _choose_qparams_affine( + scale, zero_point = choose_qparams_affine( w, mapping_type.name, block_size, @@ -97,7 +96,7 @@ def from_hp( scale_dtype, zero_point_dtype, ) - int_data = _quantize_affine( + int_data = quantize_affine( w, block_size, scale, From 282f1a8c017973e31adcbe67559c2773265b10bd Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Fri, 29 Aug 2025 14:07:53 +0000 Subject: [PATCH 18/26] Refine code --- .../quantize_/workflows/__init__.py | 4 ++-- .../quantize_/workflows/int4/__init__.py | 4 ++-- .../workflows/int4/int4_plain_int32_tensor.py | 20 +++++++++---------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index a9158296ab..92a2307fb9 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -9,7 +9,7 @@ Int4OpaqueTensor, ) from .int4.int4_plain_int32_tensor import ( - Int4PlainInt32, + Int4PlainInt32Tensor, ) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, @@ -28,7 +28,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", - "Int4PlainInt32", + "Int4PlainInt32Tensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "IntxOpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py index 37ff03a915..abb3e79554 100644 --- a/torchao/quantization/quantize_/workflows/int4/__init__.py +++ b/torchao/quantization/quantize_/workflows/int4/__init__.py @@ -1,9 +1,9 @@ -from .int4_plain_int32_tensor import Int4PlainInt32 +from .int4_plain_int32_tensor import Int4PlainInt32Tensor from .int4_preshuffled_tensor import Int4PreshuffledTensor from .int4_tensor import Int4Tensor __all__ = [ "Int4PreshuffledTensor", "Int4Tensor", - "Int4PlainInt32", + "Int4PlainInt32Tensor", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index 73493496d4..5e84be446f 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -19,13 +19,13 @@ ) __all__ = [ - "Int4PlainInt32", + "Int4PlainInt32Tensor", ] aten = torch.ops.aten -class Int4PlainInt32(TorchAOBaseTensor): +class Int4PlainInt32Tensor(TorchAOBaseTensor): """ int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) @@ -87,7 +87,7 @@ def from_hp( zero_point_dtype = torch.int32 scale, zero_point = choose_qparams_affine( w, - mapping_type.name, + mapping_type, block_size, target_dtype, quant_min, @@ -114,7 +114,7 @@ def from_hp( ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) - return Int4PlainInt32( + return Int4PlainInt32Tensor( packed_weight, scale.transpose(0, 1).contiguous(), zero_point.transpose(0, 1).contiguous().to(torch.int8), @@ -123,7 +123,7 @@ def from_hp( ) -implements = Int4PlainInt32.implements +implements = Int4PlainInt32Tensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -136,8 +136,8 @@ def _(func, types, args, kwargs): assert input_tensor.device.type == "xpu", ( f"For XPU device only but got: {input_tensor.device}" ) - assert isinstance(weight_tensor, Int4PlainInt32), ( - f"Expected weight_tensor to be Int4PlainInt32, got: {type(weight_tensor)}" + assert isinstance(weight_tensor, Int4PlainInt32Tensor), ( + f"Expected weight_tensor to be Int4PlainInt32Tensor, got: {type(weight_tensor)}" ) assert weight_tensor.block_size[0] == 1, ( f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" @@ -174,7 +174,7 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) -Int4PlainInt32.__module__ = "torchao.quantization" +Int4PlainInt32Tensor.__module__ = "torchao.quantization" -# Allow a model with Int4PlainInt32 weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([Int4PlainInt32]) +# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PlainInt32Tensor]) From b68beefda9b695a3d4c33b802018ebf6c08a11ff Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 1 Sep 2025 09:13:03 +0000 Subject: [PATCH 19/26] Add more comments about the original weight dtype --- .../quantize_/workflows/int4/int4_plain_int32_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index 5e84be446f..598faee932 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -30,7 +30,7 @@ class Int4PlainInt32Tensor(TorchAOBaseTensor): int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) Tensor Attributes: - qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2) + qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2), the original data type can be half and bfloat16 scale: (K/group_size, N), dtype is the same as the original Tensor dtype zero_point: (K/group_size, N), dtype is int8 From cd781fcc827ef516969c2bb95738dc04b625ac1e Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 1 Sep 2025 09:14:20 +0800 Subject: [PATCH 20/26] Update __init__.py From afadf6956abca8c9511ca511ac2aa82bf6e7cbfc Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 1 Sep 2025 09:15:21 +0800 Subject: [PATCH 21/26] Update __init__.py From 105b4b9aca81ca597a4eee139ac05952d4e89dc2 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 1 Sep 2025 09:20:54 +0000 Subject: [PATCH 22/26] fix code format issue --- torchao/quantization/quantize_/common/packing_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 480b6a7050..94d45917b9 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -42,12 +42,12 @@ class PackingFormat(str, Enum): UNPACKED_TO_INT8 = "unpacked_to_int8" """ - plain_int32 is referring to the format used by int4 weight-only quantization. which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32. """ PLAIN_INT32 = "plain_int32" + """ tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization """ TILE_PACKED_TO_4D = "tile_packed_to_4d" From b24ff1ae815709ca895dfc35a9d6edfe05604ba6 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 1 Sep 2025 09:22:35 +0000 Subject: [PATCH 23/26] fix code format issue --- torchao/quantization/quantize_/workflows/int4/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py index 8b13789179..e69de29bb2 100644 --- a/torchao/quantization/quantize_/workflows/int4/__init__.py +++ b/torchao/quantization/quantize_/workflows/int4/__init__.py @@ -1 +0,0 @@ - From 77868bc39420b9e467843c28350fb86e50dbdf36 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Mon, 1 Sep 2025 09:56:09 +0000 Subject: [PATCH 24/26] skip ut if no xpu --- .../quantize_/workflows/int4/test_int4_plain_int32_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index 2d971be2d8..6c5b843333 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -34,6 +34,7 @@ def get_config(group_size): @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.xpu.is_available(), "CUDA not available") class Int4PlainInt32Tensor(TestCase): @parametrize( "sizes", From 970aa17de2fea24c743036b16761d0df90e6b929 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 1 Sep 2025 09:55:24 +0800 Subject: [PATCH 25/26] Update test_int4_plain_int32_tensor.py --- .../quantize_/workflows/int4/test_int4_plain_int32_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py index 6c5b843333..d7d793685e 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -34,7 +34,7 @@ def get_config(group_size): @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") -@unittest.skipIf(not torch.xpu.is_available(), "CUDA not available") +@unittest.skipIf(not torch.xpu.is_available(), "XPU not available") class Int4PlainInt32Tensor(TestCase): @parametrize( "sizes", From 78f6bb267c4a57e0ecb93f52f523c01b4a658942 Mon Sep 17 00:00:00 2001 From: "Liangang,Zhang" Date: Thu, 4 Sep 2025 08:16:18 +0000 Subject: [PATCH 26/26] Add assert for the original weight data type --- .../quantize_/workflows/int4/int4_plain_int32_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py index 598faee932..388134f040 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -76,7 +76,9 @@ def from_hp( f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" ) assert len(block_size) == w.ndim - + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) original_shape = w.shape mapping_type = MappingType.ASYMMETRIC target_dtype = torch.int32