9090except ModuleNotFoundError :
9191 has_gemlite = False
9292
93+ from test_utils import skip_if_rocm
94+
9395logger = logging .getLogger ("INFO" )
9496
9597torch .manual_seed (0 )
@@ -566,6 +568,7 @@ def test_per_token_linear_cpu(self):
566568 self ._test_per_token_linear_impl ("cpu" , dtype )
567569
568570 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
571+ @skip_if_rocm ("ROCm development in progress" )
569572 def test_per_token_linear_cuda (self ):
570573 for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
571574 self ._test_per_token_linear_impl ("cuda" , dtype )
@@ -684,6 +687,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
684687 @parameterized .expand (COMMON_DEVICE_DTYPE )
685688 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
686689 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
690+ @skip_if_rocm ("ROCm development in progress" )
687691 def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
688692 if device == "cpu" :
689693 self .skipTest (f"Temporarily skipping for { device } " )
@@ -703,6 +707,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
703707 @parameterized .expand (COMMON_DEVICE_DTYPE )
704708 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
705709 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
710+ @skip_if_rocm ("ROCm development in progress" )
706711 def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
707712 if device == "cpu" :
708713 self .skipTest (f"Temporarily skipping for { device } " )
@@ -896,6 +901,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
896901 @parameterized .expand (COMMON_DEVICE_DTYPE )
897902 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
898903 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
904+ @skip_if_rocm ("ROCm development in progress" )
899905 def test_int4_weight_only_quant_subclass (self , device , dtype ):
900906 if device == "cpu" :
901907 self .skipTest (f"Temporarily skipping for { device } " )
@@ -915,6 +921,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
915921 @parameterized .expand (COMMON_DEVICE_DTYPE )
916922 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
917923 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
924+ @skip_if_rocm ("ROCm development in progress" )
918925 def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
919926 if dtype != torch .bfloat16 :
920927 self .skipTest (f"Fails for { dtype } " )
0 commit comments