|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
| 6 | + |
| 7 | +from torchao.dtypes.affine_quantized_tensor import register_layout |
| 8 | +from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device |
| 9 | +from torchao.utils import ( |
| 10 | + TORCH_VERSION_AT_LEAST_2_5, |
| 11 | + TORCH_VERSION_AT_LEAST_2_6, |
| 12 | + fill_defaults, |
| 13 | +) |
| 14 | + |
| 15 | +aten = torch.ops.aten |
| 16 | + |
| 17 | + |
| 18 | +@dataclass(frozen=True) |
| 19 | +class Int4CPULayout(Layout): |
| 20 | + """Only for PyTorch version at least 2.6""" |
| 21 | + |
| 22 | + pass |
| 23 | + |
| 24 | + |
| 25 | +@register_layout(Int4CPULayout) |
| 26 | +class Int4CPUAQTTensorImpl(AQTTensorImpl): |
| 27 | + """ |
| 28 | + TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, |
| 29 | + used by tinygemm kernels `_weight_int4pack_mm_for_cpu` |
| 30 | + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of |
| 31 | + dimension: [n][k / 2] (uint8 dtype) |
| 32 | + (unpacked Tensor shape is n * k) |
| 33 | + Note: we also pack scale and zero point together here for tinygemm kernel |
| 34 | + Note: technically Int4 CPU layout should be the layout for the underlying packed weight |
| 35 | + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used |
| 36 | + in plain layout, we just created a layout for AQT right now, this could be improved if we split out |
| 37 | + int4 aqt into a separate tensor subclass |
| 38 | + fields: |
| 39 | + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout |
| 40 | + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor |
| 41 | + """ |
| 42 | + |
| 43 | + def __new__( |
| 44 | + cls, |
| 45 | + packed_weight: torch.Tensor, |
| 46 | + scale_and_zero: torch.Tensor, |
| 47 | + transposed: bool, |
| 48 | + _layout: Layout, |
| 49 | + ): |
| 50 | + kwargs = {} |
| 51 | + kwargs["device"] = packed_weight.device |
| 52 | + kwargs["layout"] = ( |
| 53 | + kwargs.get("layout") |
| 54 | + if kwargs.get("layout", False) |
| 55 | + else packed_weight.layout |
| 56 | + ) |
| 57 | + kwargs["dtype"] = packed_weight.dtype |
| 58 | + kwargs["requires_grad"] = False |
| 59 | + shape = packed_weight.shape |
| 60 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + packed_weight: torch.Tensor, |
| 65 | + scale_and_zero: torch.Tensor, |
| 66 | + transposed: bool, |
| 67 | + _layout: Layout, |
| 68 | + ): |
| 69 | + self.packed_weight = packed_weight |
| 70 | + self.scale_and_zero = scale_and_zero |
| 71 | + self.transposed = False |
| 72 | + self._layout = _layout |
| 73 | + |
| 74 | + def __tensor_flatten__(self): |
| 75 | + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] |
| 76 | + |
| 77 | + @classmethod |
| 78 | + def __tensor_unflatten__( |
| 79 | + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride |
| 80 | + ): |
| 81 | + packed_weight, scale_and_zero = ( |
| 82 | + tensor_data_dict["packed_weight"], |
| 83 | + tensor_data_dict["scale_and_zero"], |
| 84 | + ) |
| 85 | + ( |
| 86 | + transposed, |
| 87 | + _layout, |
| 88 | + ) = tensor_attributes |
| 89 | + return cls(packed_weight, scale_and_zero, transposed, _layout) |
| 90 | + |
| 91 | + @classmethod |
| 92 | + def from_plain( |
| 93 | + cls, |
| 94 | + int_data: torch.Tensor, |
| 95 | + scale: torch.Tensor, |
| 96 | + zero_point: Optional[torch.Tensor], |
| 97 | + _layout: Layout, |
| 98 | + ): |
| 99 | + assert isinstance(_layout, Int4CPULayout) |
| 100 | + |
| 101 | + if TORCH_VERSION_AT_LEAST_2_6: |
| 102 | + assert ( |
| 103 | + int_data.dtype == torch.int32 |
| 104 | + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" |
| 105 | + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( |
| 106 | + int_data, |
| 107 | + 1, # TODO:remove |
| 108 | + ) |
| 109 | + elif TORCH_VERSION_AT_LEAST_2_5: |
| 110 | + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) |
| 111 | + assert ( |
| 112 | + int_data.dtype == torch.uint8 |
| 113 | + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" |
| 114 | + packed_weight = torch.ops.aten._convert_weight_to_int4pack( |
| 115 | + int_data, _layout.inner_k_tiles |
| 116 | + ) |
| 117 | + else: |
| 118 | + assert ( |
| 119 | + int_data.dtype == torch.int32 |
| 120 | + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" |
| 121 | + packed_weight = torch.ops.aten._convert_weight_to_int4pack( |
| 122 | + int_data, _layout.inner_k_tiles |
| 123 | + ) |
| 124 | + |
| 125 | + scale = scale.reshape(int_data.shape[0], -1) |
| 126 | + zero_point = zero_point.reshape(int_data.shape[0], -1) |
| 127 | + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros |
| 128 | + |
| 129 | + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) |
| 130 | + return cls(packed_weight, scale_and_zero, False, _layout) |
| 131 | + |
| 132 | + def to(self, *args, **kwargs): |
| 133 | + kwargs = self._get_to_kwargs(*args, **kwargs) |
| 134 | + device = kwargs["device"] |
| 135 | + if not is_device(torch.device(self.device).type, device): |
| 136 | + raise ValueError( |
| 137 | + f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" |
| 138 | + ) |
| 139 | + return self.__class__( |
| 140 | + self.packed_weight.to(device), |
| 141 | + self.scale_and_zero.to(device), |
| 142 | + self.transposed, |
| 143 | + self._layout, |
| 144 | + ) |
| 145 | + |
| 146 | + def _apply_fn_to_data(self, fn): |
| 147 | + return self.__class__( |
| 148 | + fn(self.packed_weight), |
| 149 | + fn(self.scale_and_zero), |
| 150 | + self.transposed, |
| 151 | + self._layout, |
| 152 | + ) |
| 153 | + |
| 154 | + @classmethod |
| 155 | + def __torch_dispatch__(cls, func, types, args, kwargs): |
| 156 | + kwargs = {} if kwargs is None else kwargs |
| 157 | + |
| 158 | + if func is aten.detach.default: |
| 159 | + return return_and_correct_aliasing( |
| 160 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
| 161 | + ) |
| 162 | + |
| 163 | + if func is aten.clone.default: |
| 164 | + return return_and_correct_aliasing( |
| 165 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
| 166 | + ) |
| 167 | + |
| 168 | + if func is aten.t.default: |
| 169 | + """we don't need to repack the weight and just rely on external |
| 170 | + shape being changed and record the status of transpose/no-transpose |
| 171 | + """ |
| 172 | + transposed = Int4CPUAQTTensorImpl( |
| 173 | + args[0].packed_weight, |
| 174 | + args[0].scale_and_zero, |
| 175 | + not args[0].transposed, |
| 176 | + args[0]._layout, |
| 177 | + ) |
| 178 | + return return_and_correct_aliasing(func, args, kwargs, transposed) |
| 179 | + |
| 180 | + if func is aten.slice.Tensor: |
| 181 | + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) |
| 182 | + if dim == 0: |
| 183 | + int_data, scale, zero_point = self.get_plain() |
| 184 | + int_data = aten.slice.Tensor(int_data, dim, start, end, step) |
| 185 | + # this is to handle padding |
| 186 | + int_data = self._layout.post_process(int_data) |
| 187 | + sliced = self.from_plain(int_data, scale, zero_point, self._layout) |
| 188 | + return return_and_correct_aliasing(func, args, kwargs, sliced) |
| 189 | + elif dim == 1: |
| 190 | + int_data, scale, zero_point = self.get_plain() |
| 191 | + assert step == 1, "Only step == 1 is supported in slicing right now" |
| 192 | + data_len = int_data.shape[dim] |
| 193 | + scale_len = scale.shape[dim] |
| 194 | + ratio = data_len / scale_len |
| 195 | + start_scale = int(start / ratio) |
| 196 | + end_scale = int(end / ratio) |
| 197 | + |
| 198 | + int_data = aten.slice.Tensor(int_data, dim, start, end, step) |
| 199 | + # this is to handle padding |
| 200 | + int_data = self._layout.post_process(int_data) |
| 201 | + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) |
| 202 | + zero_point = aten.slice.Tensor( |
| 203 | + zero_point, dim, start_scale, end_scale, step |
| 204 | + ) |
| 205 | + sliced = self.from_plain(int_data, scale, zero_point, self._layout) |
| 206 | + return sliced |
| 207 | + else: |
| 208 | + raise NotImplementedError( |
| 209 | + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" |
| 210 | + ) |
| 211 | + |
| 212 | + raise NotImplementedError( |
| 213 | + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" |
| 214 | + ) |
| 215 | + |
| 216 | + __torch_function__ = torch._C._disabled_torch_function_impl |
| 217 | + |
| 218 | + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 219 | + from torchao.quantization.quant_primitives import ( |
| 220 | + ZeroPointDomain, |
| 221 | + quantize_affine, |
| 222 | + ) |
| 223 | + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros |
| 224 | + |
| 225 | + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) |
| 226 | + |
| 227 | + cur_shape = self.shape |
| 228 | + assert len(cur_shape) == 2 |
| 229 | + original_shape = (cur_shape[0], cur_shape[1] * 2) |
| 230 | + eye_shape = original_shape[1] |
| 231 | + groupsize = int(original_shape[1] / scale.shape[-2]) |
| 232 | + block_size = (1, groupsize) |
| 233 | + device = self.device |
| 234 | + original_dtype = torch.bfloat16 |
| 235 | + target_dtype = torch.int32 |
| 236 | + quant_min = 0 |
| 237 | + quant_max = 15 |
| 238 | + zero_point_domain = ZeroPointDomain.FLOAT |
| 239 | + assert len(block_size) == 2 and block_size[0] == 1 |
| 240 | + dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu( |
| 241 | + torch.eye(eye_shape, device=device, dtype=original_dtype), |
| 242 | + self.packed_weight, |
| 243 | + groupsize, |
| 244 | + self.scale_and_zero, |
| 245 | + ) |
| 246 | + dequantized = dequantized.t().contiguous() |
| 247 | + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? |
| 248 | + scale = scale.reshape(scale.shape[:-1]).contiguous() |
| 249 | + zero = zero.reshape(zero.shape[:-1]).contiguous() |
| 250 | + int_data = quantize_affine( |
| 251 | + dequantized, |
| 252 | + block_size, |
| 253 | + scale, |
| 254 | + zero, |
| 255 | + target_dtype, |
| 256 | + quant_min, |
| 257 | + quant_max, |
| 258 | + zero_point_domain, |
| 259 | + ) |
| 260 | + return int_data, scale, zero |
| 261 | + |
| 262 | + def get_layout(self) -> Layout: |
| 263 | + return self._layout |
0 commit comments