File tree Expand file tree Collapse file tree 1 file changed +12
-4
lines changed
test/prototype/scaled_grouped_mm Expand file tree Collapse file tree 1 file changed +12
-4
lines changed Original file line number Diff line number Diff line change 44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import logging
8+
79import pytest
810import torch
911
10- from torchao .prototype .scaled_grouped_mm .kernels .jagged_float8_scales import (
11- triton_fp8_col_major_jagged_colwise_scales ,
12- triton_fp8_row_major_jagged_rowwise_scales ,
13- )
1412from torchao .prototype .scaled_grouped_mm .utils import (
1513 _is_column_major ,
1614 _to_2d_jagged_float8_tensor_colwise ,
1715 _to_2d_jagged_float8_tensor_rowwise ,
1816)
1917
18+ logging .basicConfig (level = logging .INFO )
19+ logger = logging .getLogger (__name__ )
20+
21+ # triton only ships with pytorch cuda builds, so do import conditionally.
22+ if torch .cuda .is_available ():
23+ from torchao .prototype .scaled_grouped_mm .kernels .jagged_float8_scales import (
24+ triton_fp8_col_major_jagged_colwise_scales ,
25+ triton_fp8_row_major_jagged_rowwise_scales ,
26+ )
27+
2028
2129@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
2230@pytest .mark .parametrize ("round_scales_to_power_of_2" , [True , False ])
You can’t perform that action at this time.
0 commit comments