Skip to content

Commit 88d604f

Browse files
authored
Expose FakeQuantizeConfigs in QAT quantizers (#1214)
Summary: This commit exposes the activation and weight FakeQuantizeConfigs in the existing QAT quantizers. These are helpful for implementing advanced functionality based on the quantization schemes represented by these quantizers, such as composing QAT + LoRA. Test Plan: python test/quantization/test_qat.py
1 parent 1fbf788 commit 88d604f

File tree

1 file changed

+74
-27
lines changed

1 file changed

+74
-27
lines changed

torchao/quantization/qat/linear.py

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
107107
return F.linear(x, w)
108108

109109

110+
class _LegacyQATQuantizer(TwoStepQuantizer):
111+
"""
112+
Base class for sharing common methods across legacy QAT quantizers.
113+
"""
114+
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
115+
return None
116+
117+
def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
118+
return None
119+
120+
110121
# =========================================================
111122
# | Linear int8 dynamic activations + int4 weight QAT |
112123
# =========================================================
113124

114125

115-
class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
126+
class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer):
116127
"""
117128
Quantizer for performing QAT on a model, where linear layers have int8
118129
dynamic per token fake quantized activations and int4 fake quantized
@@ -189,6 +200,12 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
189200
else:
190201
self._convert_qat_linear_8da4w(child)
191202

203+
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
204+
return _get_8da4w_activation_config(self.scales_precision)
205+
206+
def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
207+
return _get_8da4w_weight_config(self.groupsize, self.scales_precision)
208+
192209

193210
class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear):
194211
"""
@@ -211,22 +228,8 @@ def __init__(
211228
precision: torch.dtype = torch.float32,
212229
scales_precision: torch.dtype = torch.float32,
213230
) -> None:
214-
activation_config = FakeQuantizeConfig(
215-
dtype=torch.int8,
216-
granularity="per_token",
217-
is_symmetric=False,
218-
is_dynamic=True,
219-
scale_precision=scales_precision,
220-
zero_point_precision=scales_precision,
221-
)
222-
weight_config = FakeQuantizeConfig(
223-
dtype=TorchAODType.INT4,
224-
group_size=groupsize,
225-
is_symmetric=True,
226-
is_dynamic=True,
227-
scale_precision=scales_precision,
228-
zero_point_precision=scales_precision,
229-
)
231+
activation_config = _get_8da4w_activation_config(scales_precision)
232+
weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
230233
super().__init__(
231234
in_features,
232235
out_features,
@@ -261,12 +264,43 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
261264
mod.disable_fake_quant()
262265

263266

267+
def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig:
268+
"""
269+
Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
270+
"""
271+
return FakeQuantizeConfig(
272+
dtype=torch.int8,
273+
granularity="per_token",
274+
is_symmetric=False,
275+
is_dynamic=True,
276+
scale_precision=qparams_precision,
277+
zero_point_precision=qparams_precision,
278+
)
279+
280+
281+
def _get_8da4w_weight_config(
282+
group_size: int,
283+
qparams_precision: torch.dtype,
284+
) -> FakeQuantizeConfig:
285+
"""
286+
Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
287+
"""
288+
return FakeQuantizeConfig(
289+
dtype=TorchAODType.INT4,
290+
group_size=group_size,
291+
is_symmetric=True,
292+
is_dynamic=True,
293+
scale_precision=qparams_precision,
294+
zero_point_precision=qparams_precision,
295+
)
296+
297+
264298
# ===================================
265299
# | Linear int4 weight-only QAT |
266300
# ===================================
267301

268302

269-
class Int4WeightOnlyQATQuantizer(TwoStepQuantizer):
303+
class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer):
270304
"""
271305
Quantizer for performing QAT on a model, where linear layers have
272306
int4 fake quantized grouped per channel weights.
@@ -348,6 +382,9 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
348382
else:
349383
self._convert_qat_linear_4w(child)
350384

385+
def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
386+
return _get_4w_weight_config(self.groupsize, self.scales_precision)
387+
351388

352389
class Int4WeightOnlyQATLinear(FakeQuantizedLinear):
353390
"""
@@ -376,15 +413,7 @@ def __init__(
376413
if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles):
377414
raise ValueError("Padding for QAT 4w is not supported yet")
378415
self.inner_k_tiles = inner_k_tiles
379-
weight_config = FakeQuantizeConfig(
380-
dtype=torch.uint4,
381-
group_size=groupsize,
382-
is_symmetric=False,
383-
is_dynamic=True,
384-
scale_precision=scales_precision,
385-
zero_point_precision=scales_precision,
386-
zero_point_domain=ZeroPointDomain.FLOAT,
387-
)
416+
weight_config = _get_4w_weight_config(groupsize, scales_precision)
388417
super().__init__(
389418
in_features,
390419
out_features,
@@ -417,3 +446,21 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
417446
"""
418447
if isinstance(mod, Int4WeightOnlyQATLinear):
419448
mod.disable_fake_quant()
449+
450+
451+
def _get_4w_weight_config(
452+
group_size: int,
453+
qparams_precision: torch.dtype,
454+
) -> FakeQuantizeConfig:
455+
"""
456+
Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`.
457+
"""
458+
return FakeQuantizeConfig(
459+
dtype=torch.uint4,
460+
group_size=group_size,
461+
is_symmetric=False,
462+
is_dynamic=True,
463+
scale_precision=qparams_precision,
464+
zero_point_precision=qparams_precision,
465+
zero_point_domain=ZeroPointDomain.FLOAT,
466+
)

0 commit comments

Comments
 (0)