@@ -662,6 +662,8 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
662662 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
663663 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
664664 def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
665+ if device == "cpu" :
666+ self .skipTest (f"Temporarily skipping for { device } " )
665667 if dtype != torch .bfloat16 :
666668 self .skipTest ("Currently only supports bfloat16." )
667669 for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 8 )] if device == 'cuda' else [])):
@@ -673,6 +675,8 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
673675 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
674676 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
675677 def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
678+ if device == "cpu" :
679+ self .skipTest (f"Temporarily skipping for { device } " )
676680 if dtype != torch .bfloat16 :
677681 self .skipTest ("Currently only supports bfloat16." )
678682 m_shapes = [16 , 256 ] + ([1 ] if device == "cuda" else [])
@@ -815,6 +819,8 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
815819 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
816820 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
817821 def test_int4_weight_only_quant_subclass (self , device , dtype ):
822+ if device == "cpu" :
823+ self .skipTest (f"Temporarily skipping for { device } " )
818824 if dtype != torch .bfloat16 :
819825 self .skipTest (f"Fails for { dtype } " )
820826 for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 8 )] if device == 'cuda' else [])):
@@ -908,6 +914,8 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
908914 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
909915 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
910916 def test_int4_weight_only_quant_subclass_api (self , device , dtype ):
917+ if device == "cpu" :
918+ self .skipTest (f"Temporarily skipping for { device } " )
911919 if dtype != torch .bfloat16 :
912920 self .skipTest (f"Fails for { dtype } " )
913921 for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 256 )] if device == 'cuda' else [])):
@@ -923,6 +931,8 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
923931 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
924932 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
925933 def test_int4_weight_only_quant_subclass_api_grouped (self , device , dtype ):
934+ if device == "cpu" :
935+ self .skipTest (f"Temporarily skipping for { device } " )
926936 if dtype != torch .bfloat16 :
927937 self .skipTest (f"Fails for { dtype } " )
928938 for test_shape in ([(256 , 256 , 16 )] + ([(256 , 256 , 8 )] if device == 'cuda' else [])):
0 commit comments