@@ -129,9 +129,16 @@ def test_quant_llm_linear_correctness(
129129TEST_CONFIGS_DEQUANT = list (itertools .product (SHAPES , INNERKTILES , QGROUP_SIZES ))
130130
131131
132+ def make_test_id (param ):
133+ if isinstance (param , tuple ) and len (param ) == 2 : # This is a shape
134+ return f"shape_{ param [0 ]} x{ param [1 ]} "
135+ else : # This is inner_k_tiles
136+ return f"tiles_{ param } "
137+
138+
132139@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
133140# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
134- @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
141+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
135142def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
136143 N , K = shape
137144 assert K % (inner_k_tiles * kTileSizeK ) == 0 and N % kTileSizeN == 0
@@ -149,7 +156,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
149156# TODO: Fix "test_aot_dispatch_dynamic" test failure
150157@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
151158# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
152- @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
159+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
153160def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
154161 test_utils = [
155162 "test_schema" ,
0 commit comments