@@ -803,6 +803,8 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
803803 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
804804 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
805805 def test_int4_weight_only_quant_subclass (self , device , dtype ):
806+ if device == "cpu" :
807+ self .skipTest (f"Temporarily skipping for { device } " )
806808 if dtype != torch .bfloat16 :
807809 self .skipTest (f"Fails for { dtype } " )
808810 for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 8 )] if device == 'cuda' else [])):
@@ -896,6 +898,8 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
896898 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
897899 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
898900 def test_int4_weight_only_quant_subclass_api (self , device , dtype ):
901+ if device == "cpu" :
902+ self .skipTest (f"Temporarily skipping for { device } " )
899903 if dtype != torch .bfloat16 :
900904 self .skipTest (f"Fails for { dtype } " )
901905 for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 256 )] if device == 'cuda' else [])):
@@ -911,6 +915,8 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
911915 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
912916 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
913917 def test_int4_weight_only_quant_subclass_api_grouped (self , device , dtype ):
918+ if device == "cpu" :
919+ self .skipTest (f"Temporarily skipping for { device } " )
914920 if dtype != torch .bfloat16 :
915921 self .skipTest (f"Fails for { dtype } " )
916922 for test_shape in ([(256 , 256 , 16 )] + ([(256 , 256 , 8 )] if device == 'cuda' else [])):
0 commit comments