Skip to content

Commit 05e15e9

Browse files
conditionally import triton
1 parent 9e95c91 commit 05e15e9

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,27 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
79
import pytest
810
import torch
911

10-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
11-
triton_fp8_col_major_jagged_colwise_scales,
12-
triton_fp8_row_major_jagged_rowwise_scales,
13-
)
1412
from torchao.prototype.scaled_grouped_mm.utils import (
1513
_is_column_major,
1614
_to_2d_jagged_float8_tensor_colwise,
1715
_to_2d_jagged_float8_tensor_rowwise,
1816
)
1917

18+
logging.basicConfig(level=logging.INFO)
19+
logger = logging.getLogger(__name__)
20+
21+
# triton only ships with pytorch cuda builds, so do import conditionally.
22+
if torch.cuda.is_available():
23+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
24+
triton_fp8_col_major_jagged_colwise_scales,
25+
triton_fp8_row_major_jagged_rowwise_scales,
26+
)
27+
2028

2129
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
2230
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])

0 commit comments

Comments
 (0)