@@ -102,8 +102,12 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
102102 assert K % (inner_k_tiles * kTileSizeK ) == 0 and N % kTileSizeN == 0
103103
104104 t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
105+ if TORCH_VERSION_AFTER_2_5 :
106+ t = (t [::, ::2 ] << 4 | t [::, 1 ::2 ]).to (torch .uint8 )
105107 packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
106108 unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed_w , inner_k_tiles )
109+ if TORCH_VERSION_AFTER_2_5 :
110+ unpacked = (unpacked [::, ::2 ] << 4 | unpacked [::, 1 ::2 ]).to (torch .uint8 )
107111 assert torch .equal (t , unpacked )
108112
109113# TODO: Fix "test_aot_dispatch_dynamic" test failure
@@ -122,6 +126,8 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
122126 test_utils .append ("test_aot_dispatch_dynamic" )
123127
124128 t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
129+ if TORCH_VERSION_AFTER_2_5 :
130+ t = (t [::, ::2 ] << 4 | t [::, 1 ::2 ]).to (torch .uint8 )
125131 packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
126132
127133 opcheck (
@@ -229,6 +235,9 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
229235
230236 # Unpack and dequantize
231237 unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed , inner_k_tiles )
238+ if TORCH_VERSION_AFTER_2_5 :
239+ unpacked = (unpacked [::, ::2 ] << 4 | unpacked [::, 1 ::2 ]).to (torch .uint8 )
240+
232241 dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
233242 unpacked , scales , zeros , n_bit = 4 , groupsize = group_size
234243 )
0 commit comments