9393except ModuleNotFoundError :
9494 has_gemlite = False
9595
96+ from test_utils import skip_if_rocm
97+
9698logger = logging .getLogger ("INFO" )
9799
98100torch .manual_seed (0 )
@@ -569,6 +571,7 @@ def test_per_token_linear_cpu(self):
569571 self ._test_per_token_linear_impl ("cpu" , dtype )
570572
571573 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
574+ @skip_if_rocm ("ROCm development in progress" )
572575 def test_per_token_linear_cuda (self ):
573576 for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
574577 self ._test_per_token_linear_impl ("cuda" , dtype )
@@ -687,6 +690,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
687690 @parameterized .expand (COMMON_DEVICE_DTYPE )
688691 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
689692 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
693+ @skip_if_rocm ("ROCm development in progress" )
690694 def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
691695 if device == "cpu" :
692696 self .skipTest (f"Temporarily skipping for { device } " )
@@ -706,6 +710,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
706710 @parameterized .expand (COMMON_DEVICE_DTYPE )
707711 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
708712 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
713+ @skip_if_rocm ("ROCm development in progress" )
709714 def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
710715 if device == "cpu" :
711716 self .skipTest (f"Temporarily skipping for { device } " )
@@ -899,6 +904,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
899904 @parameterized .expand (COMMON_DEVICE_DTYPE )
900905 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
901906 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
907+ @skip_if_rocm ("ROCm development in progress" )
902908 def test_int4_weight_only_quant_subclass (self , device , dtype ):
903909 if device == "cpu" :
904910 self .skipTest (f"Temporarily skipping for { device } " )
@@ -918,6 +924,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
918924 @parameterized .expand (COMMON_DEVICE_DTYPE )
919925 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
920926 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
927+ @skip_if_rocm ("ROCm development in progress" )
921928 def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
922929 if dtype != torch .bfloat16 :
923930 self .skipTest (f"Fails for { dtype } " )
0 commit comments