1111
1212import pytest
1313
14- from torchao .utils import TORCH_VERSION_AT_LEAST_2_5 , is_sm_89 , is_sm_90
14+ from torchao .utils import (
15+ TORCH_VERSION_AT_LEAST_2_5 ,
16+ is_sm_at_least_89 ,
17+ is_sm_at_least_90 ,
18+ )
1519
1620if not TORCH_VERSION_AT_LEAST_2_5 :
1721 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
@@ -95,7 +99,7 @@ def _test_compile_base(
9599 "scaling_type_grad_output" ,
96100 [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
97101)
98- @pytest .mark .parametrize ("emulate" , [False , True ] if is_sm_89 () else [True ])
102+ @pytest .mark .parametrize ("emulate" , [False , True ] if is_sm_at_least_89 () else [True ])
99103@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
100104@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
101105def test_eager_only (
@@ -122,7 +126,7 @@ def test_eager_only(
122126
123127
124128@pytest .mark .parametrize ("fullgraph" , [True ])
125- @pytest .mark .parametrize ("emulate" , [False , True ] if is_sm_89 () else [True ])
129+ @pytest .mark .parametrize ("emulate" , [False , True ] if is_sm_at_least_89 () else [True ])
126130@pytest .mark .parametrize (
127131 "scaling_type_input" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
128132)
@@ -173,7 +177,7 @@ def test_aot_eager(
173177 [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
174178)
175179@unittest .skipIf (
176- not torch .cuda .is_available () or not is_sm_89 (),
180+ not torch .cuda .is_available () or not is_sm_at_least_89 (),
177181 "CUDA with float8 support not available" ,
178182)
179183@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
@@ -211,7 +215,9 @@ def test_inductor_from_config_params(
211215 Float8LinearRecipeName .LW_AXISWISE_WITH_GW_HP ,
212216 ],
213217)
214- @unittest .skipIf (not is_sm_90 (), "CUDA with capability 9.0 or greater not available" )
218+ @unittest .skipIf (
219+ not is_sm_at_least_90 (), "CUDA with capability 9.0 or greater not available"
220+ )
215221def test_inductor_from_recipe (recipe_name ):
216222 torch ._dynamo .reset ()
217223 config = recipe_name_to_linear_config (recipe_name )
@@ -249,7 +255,7 @@ def forward(self, x):
249255
250256 # TODO(future): figure out why the test below fails on CUDA capability 8.9
251257 @unittest .skipIf (
252- not torch .cuda .is_available () or not is_sm_90 (),
258+ not torch .cuda .is_available () or not is_sm_at_least_90 (),
253259 "CUDA with capability 9.0 or greater not available" ,
254260 )
255261 def test_float8_with_graph_break_in_the_middle (self ):
@@ -265,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self):
265271 torch .testing .assert_close (y_eager , y_compiled )
266272
267273 @unittest .skipIf (
268- not torch .cuda .is_available () or not is_sm_89 (),
274+ not torch .cuda .is_available () or not is_sm_at_least_89 (),
269275 "CUDA with float8 support not available" ,
270276 )
271277 def test_float8_graph_input (self ):
@@ -289,7 +295,7 @@ def to_float(x):
289295 torch .testing .assert_close (y2_eager , y2_compiled )
290296
291297 @unittest .skipIf (
292- not torch .cuda .is_available () or not is_sm_89 (),
298+ not torch .cuda .is_available () or not is_sm_at_least_89 (),
293299 "CUDA with float8 support not available" ,
294300 )
295301 def test_float8_graph_output (self ):
@@ -319,7 +325,7 @@ def test_float8_graph_output(self):
319325
320326
321327@unittest .skipIf (
322- not torch .cuda .is_available () or not is_sm_89 (),
328+ not torch .cuda .is_available () or not is_sm_at_least_89 (),
323329 "CUDA with float8 support not available" ,
324330)
325331def test_sync_amax_func ():
@@ -360,7 +366,7 @@ def __exit__(self, *args):
360366
361367
362368@unittest .skipIf (
363- not torch .cuda .is_available () or not is_sm_89 (),
369+ not torch .cuda .is_available () or not is_sm_at_least_89 (),
364370 "CUDA with float8 support not available" ,
365371)
366372def test_sync_amax_func_cuda_graph_success ():
@@ -392,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success():
392398
393399
394400@unittest .skipIf (
395- not is_sm_89 (),
401+ not is_sm_at_least_89 (),
396402 "CUDA not available" ,
397403)
398404@pytest .mark .parametrize (
0 commit comments