@@ -570,7 +570,6 @@ def test_per_token_linear_cpu(self):
570570 self ._test_per_token_linear_impl ("cpu" , dtype )
571571
572572 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
573- @skip_if_rocm ("ROCm development in progress" )
574573 def test_per_token_linear_cuda (self ):
575574 for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
576575 self ._test_per_token_linear_impl ("cuda" , dtype )
@@ -689,7 +688,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
689688 @parameterized .expand (COMMON_DEVICE_DTYPE )
690689 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
691690 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
692- @skip_if_rocm ("ROCm development in progress" )
693691 def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
694692 if device == "cpu" :
695693 self .skipTest (f"Temporarily skipping for { device } " )
@@ -709,7 +707,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
709707 @parameterized .expand (COMMON_DEVICE_DTYPE )
710708 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
711709 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
712- @skip_if_rocm ("ROCm development in progress" )
713710 def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
714711 if device == "cpu" :
715712 self .skipTest (f"Temporarily skipping for { device } " )
@@ -903,7 +900,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
903900 @parameterized .expand (COMMON_DEVICE_DTYPE )
904901 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
905902 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
906- @skip_if_rocm ("ROCm development in progress" )
907903 def test_int4_weight_only_quant_subclass (self , device , dtype ):
908904 if device == "cpu" :
909905 self .skipTest (f"Temporarily skipping for { device } " )
@@ -923,7 +919,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
923919 @parameterized .expand (COMMON_DEVICE_DTYPE )
924920 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
925921 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
926- @skip_if_rocm ("ROCm development in progress" )
927922 def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
928923 if dtype != torch .bfloat16 :
929924 self .skipTest (f"Fails for { dtype } " )
@@ -1827,7 +1822,7 @@ def test_autoquant_int4wo(self, device, dtype):
18271822 self .assertGreater (compute_error (ref , out ), 20 )
18281823
18291824 @parameterized .expand (COMMON_DEVICE_DTYPE )
1830- @unittest .skipIf (not torch . cuda . is_available (), "Need CUDA available " )
1825+ @unittest .skipIf (not is_sm_at_least_90 (), "Need cuda arch greater than SM90 " )
18311826 @unittest .skipIf (
18321827 not TORCH_VERSION_AT_LEAST_2_5 , "autoquant int4 option requires 2.5+."
18331828 )
0 commit comments