Skip to content

Commit bd3b79a

Browse files
committed
Fix CI
1 parent 5f41c1e commit bd3b79a

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

test/test_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,12 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
102102
assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0
103103

104104
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
105+
if TORCH_VERSION_AFTER_2_5:
106+
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
105107
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
106108
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles)
109+
if TORCH_VERSION_AFTER_2_5:
110+
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)
107111
assert torch.equal(t, unpacked)
108112

109113
# TODO: Fix "test_aot_dispatch_dynamic" test failure
@@ -122,6 +126,8 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
122126
test_utils.append("test_aot_dispatch_dynamic")
123127

124128
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
129+
if TORCH_VERSION_AFTER_2_5:
130+
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
125131
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
126132

127133
opcheck(
@@ -229,6 +235,9 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
229235

230236
# Unpack and dequantize
231237
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
238+
if TORCH_VERSION_AFTER_2_5:
239+
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)
240+
232241
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
233242
unpacked, scales, zeros, n_bit=4, groupsize=group_size
234243
)

torchao/quantization/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,18 +362,26 @@ def groupwise_affine_dequantize_tensor_from_qparams(
362362
groupsize=128,
363363
):
364364
assert groupsize > 1
365-
# needed for GPTQ single column dequantize
366-
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
367-
groupsize = w_int4x8.shape[-1]
368-
assert w_int4x8.shape[-1] % groupsize == 0
369365
assert w_int4x8.dim() == 2
366+
if TORCH_VERSION_AFTER_2_5:
367+
data = w_int4x8.to(torch.int32)
368+
high_bits = data >> 4
369+
low_bits = data & 0x0F
370+
w_int32 = torch.zeros((w_int4x8.shape[0], w_int4x8.shape[1] * 2), dtype=torch.int32, device=w_int4x8.device)
371+
w_int32[::, ::2] = high_bits
372+
w_int32[::, 1::2] = low_bits
373+
else:
374+
w_int32 = w_int4x8
370375

376+
# needed for GPTQ single column dequantize
377+
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
378+
groupsize = w_int32.shape[-1]
379+
assert w_int32.shape[-1] % groupsize == 0
371380
block_size = (1, groupsize)
372381
input_dtype = torch.int32
373382
quant_min = 0
374383
quant_max = 2**n_bit - 1
375-
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)
376-
384+
return dequantize_affine(w_int32, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)
377385

378386
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
379387
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)

0 commit comments

Comments
 (0)