File tree Expand file tree Collapse file tree 2 files changed +11
-2
lines changed
test/prototype/scaled_grouped_mm Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Original file line number Diff line number Diff line change 1111
1212# We need to skip before doing any imports which would use triton, since
1313# triton won't be available on CPU builds and torch < 2.5
14- if not (TORCH_VERSION_AT_LEAST_2_5 and torch .cuda .is_available ()):
14+ if not (
15+ TORCH_VERSION_AT_LEAST_2_5
16+ and torch .cuda .is_available ()
17+ and torch .cuda .get_device_capability ()[0 ] >= 9
18+ ):
1519 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
1620
21+
1722from torchao .prototype .scaled_grouped_mm .kernels .jagged_float8_scales import (
1823 triton_fp8_col_major_jagged_colwise_scales ,
1924 triton_fp8_row_major_jagged_rowwise_scales ,
Original file line number Diff line number Diff line change 1111
1212# We need to skip before doing any imports which would use triton, since
1313# triton won't be available on CPU builds and torch < 2.5
14- if not (TORCH_VERSION_AT_LEAST_2_5 and torch .cuda .is_available ()):
14+ if not (
15+ TORCH_VERSION_AT_LEAST_2_5
16+ and torch .cuda .is_available ()
17+ and torch .cuda .get_device_capability ()[0 ] >= 9
18+ ):
1519 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
1620
1721
You can’t perform that action at this time.
0 commit comments