Skip to content

Commit 6821e44

Browse files
only run tests on compute capability 9.0+
1 parent 8a37453 commit 6821e44

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111

1212
# We need to skip before doing any imports which would use triton, since
1313
# triton won't be available on CPU builds and torch < 2.5
14-
if not (TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available()):
14+
if not (
15+
TORCH_VERSION_AT_LEAST_2_5
16+
and torch.cuda.is_available()
17+
and torch.cuda.get_device_capability()[0] >= 9
18+
):
1519
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1620

21+
1722
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
1823
triton_fp8_col_major_jagged_colwise_scales,
1924
triton_fp8_row_major_jagged_rowwise_scales,

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
# We need to skip before doing any imports which would use triton, since
1313
# triton won't be available on CPU builds and torch < 2.5
14-
if not (TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available()):
14+
if not (
15+
TORCH_VERSION_AT_LEAST_2_5
16+
and torch.cuda.is_available()
17+
and torch.cuda.get_device_capability()[0] >= 9
18+
):
1519
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1620

1721

0 commit comments

Comments
 (0)