@@ -903,6 +903,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
903903 @parameterized .expand (COMMON_DEVICE_DTYPE )
904904 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
905905 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
906+ @skip_if_rocm ("ROCm enablement in progress" )
906907 def test_int4_weight_only_quant_subclass (self , device , dtype ):
907908 if device == "cpu" :
908909 self .skipTest (f"Temporarily skipping for { device } " )
@@ -922,6 +923,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
922923 @parameterized .expand (COMMON_DEVICE_DTYPE )
923924 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
924925 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
926+ @skip_if_rocm ("ROCm enablement in progress" )
925927 def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
926928 if dtype != torch .bfloat16 :
927929 self .skipTest (f"Fails for { dtype } " )
@@ -1075,6 +1077,7 @@ def test_gemlite_layout(self, device, dtype):
10751077 @parameterized .expand (COMMON_DEVICE_DTYPE )
10761078 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
10771079 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
1080+ @skip_if_rocm ("ROCm enablement in progress" )
10781081 def test_int4_weight_only_quant_subclass_api_grouped (self , device , dtype ):
10791082 if device == "cpu" :
10801083 self .skipTest (f"Temporarily skipping for { device } " )
0 commit comments