1414import torch
1515import torch .nn as nn
1616
17- from torchao .utils import TORCH_VERSION_AT_LEAST_2_5 , is_sm_89
17+ from torchao .utils import (
18+ TORCH_VERSION_AT_LEAST_2_5 ,
19+ is_sm_at_least_89 ,
20+ is_sm_at_least_90 ,
21+ )
1822
1923if not TORCH_VERSION_AT_LEAST_2_5 :
2024 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
6064torch .manual_seed (0 )
6165
6266
63- is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
64- is_cuda_9_0 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
65-
66-
6767def bitwise_identical (a : Float8Tensor , b : Float8Tensor ) -> bool :
6868 assert torch .all (a ._scale == b ._scale ).item (), "scales are not identical"
6969 assert torch .all (a ._data == b ._data ).item (), "data is not identical"
@@ -219,7 +219,7 @@ def test_axiswise_reshape(self):
219219 ],
220220 )
221221 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
222- @unittest .skipIf (not is_cuda_9_0 , "Requires CUDA capability >= 9.0" )
222+ @unittest .skipIf (not is_sm_at_least_90 () , "Requires CUDA capability >= 9.0" )
223223 def test_axiswise_gemm (self , a_shape , a_granularity , b_granularity ):
224224 a = torch .randn (* a_shape , dtype = torch .bfloat16 , device = "cuda" )
225225 b = torch .randn (64 , 32 , dtype = torch .bfloat16 , device = "cuda" )
@@ -333,7 +333,9 @@ def _test_linear_impl(
333333 # verify initialization flags got updated
334334 assert m_fp8 .is_amax_initialized , "Amax was not properly initialized"
335335
336- @pytest .mark .parametrize ("emulate" , [True , False ] if is_cuda_8_9 else [True ])
336+ @pytest .mark .parametrize (
337+ "emulate" , [True , False ] if is_sm_at_least_89 () else [True ]
338+ )
337339 @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
338340 @pytest .mark .parametrize (
339341 "scaling_type_input" ,
@@ -415,7 +417,9 @@ def test_linear_from_recipe(
415417 config ,
416418 )
417419
418- @pytest .mark .parametrize ("emulate" , [True , False ] if is_cuda_8_9 else [True ])
420+ @pytest .mark .parametrize (
421+ "emulate" , [True , False ] if is_sm_at_least_89 () else [True ]
422+ )
419423 @pytest .mark .parametrize (
420424 "linear_dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ]
421425 )
@@ -462,7 +466,9 @@ def test_autocast_outputs(
462466 @pytest .mark .parametrize (
463467 "linear_dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ]
464468 )
465- @pytest .mark .parametrize ("emulate" , [True , False ] if is_cuda_8_9 else [True ])
469+ @pytest .mark .parametrize (
470+ "emulate" , [True , False ] if is_sm_at_least_89 () else [True ]
471+ )
466472 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
467473 def test_type_cast (self , linear_dtype : torch .dtype , emulate : bool ):
468474 m = nn .Linear (32 , 16 , device = "cuda" , dtype = linear_dtype )
@@ -523,15 +529,15 @@ def test_repr(self):
523529 s = m .__repr__ ()
524530 assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s
525531
526- @unittest .skipIf (not is_cuda_8_9 , "CUDA 8.9 not available" )
532+ @unittest .skipIf (not is_sm_at_least_89 () , "CUDA 8.9 not available" )
527533 def test_inference_mode (self ):
528534 x = torch .randn (32 , 32 , device = "cuda" )
529535 m = nn .Sequential (nn .Linear (32 , 32 )).cuda ()
530536 m = convert_to_float8_training (m )
531537 with torch .inference_mode (mode = True ):
532538 m (x )
533539
534- @unittest .skipIf (not is_sm_89 (), "CUDA arch 8.9 not available" )
540+ @unittest .skipIf (not is_sm_at_least_89 (), "CUDA arch 8.9 not available" )
535541 def test_quantize (self ):
536542 x = torch .randn (32 , 32 , device = "cuda" )
537543 m = nn .Sequential (nn .Linear (32 , 32 )).cuda ()
@@ -549,7 +555,7 @@ def test_quantize(self):
549555
550556class TestScaledMM :
551557 @unittest .skipIf (
552- not is_cuda_8_9 ,
558+ not is_sm_at_least_89 () ,
553559 "CUDA not available" ,
554560 )
555561 @pytest .mark .parametrize (
@@ -594,7 +600,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
594600 atol , rtol = 3e-3 , 3e-3
595601 torch .testing .assert_close (out_scaled_mm , out_emulated , atol = atol , rtol = rtol )
596602
597- @unittest .skipIf (not is_cuda_8_9 , "CUDA not available" )
603+ @unittest .skipIf (not is_sm_at_least_89 () , "CUDA not available" )
598604 def test_different_configs_error (self ):
599605 x_fp32 = torch .randn (16 , 16 , device = "cuda" )
600606 x_scale = torch .tensor (1.0 , device = "cuda" )
@@ -630,7 +636,7 @@ def test_different_configs_error(self):
630636 a @ b
631637
632638 @unittest .skipIf (
633- not is_cuda_8_9 ,
639+ not is_sm_at_least_89 () ,
634640 "CUDA not available" ,
635641 )
636642 @pytest .mark .parametrize (
0 commit comments