diff --git a/ruff.toml b/ruff.toml index 09d0a1ec97..13026345a6 100644 --- a/ruff.toml +++ b/ruff.toml @@ -9,7 +9,7 @@ include = [ "torchao/sparsity/**/*.py", "torchao/prototype/low_bit_optim/**.py", "test/float8/**/*.py", - "test/quantization/test_observer.py", + "test/quantization/**/*.py", "test/dtypes/**/*.py", "test/prototype/low_bit_optim/**.py", "torchao/utils.py", diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 37709c4128..3eb9b0a2c5 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -3,13 +3,16 @@ import pytest # Skip entire test if triton is not available, otherwise CI failure -try: - import triton -except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) - -from bitsandbytes.functional import create_dynamic_map, quantize_blockwise, dequantize_blockwise +try: # noqa: F401 + import triton # noqa: F401 +except ImportError: # noqa: F401 + pytest.skip("triton is not installed", allow_module_level=True) # noqa: F401 import torch +from bitsandbytes.functional import ( + create_dynamic_map, + dequantize_blockwise, + quantize_blockwise, +) from torchao.prototype.galore.kernels import ( triton_dequant_blockwise, diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 0dcaaf9c8c..ebdf2281e0 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -1,4 +1,5 @@ import copy +import unittest import pytest import torch @@ -19,9 +20,12 @@ choose_qparams_and_quantize_affine_qqq, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode -import unittest -@unittest.skipIf(is_fbcode(), "Skipping the test in fbcode since we don't have TARGET file for kernels") + +@unittest.skipIf( + is_fbcode(), + "Skipping the test in fbcode since we don't have TARGET file for kernels", +) class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 29f833c9ab..3a998635aa 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -13,9 +13,8 @@ import torch import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 -from torchao.dtypes import ( - TensorCoreTiledLayout, -) + +from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4 from torchao.quantization.granularity import ( PerAxis, PerGroup, @@ -26,33 +25,26 @@ ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, -) from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, + Int4WeightOnlyQATLinear, Int8DynActInt4WeightQATLinear, - Int4WeightOnlyQATLinear ) from torchao.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, - _get_qmin_qmax, _GenericFakeQuantize, -) -from torchao.quantization.quant_api import ( - int4_weight_only, - quantize_, + _get_qmin_qmax, ) from torchao.quantization.quant_primitives import ( - fake_quantize_affine, MappingType, TorchAODType, ZeroPointDomain, + fake_quantize_affine, ) from torchao.quantization.unified import ( TwoStepQuantizer, @@ -65,17 +57,12 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, -) - -from torchao.quantization.GPTQ import ( - _replace_linear_8da4w, - _replace_linear_int4 ) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() + class Sub(torch.nn.Module): def __init__(self): super().__init__() @@ -87,6 +74,7 @@ def example_inputs(self): def forward(self, x): return self.linear(x) + class M(torch.nn.Module): def __init__(self): super().__init__() @@ -103,6 +91,7 @@ def forward(self, x): x = self.linear2(x) return x + class M2(torch.nn.Module): def __init__(self): super().__init__() @@ -118,7 +107,9 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) @@ -132,20 +123,40 @@ def test_fake_quantize_per_channel_group(self): # fake quant op out = _fake_quantize_per_channel_group( - x, s, zp, qmin, qmax, group_size, + x, + s, + zp, + qmin, + qmax, + group_size, ) out.sum().backward() # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group( - x2, s, zp, qmin, qmax, torch.int8, group_size, + x2, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group( - out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_token(self): (qmin, qmax) = _get_qmin_qmax(8) @@ -161,10 +172,21 @@ def test_fake_quantize_per_token(self): # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_token( - x2, s, zp, qmin, qmax, torch.int8, + x2, + s, + zp, + qmin, + qmax, + torch.int8, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_token( - out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) @@ -182,9 +204,10 @@ def _set_ptq_weight( WeightOnlyInt4Linear, ) from torchao.quantization.qat.linear import ( - Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, + Int8DynActInt4WeightQATLinear, ) + n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) group_size = qat_linear.weight_fake_quantizer.config.group_size @@ -193,7 +216,13 @@ def _set_ptq_weight( fp32_weight = qat_linear.weight (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + fp32_weight, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) ptq_linear.weight = q_weight ptq_linear.scales = s @@ -201,28 +230,39 @@ def _set_ptq_weight( elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - qat_linear.weight, n_bit, group_size, + qat_linear.weight, + n_bit, + group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), qat_linear.inner_k_tiles, + q_weight.to("cuda"), + qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight ptq_linear.scales_and_zeros = scales_and_zeros else: raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_linear(self): - from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear group_size = 128 torch.manual_seed(self.SEED) qat_linear = Int8DynActInt4WeightQATLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) ptq_linear = Int8DynActInt4WeightLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) # Force the weights to be the same @@ -236,10 +276,12 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer(self): - from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer group_size = 16 torch.manual_seed(self.SEED) @@ -268,9 +310,13 @@ def test_qat_8da4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -282,7 +328,9 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -341,7 +389,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -363,8 +413,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) # Simulate training for both models - optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) - optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer1 = torch.optim.SGD( + nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + optimizer2 = torch.optim.SGD( + qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn1 = torch.nn.CrossEntropyLoss() loss_fn2 = torch.nn.CrossEntropyLoss() example_inputs = nn_model.example_inputs() @@ -382,9 +436,15 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + torch.testing.assert_close( + nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0 + ) def _test_qat_quantized_gradients(self, quantizer): """ @@ -394,7 +454,9 @@ def _test_qat_quantized_gradients(self, quantizer): torch.manual_seed(self.SEED) m = M() model = quantizer.prepare(m) - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer = torch.optim.SGD( + model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn = torch.nn.CrossEntropyLoss() # Simulate training @@ -426,13 +488,18 @@ def _test_qat_quantized_gradients(self, quantizer): optimizer.step() current_step += 1 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_gradients(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_generic_fake_quantize(self): """ Test that the generic fake quantize used in 8da4w QAT matches @@ -443,7 +510,9 @@ def test_qat_generic_fake_quantize(self): py_input = torch.randn(16, 64).float().requires_grad_() py_s = torch.randn(16).float() py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) - py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax) + py_out = torch.fake_quantize_per_channel_affine( + py_input, py_s, py_zp, 0, qmin, qmax + ) py_out.sum().backward() ao_input = copy.deepcopy(py_input) @@ -451,7 +520,9 @@ def test_qat_generic_fake_quantize(self): block_size = (1, ao_input.shape[-1]) ao_s = copy.deepcopy(py_s) ao_zp = copy.deepcopy(py_zp) - ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax) + ao_out = _GenericFakeQuantize.apply( + ao_input, block_size, ao_s, ao_zp, qmin, qmax + ) ao_out.sum().backward() torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) @@ -485,10 +556,14 @@ def test_qat_4w_primitives(self): # PTQ (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(device), inner_k_tiles, + q_weight.to(device), + inner_k_tiles, ) ptq_out = torch.ops.aten._weight_int4pack_mm( x, q_weight, group_size, scales_and_zeros @@ -497,9 +572,12 @@ def test_qat_4w_primitives(self): # QAT block_size = (1, group_size) quant_min = 0 - quant_max = 2 ** n_bit - 1 + quant_max = 2**n_bit - 1 scales, zero_points = get_groupwise_affine_qparams( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) w_fq = fake_quantize_affine( weight, @@ -509,27 +587,37 @@ def test_qat_4w_primitives(self): torch.int32, quant_min, quant_max, - zero_point_domain = ZeroPointDomain.FLOAT, + zero_point_domain=ZeroPointDomain.FLOAT, ) qat_out = torch.nn.functional.linear(x, w_fq) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 device = torch.device("cuda") dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) ptq_linear = WeightOnlyInt4Linear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) # Force the weights to be the same @@ -543,17 +631,22 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer + quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 @@ -563,10 +656,12 @@ def test_qat_4w_quantizer(self): m = M().to(device).to(dtype) m2 = copy.deepcopy(m) qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) @@ -589,13 +684,16 @@ def test_qat_4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) class _MyQATQuantizer(TwoStepQuantizer): """ Dummy quantizer that attaches a certain value to each nn.Linear's `_temp_quantizer_values` attribute. """ + ATTR_NAME = "_temp_quantizer_values" def __init__(self, value: str): @@ -626,19 +724,24 @@ def test_composable_qat_quantizer(self): self.assertEqual(values_list, ["quantizer1", "quantizer2"]) composable_quantizer.convert(model) values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) - self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"]) + self.assertEqual( + values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"] + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_embedding(self): from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer + model = M2() x = model.example_inputs() - out = model(*x) + model(*x) quantizer = Int4WeightOnlyEmbeddingQATQuantizer() prepared = quantizer.prepare(model) - prepared_out = prepared(*x) + prepared(*x) converted = quantizer.convert(model) - converted_out = converted(*x) + converted(*x) def test_fake_quantize_config_granularity(self): """ @@ -685,7 +788,9 @@ def test_fake_quantize_config_granularity_error_cases(self): Test incorrect settings of `FakeQuantizeConfig`'s granularity. """ # no granularity provided - with self.assertRaisesRegex(ValueError, "`granularity` or `group_size` must be set"): + with self.assertRaisesRegex( + ValueError, "`granularity` or `group_size` must be set" + ): FakeQuantizeConfig(torch.int8) # group_size with conflicting granularity @@ -718,8 +823,12 @@ def test_fake_quantize_config_mapping_type(self): """ # symmetric symmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token") - symmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=True) - symmetric_config3 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC) + symmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=True + ) + symmetric_config3 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC + ) self.assertEqual(symmetric_config1.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config2.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config3.mapping_type, MappingType.SYMMETRIC) @@ -728,8 +837,12 @@ def test_fake_quantize_config_mapping_type(self): self.assertTrue(symmetric_config3.is_symmetric) # asymmetric - asymmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - asymmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.ASYMMETRIC) + asymmetric_config1 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ) + asymmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.ASYMMETRIC + ) self.assertEqual(asymmetric_config1.mapping_type, MappingType.ASYMMETRIC) self.assertEqual(asymmetric_config2.mapping_type, MappingType.ASYMMETRIC) self.assertFalse(asymmetric_config1.is_symmetric) @@ -743,11 +856,15 @@ def test_fake_quantize_config_mapping_type(self): # bad config1: both mapping_type and is_symmetric are set msg = "Cannot set both `mapping_type` and `is_symmetric`" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False + ) # bad config2: not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR + ) def test_fake_quantize_config_dtype(self): """ @@ -781,7 +898,9 @@ def test_fake_quantize_config_dtype(self): FakeQuantizeConfig(TorchAODType.INT7, "per_token") FakeQuantizeConfig(torch.int8, "per_token") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_8da4w(self): """ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. @@ -792,7 +911,9 @@ def test_fake_quantized_linear_8da4w(self): 256, 688, bias=False, - activation_config=FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False), + activation_config=FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ), weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), ) @@ -801,7 +922,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant. """ # activations - (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) + (s, zp) = _choose_qparams_per_token_asymmetric( + x, torch.float32, torch.int32 + ) (qmin, qmax) = _get_qmin_qmax(8) x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax) @@ -809,7 +932,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.linear(x_fq, w_fq) # Compare linear values @@ -820,7 +945,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_8da4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_4w(self): """ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. @@ -849,7 +976,13 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) w_fq = _fake_quantize_per_channel_group( - weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT, + weight, + s, + zp, + qmin, + qmax, + group_size, + zero_point_domain=ZeroPointDomain.FLOAT, ) return F.linear(x, w_fq) @@ -860,50 +993,78 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: fq_out = fq_linear(x) baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_8da4w(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(isinstance(module[0], Int8DynActInt4WeightQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int8DynActInt4WeightQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_int4(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(isinstance(module[0], Int4WeightOnlyQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int4WeightOnlyQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_embedding_4w(self): """ Test that we can express int4 per group symmetric weight only fake quantization @@ -926,7 +1087,9 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.embedding(x, w_fq) # Compare embedding values @@ -937,59 +1100,15 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_prototype_bc(self): """ Just to make sure we can import all the old prototype paths. We will remove this test in the near future when we actually break BC. """ - from torchao.quantization.prototype.qat import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - ComposableQATQuantizer, - Int8DynActInt4WeightQATLinear, - Int4WeightOnlyEmbeddingQATQuantizer, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, - ) - from torchao.quantization.prototype.qat._module_swap_api import ( - disable_4w_fake_quant_module_swap, - enable_4w_fake_quant_module_swap, - disable_8da4w_fake_quant_module_swap, - enable_8da4w_fake_quant_module_swap, - Int4WeightOnlyQATQuantizerModuleSwap, - Int8DynActInt4WeightQATQuantizerModuleSwap, - ) - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - to_affine_fake_quantized, - ) - from torchao.quantization.prototype.qat.api import ( - ComposableQATQuantizer, - FakeQuantizeConfig, - ) - from torchao.quantization.prototype.qat.embedding import ( - FakeQuantizedEmbedding, - Int4WeightOnlyEmbeddingQATQuantizer, - Int4WeightOnlyEmbedding, - Int4WeightOnlyQATEmbedding, - ) - from torchao.quantization.prototype.qat.fake_quantizer import ( - FakeQuantizer, - ) - from torchao.quantization.prototype.qat.linear import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - FakeQuantizedLinear, - Int4WeightOnlyQATLinear, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATLinear, - Int8DynActInt4WeightQATQuantizer, - ) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 458cd07810..eb5f1337d1 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -6,81 +6,86 @@ # mypy: ignore-errors # This test takes a long time to run +import copy +import gc +import tempfile import unittest +from pathlib import Path + import torch -import os from torch.ao.quantization.quantize_pt2e import ( - prepare_pt2e, convert_pt2e, + prepare_pt2e, ) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import TestCase -import torchao +from torchao import quantize_ +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import ( AffineQuantizedTensor, ) from torchao.quantization import ( LinearActivationQuantizedTensor, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) -from torchao.quantization.subclass import ( - Int8WeightOnlyQuantizedLinearWeight, - Int4WeightOnlyQuantizedLinearWeight, -) -from torchao import quantize_ from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, Quantizer, TwoStepQuantizer, - int8_dynamic_activation_int4_weight, + _replace_with_custom_fn_if_matches_filter, int4_weight_only, - int8_weight_only, + int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, + int8_weight_only, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) +from torchao.quantization.subclass import ( + Int4WeightOnlyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + unwrap_tensor_subclass, ) -from pathlib import Path -from torchao._models.llama.tokenizer import get_tokenizer -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao.utils import unwrap_tensor_subclass -import copy -import tempfile -import gc -from torch.testing._internal.common_utils import TestCase -from torch.testing._internal import common_utils def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs).module() - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) return m + def capture_and_prepare(model, example_inputs): m = torch.export.export(model, example_inputs) - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this m(*example_inputs) return m -class XNNPackDynamicQuantizer(TwoStepQuantizer): +class XNNPackDynamicQuantizer(TwoStepQuantizer): def prepare(self, model: torch.nn.Module) -> torch.nn.Module: _replace_with_custom_fn_if_matches_filter( model, - lambda linear_mod: capture_and_prepare(linear_mod, (torch.randn(1, linear_mod.in_features))), + lambda linear_mod: capture_and_prepare( + linear_mod, (torch.randn(1, linear_mod.in_features)) + ), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) return model @@ -93,11 +98,13 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: ) return model + class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: quantize_(model, int8_dynamic_activation_int8_weight()) return model + class ToyLinearModel(torch.nn.Module): def __init__(self, m=64, n=32, k=64): super().__init__() @@ -105,7 +112,11 @@ def __init__(self, m=64, n=32, k=64): self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) def forward(self, x): x = self.linear1(x) @@ -118,9 +129,11 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs The deprecated implementation for int8 dynamic quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _in_features_greater_than_16 - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import ( + _get_subclass_inserter, + _in_features_greater_than_16, + _is_linear, + ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight if filter_fn is None: @@ -129,37 +142,49 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs ) _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + model, + _get_subclass_inserter( + Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs + ), + filter_fn, ) + def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ The deprecated implementation for weight only quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs), + _get_subclass_inserter( + deprecated_tenosr_subclass, enable_parametrization=True, **kwargs + ), filter_fn, ) return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) + +_ref_change_linear_weights_to_int8_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) +) +_ref_change_linear_weights_to_int4_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) +) + class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() quantize_(m, int8_dynamic_activation_int8_weight()) - quantized = m(*example_inputs) + m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -182,7 +207,9 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) - @unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!") + @unittest.skip( + "FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!" + ) def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): quantizer = TorchCompileDynamicQuantizer() m = ToyLinearModel().eval() @@ -196,10 +223,8 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): - from torchao.quantization.quant_api import ( - change_linear_weights_to_int8_woqtensors, - ) m = ToyLinearModel().eval().cpu() + def api(model): quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) @@ -223,10 +248,12 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" + ) def test_8da4w_quantizer(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) m = ToyLinearModel().eval() @@ -242,8 +269,9 @@ def test_8da4w_quantizer(self): # https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_8da4w_gptq_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer from torchao._models._eval import InputRecorder, TransformerEvalWrapper + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer + # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -268,16 +296,20 @@ def test_8da4w_gptq_quantizer(self): input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int8DynActInt4WeightGPTQQuantizer( blocksize, @@ -287,7 +319,7 @@ def test_8da4w_gptq_quantizer(self): ) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -298,15 +330,17 @@ def test_8da4w_gptq_quantizer(self): 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.88, ( - f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.88 + ), f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower" + ) def test_8da4w_quantizer_eval(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer precision = torch.bfloat16 device = "cpu" @@ -325,7 +359,7 @@ def test_8da4w_quantizer_eval(self): quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision) q_model = quantizer.quantize(model) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( q_model, tokenizer, q_model.config.block_size, @@ -335,14 +369,18 @@ def test_8da4w_quantizer_eval(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4_weight_only(self): + from torchao._models._eval import ( + MultiTensorInputRecorder, + TransformerEvalWrapper, + ) from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer - from torchao._models._eval import MultiTensorInputRecorder, TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -367,18 +405,21 @@ def test_gptq_quantizer_int4_weight_only(self): calibration_seq_length = 100 input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = MultiTensorInputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - device="cpu", - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() - + inputs = ( + MultiTensorInputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int4WeightOnlyGPTQQuantizer( blocksize, @@ -398,14 +439,15 @@ def test_gptq_quantizer_int4_weight_only(self): ["wikitext"], None, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_quantizer_int4_weight_only(self): - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -435,13 +477,14 @@ def test_quantizer_int4_weight_only(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -456,7 +499,7 @@ def test_eval_wrapper(self): tokenizer_path, "Llama-2-7b-chat-hf", ) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -466,17 +509,20 @@ def test_eval_wrapper(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none']<7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" - checkpoint_path = Path(".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth") + checkpoint_path = Path( + ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" + ) model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) @@ -498,30 +544,43 @@ def test_eval_wrapper_llama3(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @common_utils.parametrize("mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR]) + @common_utils.parametrize( + "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR] + ) def test_quantized_tensor_subclass_8da4w(self, mapping_type): group_size = 32 m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size, mapping_type=mapping_type)) + quantize_( + m, + int8_dynamic_activation_int4_weight( + group_size=group_size, mapping_type=mapping_type + ), + ) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=mapping_type) + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=group_size, mapping_type=mapping_type + ) m_copy = quantizer.quantize(m_copy) assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) @@ -552,7 +611,6 @@ def test_quantized_tensor_subclass_int4(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_wo(self): @@ -568,13 +626,11 @@ def test_quantized_tensor_subclass_int8_wo(self): # reference _ref_change_linear_weights_to_int8_woqtensors(m_copy) - res = m(*example_inputs) ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") @@ -583,13 +639,19 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") + example_inputs = m.example_inputs( + batch_size=20, dtype=torch.bfloat16, device="cuda" + ) quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference _ref_change_linear_weights_to_int8_dqtensors(m_copy) @@ -601,6 +663,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # workaround for export path from torchao.utils import unwrap_tensor_subclass + m_unwrapped = unwrap_tensor_subclass(m) m = torch.export.export(m_unwrapped, example_inputs).module() @@ -630,12 +693,10 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") quantize_(m, int8_weight_only()) @@ -654,7 +715,6 @@ def test_int4wo_quantized_model_to_device(self): devices = ["cuda", "cuda:0"] for device in devices: m = ToyLinearModel().eval().to(torch.bfloat16).to(device) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) quantize_(m, int4_weight_only()) @@ -678,7 +738,7 @@ def test_quantized_tensor_subclass_save_load_map_location(self): f.seek(0) state_dict = torch.load(f.name, map_location="cpu", mmap=True) - with torch.device('meta'): + with torch.device("meta"): m_copy = ToyLinearModel().eval() m_copy.load_state_dict(state_dict, assign=True) @@ -710,12 +770,13 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) -class TestMultiTensorFlow(TestCase): +class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) tensor2 = torch.randn(3, 3) mt = MultiTensor(tensor1) @@ -728,6 +789,7 @@ def test_multitensor_add_tensors(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) mt = MultiTensor(tensor1) mt.pad_to_length(3) @@ -739,14 +801,13 @@ def test_multitensor_pad_unpad(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.ones(3, 3) mt = MultiTensor(tensor1) mt += 1 # In-place addition self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2))) - - common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 78556772d1..a3fef29fea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -7,25 +7,27 @@ # mypy: ignore-errors # This test takes a long time to run import unittest + import torch + +from torchao.dtypes.utils import is_device from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + choose_qparams_affine, + dequantize_affine, fake_quantize_affine, fake_quantize_affine_cachemask, quantize_affine, - dequantize_affine, - choose_qparams_affine, - MappingType, - ZeroPointDomain, ) + # TODO: remove test for utils? from torchao.quantization.utils import ( get_group_qparams_symmetric, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor_from_qparams, groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor_from_qparams, quantize_activation_per_token_absmax, ) - from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -33,11 +35,11 @@ TORCH_VERSION_AT_LEAST_2_6, is_fbcode, ) -from torchao.dtypes.utils import is_device _SEED = 1234 torch.manual_seed(_SEED) + # Helper function to run a function twice # and verify that the result is the same. # Adds some verification to avoid side effects. @@ -48,9 +50,12 @@ def check_idempotent(self, fn, *args, **kwargs): output0 = fn(*args, **kwargs) assert torch.is_tensor(output0) output1 = fn(*args, **kwargs) - self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.") + self.assertTrue( + torch.equal(output0, output1), f"Expected given function {fn} to be idempotent." + ) return output1 + # Legacy tinygemm ops def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): if groupsize > w.shape[-1]: @@ -71,6 +76,7 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1 dtype=dtype ).reshape(w.shape[0], -1) + def _groupwise_affine_quantize_tensor_from_qparams( w, scales, @@ -108,6 +114,7 @@ def _groupwise_affine_quantize_tensor_from_qparams( return w_int4x8 + def _groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, scales, @@ -138,7 +145,9 @@ def _groupwise_affine_dequantize_tensor_from_qparams( class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as @@ -147,7 +156,6 @@ def test_get_group_qparams_symmetric(self): n_bit = 4 qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1 - eps = torch.finfo(torch.float32).eps groupsize = 256 torch.manual_seed(self.SEED) weight = torch.randn(100, 256).to(torch.float16) @@ -160,14 +168,16 @@ def test_get_group_qparams_symmetric(self): quant_max=qmax, # This is needed to ensure `min_val` and `max_val` are fp16, # otherwise they default to fp32 and the qparams will be slightly off - factory_kwargs={"dtype": torch.float16} + factory_kwargs={"dtype": torch.float16}, ) obs(weight) (scale_obs, _) = obs.calculate_qparams() scale_obs = scale_obs.reshape(weight.shape[0], -1) # assert that scales are identical - (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16) + (scale_ao, _) = get_group_qparams_symmetric( + weight, n_bit, groupsize, precision=torch.float16 + ) torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0) def test_choose_qparams_group_sym(self): @@ -180,9 +190,19 @@ def test_choose_qparams_group_sym(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) @@ -197,13 +217,26 @@ def test_choose_qparams_group_sym_no_clipping_err(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) @@ -211,11 +244,29 @@ def test_choose_qparams_token_asym(self): dtype = torch.int8 block_size = (1, 10) if TORCH_VERSION_AT_LEAST_2_6: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float64, + zero_point_dtype=torch.int64, + ) else: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + ) - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) + scale_ref, zp_ref = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + input, dtype + ) + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -229,12 +280,15 @@ def test_choose_qparams_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -248,18 +302,24 @@ def test_choose_qparams_tensor_sym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -272,21 +332,35 @@ def test_quantize_activation_per_token_abs_max(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(scale, scale_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_zero_input(self): input = torch.zeros(10, 10) # make sure it still works quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_dtype(self): input = torch.zeros(10, 10, dtype=torch.bfloat16) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -300,18 +374,30 @@ def test_quantize_activation_per_token_abs_max_dtype(self): quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC dtype = torch.int8 block_size = (1, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) group_size = 2 quant_min = -128 @@ -320,23 +406,43 @@ def test_quantize_dequantize_group_sym(self): input, scale, zero_point, quant_min, quant_max, torch.int8, group_size ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + group_size, + output_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 1) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) output_dtype = torch.float32 quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) axis = 1 quant_min = -128 @@ -345,12 +451,21 @@ def test_quantize_dequantize_channel_asym(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) @@ -358,32 +473,61 @@ def test_quantize_dequantize_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) output_dtype = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) - axis = 1 quant_min = -128 quant_max = 127 quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor( input, scale, zero_point, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 1, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) axis = 2 quant_min = -128 @@ -392,20 +536,40 @@ def test_quantize_dequantize_channel_asym_4d(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 2, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02) @@ -414,11 +578,15 @@ def test_choose_qparams_tensor_asym_eps(self): mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) eps = torch.finfo(torch.float32).eps self.assertEqual(scale, eps) - @unittest.skipIf(not torch.cuda.is_available(), "skipping when cuda is not available") + @unittest.skipIf( + not torch.cuda.is_available(), "skipping when cuda is not available" + ) def test_get_group_qparams_symmetric_memory(self): """Check the memory usage of the op""" weight = torch.randn(1024, 1024).to(device="cuda") @@ -430,18 +598,20 @@ def test_get_group_qparams_symmetric_memory(self): self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use) def test_raises(self): - """Make sure some errors are raised when user requested an unsupported type of quantization - """ + """Make sure some errors are raised when user requested an unsupported type of quantization""" input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) # make sure we can't quantize int32 tensors: with self.assertRaisesRegex(AssertionError, "Unsupported input dtype:"): - _ = quantize_affine(input.to(torch.int32), block_size, scale, zero_point, dtype) + _ = quantize_affine( + input.to(torch.int32), block_size, scale, zero_point, dtype + ) # block_size and scale/zero_point shape mismatch block_size = (1, 1) @@ -460,7 +630,10 @@ def test_not_preserve_zero_not_supported(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - with self.assertRaisesRegex(ValueError, "preserve_zero == False is not supported for symmetric quantization"): + with self.assertRaisesRegex( + ValueError, + "preserve_zero == False is not supported for symmetric quantization", + ): choose_qparams_affine( input, mapping_type, @@ -474,11 +647,12 @@ def test_not_preserve_zero_not_supported(self): preserve_zero=False, ) - def test_get_groupwise_affine_qparams(self): input = torch.randn(10, 256) n_bit = 4 - scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) + scale_ref, zero_point_ref = _get_groupwise_affine_qparams( + input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16 + ) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 @@ -488,20 +662,19 @@ def test_get_groupwise_affine_qparams(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - scale, zero_point = \ - choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - quant_min, - quant_max, - eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=False, + zero_point_domain=ZeroPointDomain.FLOAT, + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zero_point_ref)) @@ -513,8 +686,12 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self): n_bit = 4 groupsize = 128 - w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) @@ -529,14 +706,22 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): input_tmp = input if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize + ) else: - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine(self): input = torch.randn(10, 10) @@ -548,14 +733,31 @@ def test_fake_quantize_affine(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) - fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + fake_quantized = fake_quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) torch.testing.assert_close(dequantized, fake_quantized) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) @@ -567,16 +769,36 @@ def test_fake_quantize_affine_cachemask(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) (fake_quantized, mask) = fake_quantize_affine_cachemask( - input, block_size, scale, zero_point, dtype, quant_min, quant_max, + input, + block_size, + scale, + zero_point, + dtype, + quant_min, + quant_max, ) expected_mask = torch.full(input.shape, True) torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + if __name__ == "__main__": unittest.main()