1111
1212import pytest
1313
14- from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
14+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_5 , is_sm_89 , is_sm_90
1515
1616if not TORCH_VERSION_AT_LEAST_2_5 :
1717 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
4646from torchao .float8 .float8_utils import e4m3_dtype
4747from torchao .testing .float8 .test_utils import get_test_float8_linear_config
4848
49- # TODO(future PR): standardize IS_H100 with the rest of the codebase
50- is_H100 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
51- is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
52-
5349
5450def _test_compile_base (
5551 backend : str ,
@@ -99,7 +95,7 @@ def _test_compile_base(
9995 "scaling_type_grad_output" ,
10096 [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
10197)
102- @pytest .mark .parametrize ("emulate" , [False , True ] if is_cuda_8_9 else [True ])
98+ @pytest .mark .parametrize ("emulate" , [False , True ] if is_sm_89 () else [True ])
10399@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
104100@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
105101def test_eager_only (
@@ -126,7 +122,7 @@ def test_eager_only(
126122
127123
128124@pytest .mark .parametrize ("fullgraph" , [True ])
129- @pytest .mark .parametrize ("emulate" , [False , True ] if is_cuda_8_9 else [True ])
125+ @pytest .mark .parametrize ("emulate" , [False , True ] if is_sm_89 () else [True ])
130126@pytest .mark .parametrize (
131127 "scaling_type_input" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
132128)
@@ -177,7 +173,7 @@ def test_aot_eager(
177173 [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
178174)
179175@unittest .skipIf (
180- not torch .cuda .is_available () or not is_cuda_8_9 ,
176+ not torch .cuda .is_available () or not is_sm_89 () ,
181177 "CUDA with float8 support not available" ,
182178)
183179@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
@@ -215,7 +211,7 @@ def test_inductor_from_config_params(
215211 Float8LinearRecipeName .LW_AXISWISE_WITH_GW_HP ,
216212 ],
217213)
218- @unittest .skipIf (not is_H100 , "CUDA with capability 9.0 or greater not available" )
214+ @unittest .skipIf (not is_sm_90 () , "CUDA with capability 9.0 or greater not available" )
219215def test_inductor_from_recipe (recipe_name ):
220216 torch ._dynamo .reset ()
221217 config = recipe_name_to_linear_config (recipe_name )
@@ -253,7 +249,7 @@ def forward(self, x):
253249
254250 # TODO(future): figure out why the test below fails on CUDA capability 8.9
255251 @unittest .skipIf (
256- not torch .cuda .is_available () or not is_H100 ,
252+ not torch .cuda .is_available () or not is_sm_90 () ,
257253 "CUDA with capability 9.0 or greater not available" ,
258254 )
259255 def test_float8_with_graph_break_in_the_middle (self ):
@@ -269,7 +265,7 @@ def test_float8_with_graph_break_in_the_middle(self):
269265 torch .testing .assert_close (y_eager , y_compiled )
270266
271267 @unittest .skipIf (
272- not torch .cuda .is_available () or not is_cuda_8_9 ,
268+ not torch .cuda .is_available () or not is_sm_89 () ,
273269 "CUDA with float8 support not available" ,
274270 )
275271 def test_float8_graph_input (self ):
@@ -293,7 +289,7 @@ def to_float(x):
293289 torch .testing .assert_close (y2_eager , y2_compiled )
294290
295291 @unittest .skipIf (
296- not torch .cuda .is_available () or not is_cuda_8_9 ,
292+ not torch .cuda .is_available () or not is_sm_89 () ,
297293 "CUDA with float8 support not available" ,
298294 )
299295 def test_float8_graph_output (self ):
@@ -323,7 +319,7 @@ def test_float8_graph_output(self):
323319
324320
325321@unittest .skipIf (
326- not torch .cuda .is_available () or not is_cuda_8_9 ,
322+ not torch .cuda .is_available () or not is_sm_89 () ,
327323 "CUDA with float8 support not available" ,
328324)
329325def test_sync_amax_func ():
@@ -364,7 +360,7 @@ def __exit__(self, *args):
364360
365361
366362@unittest .skipIf (
367- not torch .cuda .is_available () or not is_cuda_8_9 ,
363+ not torch .cuda .is_available () or not is_sm_89 () ,
368364 "CUDA with float8 support not available" ,
369365)
370366def test_sync_amax_func_cuda_graph_success ():
@@ -396,7 +392,7 @@ def test_sync_amax_func_cuda_graph_success():
396392
397393
398394@unittest .skipIf (
399- not is_cuda_8_9 ,
395+ not is_sm_89 () ,
400396 "CUDA not available" ,
401397)
402398@pytest .mark .parametrize (
0 commit comments