From e092ec1ef4e689e1d033b0a0afda6a3b4431abd3 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 18 Aug 2025 20:13:04 -0700 Subject: [PATCH 1/7] [CPU] Introduce Int4WoqCpuTensor to replace Int4CPULayout in AQT --- .../int4/test_int4_woq_cpu_tensor.py | 72 +++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 7 + .../quantize_/common/packing_format.py | 5 + .../quantize_/workflows/__init__.py | 4 + .../workflows/int4/int4_woq_cpu_tensor.py | 194 ++++++++++++++++++ 6 files changed, 284 insertions(+) create mode 100644 test/quantization/quantize_/workflows/int4/test_int4_woq_cpu_tensor.py create mode 100644 torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py diff --git a/test/quantization/quantize_/workflows/int4/test_int4_woq_cpu_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_woq_cpu_tensor.py new file mode 100644 index 0000000000..d013a473b7 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_woq_cpu_tensor.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(group_size): + return Int4WeightOnlyConfig( + group_size=group_size, + packing_format="int4_woq_cpu", + version=2, + ) + + +@unittest.skipIf(not torch_version_at_least("2.6.0"), "Need pytorch 2.6+") +class TestInt4WoqCpuTensor(TestCase): + @parametrize("group_size", [32, 64, 128]) + def test_linear(self, group_size): + dtype = torch.bfloat16 + device = "cpu" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(group_size)) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + @parametrize("group_size", [32, 64, 128]) + def test_module_path(self, group_size): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, get_config(group_size)) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestInt4WoqCpuTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 8e98e55178..d38e886265 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -93,6 +93,7 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, + Int4WoqCpuTensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -162,6 +163,7 @@ "Int4PreshuffledTensor", "Int4MarlinSparseTensor", "Float8Tensor", + "Int4WoqCpuTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ed5abb7333..ad5c5db6c8 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,6 +75,7 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, + Int4WoqCpuTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -1075,6 +1076,12 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight + elif packing_format == PackingFormat.INT4_WOQ_CPU: + new_weight = Int4WoqCpuTensor.from_hp( + weight, + block_size, + ) + return new_weight else: raise ValueError(f"Unsupported packing format: {packing_format}") diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 96a29d2990..0a1f93d30b 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -35,3 +35,8 @@ class PackingFormat(str, Enum): marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization """ MARLIN_SPARSE = "marlin_sparse" + + """ + int4_woq_cpu is referring to the format used by int4 weight-only quantization on CPU, which is a groupwise quantization format. + """ + INT4_WOQ_CPU = "int4_woq_cpu" diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 8441382243..e193e783ad 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -11,6 +11,9 @@ from .int4.int4_tensor import ( Int4Tensor, ) +from .int4.int4_woq_cpu_tensor import ( + Int4WoqCpuTensor, +) __all__ = [ "Int4Tensor", @@ -18,4 +21,5 @@ "Int4MarlinSparseTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", + "Int4WoqCpuTensor", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py new file mode 100644 index 0000000000..353ffecd57 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + _choose_qparams_affine_tinygemm, + _quantize_affine_tinygemm, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +__all__ = [ + "Int4WoqCpuTensor", +] + +aten = torch.ops.aten + + +class Int4WoqCpuTensor(TorchAOBaseTensor): + """ + int4 weight-only quantization on CPU (groupwise quantization only) + + Tensor Attributes: + qdata: preshuffled and packed int4 weight, always viewed as a 2D (N, K/2) tensor, last dimension is packed + preshuffling is specific to CPU kernels, see Note below. + qscale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size). + we only support group_size = 32/64/128. + shape: shape of the original Tensor + + Note on Details for data layout for CPU kernel: + + We use AVX512, AVX512_VNNI and AMX instructions (torch.compile and max-autotune needed for the latter two) to compute GEMM on CPU. + For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64. And when packing + the last dimension, data are shuffled by lanes before packing two int4 to one int8: + block_n = 64 = 16 * 4, so we have 4 lanes, each lane has 16 int4s = [lane0, lane1, lane2, lane3]. We pack them as [lane0|lane2, lane1|lane3]. + See https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 for more details. + """ + + tensor_data_names = ["qdata", "qscale_and_zero"] + optional_tensor_data_names = [] + tensor_attribute_names = ["block_size", "shape"] + + def __new__( + cls, + qdata, + qscale_and_zero, + block_size, + shape, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = qscale_and_zero.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + qscale_and_zero: torch.Tensor, + block_size: List[int], + shape: List[int], + ): + self.qdata = qdata + self.qscale_and_zero = qscale_and_zero + self.block_size = block_size + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + assert len(block_size) == w.ndim + assert block_size[0] == 1 and block_size[1] in (32, 64, 128), ( + f"Expecting groupwise quantization with group size = 32/64/128, but got block_size: {block_size}" + ) + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + scale_dtype = None + zero_point_dtype = w.dtype + scale, zero_point = _choose_qparams_affine_tinygemm( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = _quantize_affine_tinygemm( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + ) + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) + return Int4WoqCpuTensor( + qdata=packed_weight, + qscale_and_zero=scale_and_zero, + block_size=block_size, + shape=original_shape, + ) + + +implements = Int4WoqCpuTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "cpu", ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Int4WoqCpuTensor), ( + f"Expected weight_tensor to be Int4WoqCpuTensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + packed_weight = weight_tensor.qdata.contiguous() + scale_and_zero = weight_tensor.qscale_and_zero.contiguous() + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +Int4WoqCpuTensor.__module__ = "torchao.quantization" + +# Allow a model with Int4WoqCpuTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4WoqCpuTensor]) From 6c16b26d90f2af25ba672c7127f54682701c5f9e Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 18 Aug 2025 20:25:51 -0700 Subject: [PATCH 2/7] refine code --- .../quantize_/workflows/int4/int4_woq_cpu_tensor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py index 353ffecd57..29fa27a37f 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py @@ -122,7 +122,7 @@ def from_hp( ) packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( int_data, - 1, # TODO:remove + 1, # innerKTiles is not needed for CPU ) scale = scale.reshape(int_data.shape[0], -1) @@ -158,8 +158,7 @@ def _(func, types, args, kwargs): f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" ) assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" ) act_mat = input_tensor From 9012a61d44066eff61a1990f540429002773be8d Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 18 Aug 2025 20:30:13 -0700 Subject: [PATCH 3/7] refine code --- .../quantize_/workflows/int4/int4_woq_cpu_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py index 29fa27a37f..5a2bf4fba5 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py @@ -85,6 +85,9 @@ def from_hp( w: torch.Tensor, block_size: List[int], ): + assert w.ndim == 2 and w.device.type == "cpu", ( + f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}" + ) assert len(block_size) == w.ndim assert block_size[0] == 1 and block_size[1] in (32, 64, 128), ( f"Expecting groupwise quantization with group size = 32/64/128, but got block_size: {block_size}" @@ -178,6 +181,7 @@ def _(func, types, args, kwargs): ) # remove out_feature padding + assert weight_tensor.ndim == 2 orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] y = y.reshape(*orig_act_size[:-1], orig_out_features) From d41e0a836e18edcf626ecdd8ad3dc652e62364f3 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 19 Aug 2025 19:39:59 -0700 Subject: [PATCH 4/7] Refine code --- ...or.py => test_int4_tinygemm_cpu_tensor.py} | 39 +++++++++----- torchao/quantization/__init__.py | 4 +- torchao/quantization/quant_api.py | 6 +-- .../quantize_/common/packing_format.py | 4 +- .../quantize_/workflows/__init__.py | 6 +-- ..._tensor.py => int4_tinygemm_cpu_tensor.py} | 51 +++++++++---------- 6 files changed, 60 insertions(+), 50 deletions(-) rename test/quantization/quantize_/workflows/int4/{test_int4_woq_cpu_tensor.py => test_int4_tinygemm_cpu_tensor.py} (55%) rename torchao/quantization/quantize_/workflows/int4/{int4_woq_cpu_tensor.py => int4_tinygemm_cpu_tensor.py} (74%) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_woq_cpu_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py similarity index 55% rename from test/quantization/quantize_/workflows/int4/test_int4_woq_cpu_tensor.py rename to test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py index d013a473b7..ce2a862e32 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_woq_cpu_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py @@ -28,31 +28,44 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="int4_woq_cpu", + packing_format="int4_tinygemm_cpu", version=2, ) @unittest.skipIf(not torch_version_at_least("2.6.0"), "Need pytorch 2.6+") -class TestInt4WoqCpuTensor(TestCase): +class TestInt4TinyGemmCpuTensor(TestCase): + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + @parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @parametrize("group_size", [32, 64, 128]) - def test_linear(self, group_size): - dtype = torch.bfloat16 + def test_linear(self, sizes, dtype, group_size): device = "cpu" - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) original = linear(input) quantize_(linear, get_config(group_size)) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) - @parametrize("group_size", [32, 64, 128]) - def test_module_path(self, group_size): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, get_config(group_size)) + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + def test_module_path(self, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype) + quantize_(linear, get_config(group_size=128)) self.assertEqual( str(type(linear.weight)), - "", + "", ) with tempfile.NamedTemporaryFile() as f: @@ -61,11 +74,11 @@ def test_module_path(self, group_size): state_dict = torch.load(f) self.assertEqual( str(type(state_dict["weight"])), - "", + "", ) -instantiate_parametrized_tests(TestInt4WoqCpuTensor) +instantiate_parametrized_tests(TestInt4TinyGemmCpuTensor) if __name__ == "__main__": diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 791d05bee0..ea3bed2fa9 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -93,7 +93,7 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, - Int4WoqCpuTensor, + Int4TinyGemmCpuTensor, IntxUnpackedTensor, ) from .smoothquant import ( @@ -165,7 +165,7 @@ "Int4MarlinSparseTensor", "IntxUnpackedTensor", "Float8Tensor", - "Int4WoqCpuTensor", + "Int4TinyGemmCpuTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 97ba588751..dd3da17527 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,7 +75,7 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, - Int4WoqCpuTensor, + Int4TinyGemmCpuTensor, IntxUnpackedTensor, QuantizeTensorToFloat8Kwargs, ) @@ -1081,8 +1081,8 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight - elif packing_format == PackingFormat.INT4_WOQ_CPU: - new_weight = Int4WoqCpuTensor.from_hp( + elif packing_format == PackingFormat.INT4_TINYGEMM_CPU: + new_weight = Int4TinyGemmCpuTensor.from_hp( weight, block_size, ) diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index a47906294e..4aad8b0f0c 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -42,6 +42,6 @@ class PackingFormat(str, Enum): UNPACKED_TO_INT8 = "unpacked_to_int8" """ - int4_woq_cpu is referring to the format used by int4 weight-only quantization on CPU, which is a groupwise quantization format. + int4_tinygemm_cpu is referring to the format used by int4 weight-only quantization on CPU, which is a groupwise quantization format. """ - INT4_WOQ_CPU = "int4_woq_cpu" + INT4_TINYGEMM_CPU = "int4_tinygemm_cpu" diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index fa8ce21afd..57197f35fa 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -11,8 +11,8 @@ from .int4.int4_tensor import ( Int4Tensor, ) -from .int4.int4_woq_cpu_tensor import ( - Int4WoqCpuTensor, +from .int4.int4_tinygemm_cpu_tensor import ( + Int4TinyGemmCpuTensor, ) from .intx.intx_unpacked_tensor import ( IntxUnpackedTensor, @@ -24,6 +24,6 @@ "Int4MarlinSparseTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", - "Int4WoqCpuTensor", + "Int4TinyGemmCpuTensor", "IntxUnpackedTensor", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tinygemm_cpu_tensor.py similarity index 74% rename from torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py rename to torchao/quantization/quantize_/workflows/int4/int4_tinygemm_cpu_tensor.py index 5a2bf4fba5..c6069bf8ad 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_woq_cpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_tinygemm_cpu_tensor.py @@ -19,61 +19,58 @@ ) __all__ = [ - "Int4WoqCpuTensor", + "Int4TinyGemmCpuTensor", ] aten = torch.ops.aten -class Int4WoqCpuTensor(TorchAOBaseTensor): +class Int4TinyGemmCpuTensor(TorchAOBaseTensor): """ - int4 weight-only quantization on CPU (groupwise quantization only) + int4 weight-only quantization on CPU with tinygemm (groupwise quantization only) Tensor Attributes: - qdata: preshuffled and packed int4 weight, always viewed as a 2D (N, K/2) tensor, last dimension is packed + qdata: preshuffled and packed int4 weight for tinygemm, always viewed as a 2D (N, K/2) tensor, last dimension is packed preshuffling is specific to CPU kernels, see Note below. - qscale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype + scale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype Non-Tensor Attributes: block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size). we only support group_size = 32/64/128. shape: shape of the original Tensor - Note on Details for data layout for CPU kernel: + Note on Details for data layout for CPU tinygemm kernel: - We use AVX512, AVX512_VNNI and AMX instructions (torch.compile and max-autotune needed for the latter two) to compute GEMM on CPU. - For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64. And when packing - the last dimension, data are shuffled by lanes before packing two int4 to one int8: - block_n = 64 = 16 * 4, so we have 4 lanes, each lane has 16 int4s = [lane0, lane1, lane2, lane3]. We pack them as [lane0|lane2, lane1|lane3]. + We use AVX512 to compute TINYGEMM on CPU. We can also leverage AVX512_VNNI and AMX instructions with torch.compile and max-autotune. + For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64/32/16. See https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 for more details. """ - tensor_data_names = ["qdata", "qscale_and_zero"] - optional_tensor_data_names = [] + tensor_data_names = ["qdata", "scale_and_zero"] tensor_attribute_names = ["block_size", "shape"] def __new__( cls, qdata, - qscale_and_zero, + scale_and_zero, block_size, shape, ): kwargs = {} kwargs["device"] = qdata.device - kwargs["dtype"] = qscale_and_zero.dtype + kwargs["dtype"] = scale_and_zero.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, qdata: torch.Tensor, - qscale_and_zero: torch.Tensor, + scale_and_zero: torch.Tensor, block_size: List[int], - shape: List[int], + shape: torch.Size, ): self.qdata = qdata - self.qscale_and_zero = qscale_and_zero + self.scale_and_zero = scale_and_zero self.block_size = block_size def _quantization_type(self): @@ -133,15 +130,15 @@ def from_hp( from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) - return Int4WoqCpuTensor( + return Int4TinyGemmCpuTensor( qdata=packed_weight, - qscale_and_zero=scale_and_zero, + scale_and_zero=scale_and_zero, block_size=block_size, shape=original_shape, ) -implements = Int4WoqCpuTensor.implements +implements = Int4TinyGemmCpuTensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -154,8 +151,8 @@ def _(func, types, args, kwargs): assert input_tensor.device.type == "cpu", ( f"For CPU device only but got: {input_tensor.device}" ) - assert isinstance(weight_tensor, Int4WoqCpuTensor), ( - f"Expected weight_tensor to be Int4WoqCpuTensor, got: {type(weight_tensor)}" + assert isinstance(weight_tensor, Int4TinyGemmCpuTensor), ( + f"Expected weight_tensor to be Int4TinyGemmCpuTensor, got: {type(weight_tensor)}" ) assert weight_tensor.block_size[0] == 1, ( f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" @@ -165,8 +162,8 @@ def _(func, types, args, kwargs): ) act_mat = input_tensor - packed_weight = weight_tensor.qdata.contiguous() - scale_and_zero = weight_tensor.qscale_and_zero.contiguous() + packed_weight = weight_tensor.qdata + scale_and_zero = weight_tensor.scale_and_zero orig_act_size = act_mat.size() orig_dtype = act_mat.dtype @@ -191,7 +188,7 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) -Int4WoqCpuTensor.__module__ = "torchao.quantization" +Int4TinyGemmCpuTensor.__module__ = "torchao.quantization" -# Allow a model with Int4WoqCpuTensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([Int4WoqCpuTensor]) +# Allow a model with Int4TinyGemmCpuTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4TinyGemmCpuTensor]) From 969c46aa439a54bebbc08a6e468cae70a11b50e6 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sun, 24 Aug 2025 18:54:43 -0700 Subject: [PATCH 5/7] Update UT --- .../quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py index ce2a862e32..7e663bb90f 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py @@ -43,7 +43,7 @@ class TestInt4TinyGemmCpuTensor(TestCase): ((2, 32, 128), 256, 12), ], ) - @parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @parametrize("group_size", [32, 64, 128]) def test_linear(self, sizes, dtype, group_size): device = "cpu" @@ -59,7 +59,7 @@ def test_linear(self, sizes, dtype, group_size): quantized_and_compiled = compiled_linear(input) self.assertTrue(compute_error(original, quantized_and_compiled) > 20) - @parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) def test_module_path(self, dtype): linear = torch.nn.Linear(128, 256, dtype=dtype) quantize_(linear, get_config(group_size=128)) From ade2c32dc9127d32981406f0c54173470e687897 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 26 Aug 2025 10:29:07 +0000 Subject: [PATCH 6/7] Rename tensor & format to opaque --- ...mm_cpu_tensor.py => test_opaque_tensor.py} | 10 ++++----- torchao/quantization/__init__.py | 4 ++-- torchao/quantization/quant_api.py | 6 +++--- .../quantize_/workflows/__init__.py | 6 +++--- ...inygemm_cpu_tensor.py => opaque_tensor.py} | 21 ++++++++++--------- 5 files changed, 24 insertions(+), 23 deletions(-) rename test/quantization/quantize_/workflows/int4/{test_int4_tinygemm_cpu_tensor.py => test_opaque_tensor.py} (88%) rename torchao/quantization/quantize_/workflows/int4/{int4_tinygemm_cpu_tensor.py => opaque_tensor.py} (90%) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py b/test/quantization/quantize_/workflows/int4/test_opaque_tensor.py similarity index 88% rename from test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py rename to test/quantization/quantize_/workflows/int4/test_opaque_tensor.py index 7e663bb90f..5cb9e440f0 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tinygemm_cpu_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_opaque_tensor.py @@ -28,13 +28,13 @@ def get_config(group_size): return Int4WeightOnlyConfig( group_size=group_size, - packing_format="int4_tinygemm_cpu", + packing_format="opaque", version=2, ) @unittest.skipIf(not torch_version_at_least("2.6.0"), "Need pytorch 2.6+") -class TestInt4TinyGemmCpuTensor(TestCase): +class TestOpaqueTensor(TestCase): @parametrize( "sizes", [ @@ -65,7 +65,7 @@ def test_module_path(self, dtype): quantize_(linear, get_config(group_size=128)) self.assertEqual( str(type(linear.weight)), - "", + "", ) with tempfile.NamedTemporaryFile() as f: @@ -74,11 +74,11 @@ def test_module_path(self, dtype): state_dict = torch.load(f) self.assertEqual( str(type(state_dict["weight"])), - "", + "", ) -instantiate_parametrized_tests(TestInt4TinyGemmCpuTensor) +instantiate_parametrized_tests(TestOpaqueTensor) if __name__ == "__main__": diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ea3bed2fa9..b9baec8281 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -93,8 +93,8 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, - Int4TinyGemmCpuTensor, IntxUnpackedTensor, + OpaqueTensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -165,7 +165,7 @@ "Int4MarlinSparseTensor", "IntxUnpackedTensor", "Float8Tensor", - "Int4TinyGemmCpuTensor", + "OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8ca0cf4a59..260ad2f7f8 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,8 +75,8 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, - Int4TinyGemmCpuTensor, IntxUnpackedTensor, + OpaqueTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -1081,8 +1081,8 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight - elif packing_format == PackingFormat.INT4_TINYGEMM_CPU: - new_weight = Int4TinyGemmCpuTensor.from_hp( + elif packing_format == PackingFormat.OPAQUE: + new_weight = OpaqueTensor.from_hp( weight, block_size, ) diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 57197f35fa..9f705f17bf 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -11,8 +11,8 @@ from .int4.int4_tensor import ( Int4Tensor, ) -from .int4.int4_tinygemm_cpu_tensor import ( - Int4TinyGemmCpuTensor, +from .int4.opaque_tensor import ( + OpaqueTensor, ) from .intx.intx_unpacked_tensor import ( IntxUnpackedTensor, @@ -24,6 +24,6 @@ "Int4MarlinSparseTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", - "Int4TinyGemmCpuTensor", + "OpaqueTensor", "IntxUnpackedTensor", ] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tinygemm_cpu_tensor.py b/torchao/quantization/quantize_/workflows/int4/opaque_tensor.py similarity index 90% rename from torchao/quantization/quantize_/workflows/int4/int4_tinygemm_cpu_tensor.py rename to torchao/quantization/quantize_/workflows/int4/opaque_tensor.py index c6069bf8ad..a5186f1e5b 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_tinygemm_cpu_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/opaque_tensor.py @@ -19,15 +19,16 @@ ) __all__ = [ - "Int4TinyGemmCpuTensor", + "OpaqueTensor", ] aten = torch.ops.aten -class Int4TinyGemmCpuTensor(TorchAOBaseTensor): +class OpaqueTensor(TorchAOBaseTensor): """ - int4 weight-only quantization on CPU with tinygemm (groupwise quantization only) + int4 weight-only quantization on CPU with tinygemm (groupwise quantization only). The packing format is determined on ISA and shape. + This is an opaque tensor subclass, the packing format is not exposed to the rest of the system. See the note below for more details. Tensor Attributes: qdata: preshuffled and packed int4 weight for tinygemm, always viewed as a 2D (N, K/2) tensor, last dimension is packed @@ -130,7 +131,7 @@ def from_hp( from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) - return Int4TinyGemmCpuTensor( + return OpaqueTensor( qdata=packed_weight, scale_and_zero=scale_and_zero, block_size=block_size, @@ -138,7 +139,7 @@ def from_hp( ) -implements = Int4TinyGemmCpuTensor.implements +implements = OpaqueTensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -151,8 +152,8 @@ def _(func, types, args, kwargs): assert input_tensor.device.type == "cpu", ( f"For CPU device only but got: {input_tensor.device}" ) - assert isinstance(weight_tensor, Int4TinyGemmCpuTensor), ( - f"Expected weight_tensor to be Int4TinyGemmCpuTensor, got: {type(weight_tensor)}" + assert isinstance(weight_tensor, OpaqueTensor), ( + f"Expected weight_tensor to be OpaqueTensor, got: {type(weight_tensor)}" ) assert weight_tensor.block_size[0] == 1, ( f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" @@ -188,7 +189,7 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) -Int4TinyGemmCpuTensor.__module__ = "torchao.quantization" +OpaqueTensor.__module__ = "torchao.quantization" -# Allow a model with Int4TinyGemmCpuTensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([Int4TinyGemmCpuTensor]) +# Allow a model with OpaqueTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([OpaqueTensor]) From c81880e8c6204ef2aed677551315dd4dee9576e9 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 26 Aug 2025 11:11:27 +0000 Subject: [PATCH 7/7] Rename OpaqueTensor -> Int4OpaqueTensor --- ...e_tensor.py => test_int4_opaque_tensor.py} | 8 +++---- torchao/quantization/__init__.py | 4 ++-- torchao/quantization/quant_api.py | 4 ++-- .../quantize_/workflows/__init__.py | 8 +++---- ...opaque_tensor.py => int4_opaque_tensor.py} | 22 +++++++++---------- 5 files changed, 23 insertions(+), 23 deletions(-) rename test/quantization/quantize_/workflows/int4/{test_opaque_tensor.py => test_int4_opaque_tensor.py} (91%) rename torchao/quantization/quantize_/workflows/int4/{opaque_tensor.py => int4_opaque_tensor.py} (89%) diff --git a/test/quantization/quantize_/workflows/int4/test_opaque_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py similarity index 91% rename from test/quantization/quantize_/workflows/int4/test_opaque_tensor.py rename to test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py index 5cb9e440f0..58ec391038 100644 --- a/test/quantization/quantize_/workflows/int4/test_opaque_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py @@ -34,7 +34,7 @@ def get_config(group_size): @unittest.skipIf(not torch_version_at_least("2.6.0"), "Need pytorch 2.6+") -class TestOpaqueTensor(TestCase): +class TestInt4OpaqueTensor(TestCase): @parametrize( "sizes", [ @@ -65,7 +65,7 @@ def test_module_path(self, dtype): quantize_(linear, get_config(group_size=128)) self.assertEqual( str(type(linear.weight)), - "", + "", ) with tempfile.NamedTemporaryFile() as f: @@ -74,11 +74,11 @@ def test_module_path(self, dtype): state_dict = torch.load(f) self.assertEqual( str(type(state_dict["weight"])), - "", + "", ) -instantiate_parametrized_tests(TestOpaqueTensor) +instantiate_parametrized_tests(TestInt4OpaqueTensor) if __name__ == "__main__": diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b9baec8281..a692eaff26 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -91,10 +91,10 @@ from .quantize_.workflows import ( Float8Tensor, Int4MarlinSparseTensor, + Int4OpaqueTensor, Int4PreshuffledTensor, Int4Tensor, IntxUnpackedTensor, - OpaqueTensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -165,7 +165,7 @@ "Int4MarlinSparseTensor", "IntxUnpackedTensor", "Float8Tensor", - "OpaqueTensor", + "Int4OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 260ad2f7f8..e524bdaca1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -73,10 +73,10 @@ from torchao.quantization.quantize_.workflows import ( Float8Tensor, Int4MarlinSparseTensor, + Int4OpaqueTensor, Int4PreshuffledTensor, Int4Tensor, IntxUnpackedTensor, - OpaqueTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -1082,7 +1082,7 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.OPAQUE: - new_weight = OpaqueTensor.from_hp( + new_weight = Int4OpaqueTensor.from_hp( weight, block_size, ) diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 9f705f17bf..850c856298 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -5,15 +5,15 @@ from .int4.int4_marlin_sparse_tensor import ( Int4MarlinSparseTensor, ) +from .int4.int4_opaque_tensor import ( + Int4OpaqueTensor, +) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, ) from .int4.int4_tensor import ( Int4Tensor, ) -from .int4.opaque_tensor import ( - OpaqueTensor, -) from .intx.intx_unpacked_tensor import ( IntxUnpackedTensor, ) @@ -24,6 +24,6 @@ "Int4MarlinSparseTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", - "OpaqueTensor", + "Int4OpaqueTensor", "IntxUnpackedTensor", ] diff --git a/torchao/quantization/quantize_/workflows/int4/opaque_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py similarity index 89% rename from torchao/quantization/quantize_/workflows/int4/opaque_tensor.py rename to torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py index a5186f1e5b..ace8745175 100644 --- a/torchao/quantization/quantize_/workflows/int4/opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py @@ -19,20 +19,20 @@ ) __all__ = [ - "OpaqueTensor", + "Int4OpaqueTensor", ] aten = torch.ops.aten -class OpaqueTensor(TorchAOBaseTensor): +class Int4OpaqueTensor(TorchAOBaseTensor): """ int4 weight-only quantization on CPU with tinygemm (groupwise quantization only). The packing format is determined on ISA and shape. This is an opaque tensor subclass, the packing format is not exposed to the rest of the system. See the note below for more details. Tensor Attributes: - qdata: preshuffled and packed int4 weight for tinygemm, always viewed as a 2D (N, K/2) tensor, last dimension is packed - preshuffling is specific to CPU kernels, see Note below. + qdata: preshuffled and packed int4 weight for CPU tinygemm kernel, always viewed as a 2D (N, K/2) tensor, last dimension is packed + preshuffling is specific to CPU kernels based on ISA and shape, see Note below. scale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype Non-Tensor Attributes: @@ -131,7 +131,7 @@ def from_hp( from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) - return OpaqueTensor( + return Int4OpaqueTensor( qdata=packed_weight, scale_and_zero=scale_and_zero, block_size=block_size, @@ -139,7 +139,7 @@ def from_hp( ) -implements = OpaqueTensor.implements +implements = Int4OpaqueTensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -152,8 +152,8 @@ def _(func, types, args, kwargs): assert input_tensor.device.type == "cpu", ( f"For CPU device only but got: {input_tensor.device}" ) - assert isinstance(weight_tensor, OpaqueTensor), ( - f"Expected weight_tensor to be OpaqueTensor, got: {type(weight_tensor)}" + assert isinstance(weight_tensor, Int4OpaqueTensor), ( + f"Expected weight_tensor to be Int4OpaqueTensor, got: {type(weight_tensor)}" ) assert weight_tensor.block_size[0] == 1, ( f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" @@ -189,7 +189,7 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) -OpaqueTensor.__module__ = "torchao.quantization" +Int4OpaqueTensor.__module__ = "torchao.quantization" -# Allow a model with OpaqueTensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([OpaqueTensor]) +# Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4OpaqueTensor])