|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import logging |
8 | | - |
9 | 7 | import pytest |
10 | 8 | import torch |
11 | 9 |
|
| 10 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 11 | + |
| 12 | +# We need to skip before doing any imports which would use triton, since |
| 13 | +# triton won't be available on CPU builds and torch < 2.5 |
| 14 | +if not (TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available()): |
| 15 | + pytest.skip("Unsupported PyTorch version", allow_module_level=True) |
| 16 | + |
| 17 | +from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( |
| 18 | + triton_fp8_col_major_jagged_colwise_scales, |
| 19 | + triton_fp8_row_major_jagged_rowwise_scales, |
| 20 | +) |
12 | 21 | from torchao.prototype.scaled_grouped_mm.utils import ( |
13 | 22 | _is_column_major, |
14 | 23 | _to_2d_jagged_float8_tensor_colwise, |
15 | 24 | _to_2d_jagged_float8_tensor_rowwise, |
16 | 25 | ) |
17 | 26 |
|
18 | | -logging.basicConfig(level=logging.INFO) |
19 | | -logger = logging.getLogger(__name__) |
20 | | - |
21 | | -# triton only ships with pytorch cuda builds, so do import conditionally. |
22 | | -if torch.cuda.is_available(): |
23 | | - from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( |
24 | | - triton_fp8_col_major_jagged_colwise_scales, |
25 | | - triton_fp8_row_major_jagged_rowwise_scales, |
26 | | - ) |
27 | | -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
28 | | - |
29 | 27 |
|
30 | 28 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
31 | 29 | @pytest.mark.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "torch 2.5+ required") |
|
0 commit comments