9393except ModuleNotFoundError :
9494 has_gemlite = False
9595
96- from test_utils import skip_if_rocm
97-
9896logger = logging .getLogger ("INFO" )
9997
10098torch .manual_seed (0 )
@@ -571,7 +569,6 @@ def test_per_token_linear_cpu(self):
571569 self ._test_per_token_linear_impl ("cpu" , dtype )
572570
573571 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
574- @skip_if_rocm ("ROCm development in progress" )
575572 def test_per_token_linear_cuda (self ):
576573 for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
577574 self ._test_per_token_linear_impl ("cuda" , dtype )
@@ -690,7 +687,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
690687 @parameterized .expand (COMMON_DEVICE_DTYPE )
691688 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
692689 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
693- @skip_if_rocm ("ROCm development in progress" )
694690 def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
695691 if device == "cpu" :
696692 self .skipTest (f"Temporarily skipping for { device } " )
@@ -710,7 +706,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
710706 @parameterized .expand (COMMON_DEVICE_DTYPE )
711707 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
712708 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
713- @skip_if_rocm ("ROCm development in progress" )
714709 def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
715710 if device == "cpu" :
716711 self .skipTest (f"Temporarily skipping for { device } " )
@@ -904,7 +899,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
904899 @parameterized .expand (COMMON_DEVICE_DTYPE )
905900 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
906901 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
907- @skip_if_rocm ("ROCm development in progress" )
908902 def test_int4_weight_only_quant_subclass (self , device , dtype ):
909903 if device == "cpu" :
910904 self .skipTest (f"Temporarily skipping for { device } " )
@@ -924,7 +918,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
924918 @parameterized .expand (COMMON_DEVICE_DTYPE )
925919 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
926920 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
927- @skip_if_rocm ("ROCm development in progress" )
928921 def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
929922 if dtype != torch .bfloat16 :
930923 self .skipTest (f"Fails for { dtype } " )
0 commit comments