Skip to content

Commit 5099846

Browse files
committed
Fix int4pack_mm error
1 parent 591df26 commit 5099846

File tree

8 files changed

+21
-18
lines changed

8 files changed

+21
-18
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class TestAffineQuantized(TestCase):
1414
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
15-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
15+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
1616
def test_tensor_core_layout_transpose(self):
1717
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
1818
shape = t.shape

test/integration/test_integration.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
631631

632632
@parameterized.expand(COMMON_DEVICE_DTYPE)
633633
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
634-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
634+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
635635
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
636636
if dtype != torch.bfloat16:
637637
self.skipTest("Currently only supports bfloat16.")
@@ -642,7 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
642642

643643
@parameterized.expand(COMMON_DEVICE_DTYPE)
644644
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
645-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
645+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
646646
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
647647
if dtype != torch.bfloat16:
648648
self.skipTest("Currently only supports bfloat16.")
@@ -737,7 +737,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
737737

738738
@parameterized.expand(COMMON_DEVICE_DTYPE)
739739
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
740-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
740+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
741741
def test_int4_weight_only_quant_subclass(self, device, dtype):
742742
if dtype != torch.bfloat16:
743743
self.skipTest(f"Fails for {dtype}")
@@ -748,7 +748,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
748748

749749
@parameterized.expand(COMMON_DEVICE_DTYPE)
750750
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
751-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
751+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
752752
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
753753
if dtype != torch.bfloat16:
754754
self.skipTest(f"Fails for {dtype}")
@@ -823,7 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
823823

824824
@parameterized.expand(COMMON_DEVICE_DTYPE)
825825
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
826-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
826+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
827827
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
828828
if dtype != torch.bfloat16:
829829
self.skipTest(f"Fails for {dtype}")
@@ -838,7 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
838838

839839
@parameterized.expand(COMMON_DEVICE_DTYPE)
840840
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
841-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
841+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
842842
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
843843
if dtype != torch.bfloat16:
844844
self.skipTest(f"Fails for {dtype}")
@@ -1028,7 +1028,7 @@ def test_save_load_int8woqtensors(self, device, dtype):
10281028

10291029
@parameterized.expand(COMMON_DEVICE_DTYPE)
10301030
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.")
1031-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
1031+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
10321032
@torch.no_grad()
10331033
def test_save_load_int4woqtensors(self, device, dtype):
10341034
if dtype != torch.bfloat16:
@@ -1488,7 +1488,7 @@ def test_get_model_size_autoquant(self, device, dtype):
14881488
@parameterized.expand(
14891489
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
14901490
)
1491-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
1491+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
14921492
def test_get_model_size_aqt(self, api, test_device, test_dtype):
14931493
if test_dtype != torch.bfloat16:
14941494
self.skipTest(f"{api} in {test_dtype} is not supported yet")

test/quantization/test_quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def test_quantized_tensor_subclass_8da4w(self):
523523
self.assertTrue(torch.equal(res, ref))
524524

525525
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
526-
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
526+
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
527527
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
528528
def test_quantized_tensor_subclass_int4(self):
529529
# use 1024 so that we don't need padding

test/quantization/test_quant_primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def _groupwise_affine_quantize_tensor_from_qparams(
9898
.to(torch.int32)
9999
.reshape_as(w)
100100
)
101+
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
101102

102103
return w_int4x8
103104

test/test_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
9595
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))
9696

9797
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
98-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
98+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
9999
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
100100
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
101101
N, K = shape
@@ -108,7 +108,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
108108

109109
# TODO: Fix "test_aot_dispatch_dynamic" test failure
110110
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
111-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
111+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
112112
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
113113
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
114114
test_utils = [
@@ -151,7 +151,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
151151

152152

153153
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
154-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
154+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
155155
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
156156
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
157157
n, k = shape
@@ -210,7 +210,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
210210

211211
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
212212
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
213-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
213+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
214214
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
215215
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
216216
n, k = shape
@@ -264,7 +264,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
264264
assert diff_op_ao < 1e-1
265265

266266
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
267-
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
267+
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
268268
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
269269
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
270270
n, k = shape

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def from_float(
232232

233233
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
234234
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
235+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
235236
int_data = layout_type.post_process(int_data)
236237

237238
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
@@ -500,9 +501,8 @@ def from_plain(
500501
layout_type: LayoutType
501502
):
502503
assert isinstance(layout_type, TensorCoreTiledLayoutType)
503-
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
504-
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
505-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
504+
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
505+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
506506
scale = scale.reshape(int_data.shape[0], -1)
507507
zero_point = zero_point.reshape(int_data.shape[0], -1)
508508
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def hqq_quants_to_torch_quants(
198198
.reshape(shape)
199199
.contiguous()
200200
)
201+
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
201202

202203
# group_dequantize_tensor_from_qparams
203204
# W_r = W_q*scales + min_val

torchao/quantization/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
349349
quant_max = 2 ** n_bit - 1
350350

351351
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
352+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
352353
return int_data
353354

354355
def groupwise_affine_dequantize_tensor_from_qparams(

0 commit comments

Comments
 (0)