@@ -95,6 +95,7 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
9595TEST_CONFIGS_DEQUANT = list (itertools .product (SHAPES , INNERKTILES , QGROUP_SIZES ))
9696
9797@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
98+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
9899@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
99100def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
100101 N , K = shape
@@ -107,14 +108,15 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
107108
108109# TODO: Fix "test_aot_dispatch_dynamic" test failure
109110@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
111+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
110112@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
111113def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
112114 test_utils = [
113115 "test_schema" ,
114116 "test_autograd_registration" ,
115117 "test_faketensor" ,
116118 ]
117-
119+
118120 # TODO: Figure out why test fails unless torch >= 2.5
119121 if TORCH_VERSION_AFTER_2_5 :
120122 test_utils .append ("test_aot_dispatch_dynamic" )
@@ -137,10 +139,10 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
137139 assert scales .shape == zeros .shape
138140
139141 midpoint = 2 ** (nbits - 1 )
140-
142+
141143 #Convert fron u4 -> s4 and upcast to bfloat16
142144 q = q .sub (midpoint ).to (dtype )
143-
145+
144146 # Dequantize
145147 q = q .reshape (- 1 , group_size )
146148 dq = q * scales .reshape (- 1 , 1 ) + zeros .reshape (- 1 , 1 )
@@ -149,21 +151,22 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
149151
150152
151153@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
154+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
152155@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
153156def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant (shape , inner_k_tiles , group_size ):
154157 n , k = shape
155- dtype = torch .bfloat16
158+ dtype = torch .bfloat16
156159
157160 device = "cuda"
158161
159162 t = torch .randn (n , k , dtype = dtype , device = device )
160163 scales , zeros = get_groupwise_affine_qparams (t , n_bit = 4 , groupsize = group_size , dtype = dtype )
161-
164+
162165 # Quantize
163166 q = groupwise_affine_quantize_tensor_from_qparams (
164167 t , scales , zeros , n_bit = 4 , groupsize = group_size
165168 )
166-
169+
167170 # Pack to tensor core layout
168171 packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
169172 scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
@@ -174,7 +177,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
174177 dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
175178 q , scales , zeros , n_bit = 4 , groupsize = group_size
176179 )
177-
180+
178181 # Dequantize by passing in an identity matrix as the activation
179182 a_eye = torch .eye (k , device = device , dtype = dtype )
180183 dq_id = torch .ops .aten ._weight_int4pack_mm (
@@ -183,34 +186,35 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
183186 group_size ,
184187 scales_and_zeros ,
185188 ).t ()
186-
189+
187190 # Actual operation to test
188191 dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
189-
192+
190193 # Compare results
191194 diff_ao_id = (dq_id - dq_ao ).abs ().max ()
192195 diff_op_id = (dq_op - dq_id ).abs ().max ()
193196 diff_op_ao = (dq_op - dq_ao ).abs ().max ()
194-
197+
195198 # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
196199 # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
197200 # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
198201 # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
199-
200- # Test that the `dequant` kernel gives same results as identity matrix-based dequant
202+
203+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
201204 assert diff_op_id == 0
202-
205+
203206 # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
204207 assert diff_op_ao == diff_ao_id
205208
206209 assert diff_op_ao < 1e-1
207210
208211# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
209212@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
213+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
210214@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
211215def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant (shape , inner_k_tiles , group_size ):
212216 n , k = shape
213- dtype = torch .bfloat16
217+ dtype = torch .bfloat16
214218 device = "cuda"
215219
216220 # Quantize and pack
@@ -222,13 +226,13 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
222226
223227 packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
224228 scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
225-
229+
226230 # Unpack and dequantize
227231 unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed , inner_k_tiles )
228232 dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
229233 unpacked , scales , zeros , n_bit = 4 , groupsize = group_size
230234 )
231-
235+
232236 # Dequantize by passing in an identity matrix as the activation
233237 a_eye = torch .eye (k , device = device , dtype = dtype )
234238 dq_id = torch .ops .aten ._weight_int4pack_mm (
@@ -237,29 +241,30 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
237241 group_size ,
238242 scales_and_zeros ,
239243 ).t ()
240-
244+
241245 # Actual operation to test
242246 dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
243-
247+
244248 # Compare results
245249 diff_ao_id = (dq_id - dq_ao ).abs ().max ()
246250 diff_op_id = (dq_op - dq_id ).abs ().max ()
247251 diff_op_ao = (dq_op - dq_ao ).abs ().max ()
248-
252+
249253 # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
250254 # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
251255 # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
252256 # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
253-
254- # Test that the `dequant` kernel gives same results as identity matrix-based dequant
257+
258+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
255259 assert diff_op_id == 0
256-
260+
257261 # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
258262 assert diff_op_ao == diff_ao_id
259263
260264 assert diff_op_ao < 1e-1
261265
262266@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
267+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
263268@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
264269def test_dequantize_tensor_core_tiled_layout_op (shape , inner_k_tiles , group_size ):
265270 n , k = shape
@@ -271,7 +276,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
271276 scales = torch .randn (n , q_groups , dtype = torch .bfloat16 , device = device )
272277 zeros = torch .randn_like (scales )
273278 scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
274-
279+
275280 test_utils = [
276281 "test_schema" ,
277282 "test_autograd_registration" ,
@@ -287,4 +292,4 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
287292 )
288293
289294if __name__ == "__main__" :
290- run_tests ()
295+ run_tests ()
0 commit comments