Skip to content

Commit 552fa00

Browse files
yanbing-jRUFF-bot
authored andcommitted
Move Int4CPULayout to int4_cpu_layout.py (#1419)
* Move Int4CPULayout to int4_cpu_layout.py * Apply automatic Ruff fixes --------- Co-authored-by: Ruff Auto-fixes <[email protected]>
1 parent 1a825e1 commit 552fa00

File tree

3 files changed

+266
-249
lines changed

3 files changed

+266
-249
lines changed

torchao/dtypes/uintx/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from .block_sparse_layout import (
22
BlockSparseLayout,
33
)
4+
from .int4_cpu_layout import (
5+
Int4CPULayout,
6+
)
47
from .marlin_qqq_tensor import (
58
MarlinQQQLayout,
69
MarlinQQQTensor,
@@ -13,7 +16,6 @@
1316
SemiSparseLayout,
1417
)
1518
from .tensor_core_tiled_layout import (
16-
Int4CPULayout,
1719
TensorCoreTiledLayout,
1820
)
1921
from .uintx_layout import (
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
from torch.utils._python_dispatch import return_and_correct_aliasing
6+
7+
from torchao.dtypes.affine_quantized_tensor import register_layout
8+
from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device
9+
from torchao.utils import (
10+
TORCH_VERSION_AT_LEAST_2_5,
11+
TORCH_VERSION_AT_LEAST_2_6,
12+
fill_defaults,
13+
)
14+
15+
aten = torch.ops.aten
16+
17+
18+
@dataclass(frozen=True)
19+
class Int4CPULayout(Layout):
20+
"""Only for PyTorch version at least 2.6"""
21+
22+
pass
23+
24+
25+
@register_layout(Int4CPULayout)
26+
class Int4CPUAQTTensorImpl(AQTTensorImpl):
27+
"""
28+
TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
29+
used by tinygemm kernels `_weight_int4pack_mm_for_cpu`
30+
It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of
31+
dimension: [n][k / 2] (uint8 dtype)
32+
(unpacked Tensor shape is n * k)
33+
Note: we also pack scale and zero point together here for tinygemm kernel
34+
Note: technically Int4 CPU layout should be the layout for the underlying packed weight
35+
(int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used
36+
in plain layout, we just created a layout for AQT right now, this could be improved if we split out
37+
int4 aqt into a separate tensor subclass
38+
fields:
39+
packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout
40+
scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor
41+
"""
42+
43+
def __new__(
44+
cls,
45+
packed_weight: torch.Tensor,
46+
scale_and_zero: torch.Tensor,
47+
transposed: bool,
48+
_layout: Layout,
49+
):
50+
kwargs = {}
51+
kwargs["device"] = packed_weight.device
52+
kwargs["layout"] = (
53+
kwargs.get("layout")
54+
if kwargs.get("layout", False)
55+
else packed_weight.layout
56+
)
57+
kwargs["dtype"] = packed_weight.dtype
58+
kwargs["requires_grad"] = False
59+
shape = packed_weight.shape
60+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
61+
62+
def __init__(
63+
self,
64+
packed_weight: torch.Tensor,
65+
scale_and_zero: torch.Tensor,
66+
transposed: bool,
67+
_layout: Layout,
68+
):
69+
self.packed_weight = packed_weight
70+
self.scale_and_zero = scale_and_zero
71+
self.transposed = False
72+
self._layout = _layout
73+
74+
def __tensor_flatten__(self):
75+
return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout]
76+
77+
@classmethod
78+
def __tensor_unflatten__(
79+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
80+
):
81+
packed_weight, scale_and_zero = (
82+
tensor_data_dict["packed_weight"],
83+
tensor_data_dict["scale_and_zero"],
84+
)
85+
(
86+
transposed,
87+
_layout,
88+
) = tensor_attributes
89+
return cls(packed_weight, scale_and_zero, transposed, _layout)
90+
91+
@classmethod
92+
def from_plain(
93+
cls,
94+
int_data: torch.Tensor,
95+
scale: torch.Tensor,
96+
zero_point: Optional[torch.Tensor],
97+
_layout: Layout,
98+
):
99+
assert isinstance(_layout, Int4CPULayout)
100+
101+
if TORCH_VERSION_AT_LEAST_2_6:
102+
assert (
103+
int_data.dtype == torch.int32
104+
), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
105+
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
106+
int_data,
107+
1, # TODO:remove
108+
)
109+
elif TORCH_VERSION_AT_LEAST_2_5:
110+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
111+
assert (
112+
int_data.dtype == torch.uint8
113+
), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
114+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
115+
int_data, _layout.inner_k_tiles
116+
)
117+
else:
118+
assert (
119+
int_data.dtype == torch.int32
120+
), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
121+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
122+
int_data, _layout.inner_k_tiles
123+
)
124+
125+
scale = scale.reshape(int_data.shape[0], -1)
126+
zero_point = zero_point.reshape(int_data.shape[0], -1)
127+
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
128+
129+
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
130+
return cls(packed_weight, scale_and_zero, False, _layout)
131+
132+
def to(self, *args, **kwargs):
133+
kwargs = self._get_to_kwargs(*args, **kwargs)
134+
device = kwargs["device"]
135+
if not is_device(torch.device(self.device).type, device):
136+
raise ValueError(
137+
f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}"
138+
)
139+
return self.__class__(
140+
self.packed_weight.to(device),
141+
self.scale_and_zero.to(device),
142+
self.transposed,
143+
self._layout,
144+
)
145+
146+
def _apply_fn_to_data(self, fn):
147+
return self.__class__(
148+
fn(self.packed_weight),
149+
fn(self.scale_and_zero),
150+
self.transposed,
151+
self._layout,
152+
)
153+
154+
@classmethod
155+
def __torch_dispatch__(cls, func, types, args, kwargs):
156+
kwargs = {} if kwargs is None else kwargs
157+
158+
if func is aten.detach.default:
159+
return return_and_correct_aliasing(
160+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
161+
)
162+
163+
if func is aten.clone.default:
164+
return return_and_correct_aliasing(
165+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
166+
)
167+
168+
if func is aten.t.default:
169+
"""we don't need to repack the weight and just rely on external
170+
shape being changed and record the status of transpose/no-transpose
171+
"""
172+
transposed = Int4CPUAQTTensorImpl(
173+
args[0].packed_weight,
174+
args[0].scale_and_zero,
175+
not args[0].transposed,
176+
args[0]._layout,
177+
)
178+
return return_and_correct_aliasing(func, args, kwargs, transposed)
179+
180+
if func is aten.slice.Tensor:
181+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
182+
if dim == 0:
183+
int_data, scale, zero_point = self.get_plain()
184+
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
185+
# this is to handle padding
186+
int_data = self._layout.post_process(int_data)
187+
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
188+
return return_and_correct_aliasing(func, args, kwargs, sliced)
189+
elif dim == 1:
190+
int_data, scale, zero_point = self.get_plain()
191+
assert step == 1, "Only step == 1 is supported in slicing right now"
192+
data_len = int_data.shape[dim]
193+
scale_len = scale.shape[dim]
194+
ratio = data_len / scale_len
195+
start_scale = int(start / ratio)
196+
end_scale = int(end / ratio)
197+
198+
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
199+
# this is to handle padding
200+
int_data = self._layout.post_process(int_data)
201+
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
202+
zero_point = aten.slice.Tensor(
203+
zero_point, dim, start_scale, end_scale, step
204+
)
205+
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
206+
return sliced
207+
else:
208+
raise NotImplementedError(
209+
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
210+
)
211+
212+
raise NotImplementedError(
213+
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
214+
)
215+
216+
__torch_function__ = torch._C._disabled_torch_function_impl
217+
218+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
219+
from torchao.quantization.quant_primitives import (
220+
ZeroPointDomain,
221+
quantize_affine,
222+
)
223+
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
224+
225+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
226+
227+
cur_shape = self.shape
228+
assert len(cur_shape) == 2
229+
original_shape = (cur_shape[0], cur_shape[1] * 2)
230+
eye_shape = original_shape[1]
231+
groupsize = int(original_shape[1] / scale.shape[-2])
232+
block_size = (1, groupsize)
233+
device = self.device
234+
original_dtype = torch.bfloat16
235+
target_dtype = torch.int32
236+
quant_min = 0
237+
quant_max = 15
238+
zero_point_domain = ZeroPointDomain.FLOAT
239+
assert len(block_size) == 2 and block_size[0] == 1
240+
dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu(
241+
torch.eye(eye_shape, device=device, dtype=original_dtype),
242+
self.packed_weight,
243+
groupsize,
244+
self.scale_and_zero,
245+
)
246+
dequantized = dequantized.t().contiguous()
247+
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
248+
scale = scale.reshape(scale.shape[:-1]).contiguous()
249+
zero = zero.reshape(zero.shape[:-1]).contiguous()
250+
int_data = quantize_affine(
251+
dequantized,
252+
block_size,
253+
scale,
254+
zero,
255+
target_dtype,
256+
quant_min,
257+
quant_max,
258+
zero_point_domain,
259+
)
260+
return int_data, scale, zero
261+
262+
def get_layout(self) -> Layout:
263+
return self._layout

0 commit comments

Comments
 (0)