Skip to content

Commit 98b8f8c

Browse files
committed
Update
1 parent fbb2cae commit 98b8f8c

File tree

8 files changed

+101
-43
lines changed

8 files changed

+101
-43
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
526526
groupsize = 128
527527

528528
if TORCH_VERSION_AT_LEAST_2_5:
529-
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
530-
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
529+
if not is_device(input.device.type, "cpu"):
530+
input = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
531+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
531532
else:
532533
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
533534
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,9 +2056,14 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
20562056

20572057
# groupwise int4 quantization
20582058
groupsize = weight_tensor.block_size[1]
2059-
y = torch.ops.aten._weight_int4pack_mm(
2060-
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
2061-
)
2059+
if is_device(input_tensor.device.type, "cpu"):
2060+
y = torch.ops.aten._weight_int4pack_mm_for_cpu(
2061+
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
2062+
)
2063+
else:
2064+
y = torch.ops.aten._weight_int4pack_mm(
2065+
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
2066+
)
20622067

20632068
# remove out_feature padding
20642069
orig_out_features = weight_tensor.shape[-2]

torchao/prototype/hqq/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f
8383

8484
- Times are in `ms`, see `benchmarks/benchmark_hqq.py`.
8585
- `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul).
86-
- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions.
86+
- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions.
8787

8888
GPU details:
8989

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch.nn.functional as F
1515
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
16+
from torchao.dtypes.utils import is_device
1617

1718

1819
class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
@@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta):
162163
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
163164
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
164165
)
165-
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
166-
W_q_torch, self.inner_k_tiles
167-
)
166+
if is_device(W_q.device.type, "cpu"):
167+
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
168+
W_q_torch, self.inner_k_tiles
169+
)
170+
else:
171+
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
172+
W_q_torch, self.inner_k_tiles
173+
)
168174
self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch)
169175

170176
del W_q_torch, scales_torch, zeros_torch
@@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants(
200206
.contiguous()
201207
)
202208
if TORCH_VERSION_AT_LEAST_2_5:
203-
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
209+
if not is_device(W_q.device.type, "cpu"):
210+
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
204211

205212
# group_dequantize_tensor_from_qparams
206213
# W_r = W_q*scales + min_val
@@ -232,9 +239,14 @@ def pack_scales_and_zeros(self, scales, zeros):
232239
def matmul(self, x):
233240
origin_x_size = x.size()
234241
x = x.reshape(-1, origin_x_size[-1])
235-
c = torch.ops.aten._weight_int4pack_mm(
236-
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
237-
)
242+
if is_device(x.device.type, "cpu"):
243+
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
244+
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
245+
)
246+
else:
247+
c = torch.ops.aten._weight_int4pack_mm(
248+
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
249+
)
238250
new_shape = origin_x_size[:-1] + (self.out_features,)
239251
c = c.reshape(new_shape)
240252
return c

torchao/quantization/GPTQ.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
pack_tinygemm_scales_and_zeros,
3737
per_token_dynamic_quant,
3838
)
39+
from torchao.dtypes.utils import is_device
3940

4041
aten = torch.ops.aten
4142

@@ -542,12 +543,20 @@ def linear_forward_int4(
542543
):
543544
origin_x_size = x.size()
544545
x = x.reshape(-1, origin_x_size[-1])
545-
c = torch.ops.aten._weight_int4pack_mm(
546-
x.to(precision),
547-
weight_int4pack,
548-
groupsize,
549-
scales_and_zeros.to(scales_precision),
550-
).to(dtype=x.dtype)
546+
if is_device(x.device.type, "cpu"):
547+
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
548+
x.to(precision),
549+
weight_int4pack,
550+
groupsize,
551+
scales_and_zeros.to(scales_precision),
552+
).to(dtype=x.dtype)
553+
else:
554+
c = torch.ops.aten._weight_int4pack_mm(
555+
x.to(precision),
556+
weight_int4pack,
557+
groupsize,
558+
scales_and_zeros.to(scales_precision),
559+
).to(dtype=x.dtype)
551560
new_shape = origin_x_size[:-1] + (out_features,)
552561
c = c.reshape(new_shape)
553562
return c
@@ -596,19 +605,32 @@ def __init__(
596605
assert (
597606
in_features % (inner_k_tiles * 16) == 0
598607
), "require in_features % (innerKTiles * 16) == 0"
599-
self.register_buffer(
600-
"weight",
601-
torch.zeros(
602-
(
603-
out_features // 8,
604-
in_features // (inner_k_tiles * 16),
605-
32,
606-
inner_k_tiles // 2,
608+
if is_device(device.type, "cpu"):
609+
self.register_buffer(
610+
"weight",
611+
torch.zeros(
612+
(
613+
out_features,
614+
in_features // 2,
615+
),
616+
dtype=torch.uint8,
617+
device=device,
607618
),
608-
dtype=torch.int32,
609-
device=device,
610-
),
611-
)
619+
)
620+
else:
621+
self.register_buffer(
622+
"weight",
623+
torch.zeros(
624+
(
625+
out_features // 8,
626+
in_features // (inner_k_tiles * 16),
627+
32,
628+
inner_k_tiles // 2,
629+
),
630+
dtype=torch.int32,
631+
device=device,
632+
),
633+
)
612634
self.dtype = dtype
613635
self.register_buffer(
614636
"scales_and_zeros",
@@ -765,9 +787,14 @@ def _create_quantized_state_dict(
765787
self.precision, # dtype for scales_and_zeros
766788
)
767789
# TODO: just get the device from mod.weight.device?
768-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
769-
w_int4x8.to(self.device), self.inner_k_tiles
770-
)
790+
if is_device(w_int4x8.device.type, "cpu"):
791+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
792+
w_int4x8.to(self.device), self.inner_k_tiles
793+
)
794+
else:
795+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
796+
w_int4x8.to(self.device), self.inner_k_tiles
797+
)
771798
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
772799
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
773800
self.device
@@ -851,9 +878,14 @@ def make_names_and_values_dict_func(q, qparams):
851878
# how much we need to pad the weight
852879
delta_k = int((new_k - k) / 2)
853880
q = q.to(self.device)
854-
final_q = torch.ops.aten._convert_weight_to_int4pack(
855-
F.pad(q, pad=(0, delta_k)), inner_k_tiles
856-
)
881+
if is_device(self.device.type, "cpu"):
882+
final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
883+
F.pad(q, pad=(0, delta_k)), inner_k_tiles
884+
)
885+
else:
886+
final_q = torch.ops.aten._convert_weight_to_int4pack(
887+
F.pad(q, pad=(0, delta_k)), inner_k_tiles
888+
)
857889
scales = qparams[0].to(torch.bfloat16).to(self.device)
858890
zeros = qparams[1].to(torch.bfloat16).to(self.device)
859891
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

torchao/quantization/qat/linear.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .utils import (
3030
_get_qmin_qmax,
3131
)
32+
from torchao.dtypes.utils import is_device
3233

3334

3435
class FakeQuantizedLinear(torch.nn.Linear):
@@ -373,10 +374,16 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
373374
n_bit,
374375
config.group_size,
375376
)
376-
q_weight = torch.ops.aten._convert_weight_to_int4pack(
377-
q_weight.to(child.weight.device),
378-
child.inner_k_tiles,
379-
)
377+
if is_device(q_weight.device.type, "cpu"):
378+
q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
379+
q_weight.to(child.weight.device),
380+
child.inner_k_tiles,
381+
)
382+
else:
383+
q_weight = torch.ops.aten._convert_weight_to_int4pack(
384+
q_weight.to(child.weight.device),
385+
child.inner_k_tiles,
386+
)
380387
quantized_linear.weight = q_weight
381388
quantized_linear.scales_and_zeros = scales_and_zeros
382389
else:

torchao/quantization/quant_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,8 @@ def int4_weight_only(
575575
"tensor_core_tiled" layout for speedup with tinygemm kernel
576576
577577
Note:
578-
This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference
578+
This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`
579+
and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference
579580
of quantization algorithm compared to the more traditional type of integer quantization is the following:
580581
1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`)
581582
2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`)

torchao/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
418418
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
419419
if TORCH_VERSION_AT_LEAST_2_5 and (
420420
w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1
421-
):
421+
) and not is_device(w_int4x8.device.type, "cpu"):
422422
data = w_int4x8.to(torch.int32)
423423
high_bits = data >> 4
424424
low_bits = data & 0x0F

0 commit comments

Comments
 (0)