Skip to content

Commit bc2aaaf

Browse files
authored
Add support for groupwise quantization for int8 weight only quantization (#1121)
Summary: This is to support deprecating torchchat int8 weight only quantization: https://github.com/pytorch/torchchat/blob/ecc628da7c32c486742d92a751ed045b2a2194be/torchchat/utils/quantize.py#L582 Test Plan: python test/integration/test_integration.py -k test_weight_only_groupwise_quant Reviewers: Subscribers: Tasks: Tags:
1 parent 3296749 commit bc2aaaf

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

test/integration/test_integration.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ def _int8wo_api(mod):
109109
else:
110110
change_linear_weights_to_int8_woqtensors(mod)
111111

112+
def _int8wo_groupwise_api(mod):
113+
group_size = 32
114+
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)
115+
112116
def _int8da_int8w_api(mod):
113117
if TORCH_VERSION_AT_LEAST_2_4:
114118
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
@@ -927,6 +931,16 @@ def test_weight_only_quant(self):
927931
sqnr = compute_error(y_ref, y_wo)
928932
self.assertGreater(sqnr, 43.0)
929933

934+
def test_weight_only_groupwise_quant(self):
935+
for x_shape in [[128, 512]]:
936+
x = torch.randn(*x_shape)
937+
m = nn.Sequential(nn.Linear(512, 32))
938+
y_ref = m(x)
939+
_int8wo_groupwise_api(m)
940+
y_wo = m(x)
941+
sqnr = compute_error(y_ref, y_wo)
942+
self.assertGreater(sqnr, 45.0)
943+
930944
@parameterized.expand(COMMON_DEVICE_DTYPE)
931945
@torch.no_grad()
932946
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

torchao/quantization/quant_api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,19 +563,22 @@ def apply_int4_weight_only_quant(weight):
563563
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
564564

565565

566-
def int8_weight_only():
566+
def int8_weight_only(group_size=None):
567567
"""
568568
Applies int8 weight-only symmetric per-channel quantization to linear layers.
569569
"""
570-
def apply_int8wo_quant(weight):
570+
def apply_int8wo_quant(weight, group_size=None):
571571
mapping_type = MappingType.SYMMETRIC
572572
target_dtype = torch.int8
573573
eps = torch.finfo(torch.float32).eps
574574
zero_point_dtype = torch.int64
575+
if group_size is None:
576+
group_size = weight.shape[1]
577+
575578
block_size = (1, weight.shape[1])
576579
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
577580

578-
return _get_linear_subclass_inserter(apply_int8wo_quant)
581+
return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size)
579582

580583
def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
581584
mapping_type = MappingType.SYMMETRIC

0 commit comments

Comments
 (0)