Skip to content

Commit 4b88373

Browse files
skip torch < 2.5
1 parent 05e15e9 commit 4b88373

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
triton_fp8_col_major_jagged_colwise_scales,
2525
triton_fp8_row_major_jagged_rowwise_scales,
2626
)
27+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2728

2829

2930
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
31+
@pytest.mark.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "torch 2.5+ required")
3032
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
3133
def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
3234
# tests case where rowwise scales are computed for multiple distinct subtensors,
@@ -55,6 +57,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
5557

5658

5759
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
60+
@pytest.mark.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "torch 2.5+ required")
5861
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
5962
def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: bool):
6063
# tests case where colwise scales are computed for multiple distinct subtensors,

0 commit comments

Comments
 (0)