@@ -107,12 +107,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
107107 return F .linear (x , w )
108108
109109
110+ class _LegacyQATQuantizer (TwoStepQuantizer ):
111+ """
112+ Base class for sharing common methods across legacy QAT quantizers.
113+ """
114+ def get_activation_fake_quantize_config (self ) -> Optional [FakeQuantizeConfig ]:
115+ return None
116+
117+ def get_weight_fake_quantize_config (self ) -> Optional [FakeQuantizeConfig ]:
118+ return None
119+
120+
110121# =========================================================
111122# | Linear int8 dynamic activations + int4 weight QAT |
112123# =========================================================
113124
114125
115- class Int8DynActInt4WeightQATQuantizer (TwoStepQuantizer ):
126+ class Int8DynActInt4WeightQATQuantizer (_LegacyQATQuantizer ):
116127 """
117128 Quantizer for performing QAT on a model, where linear layers have int8
118129 dynamic per token fake quantized activations and int4 fake quantized
@@ -189,6 +200,12 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
189200 else :
190201 self ._convert_qat_linear_8da4w (child )
191202
203+ def get_activation_fake_quantize_config (self ) -> Optional [FakeQuantizeConfig ]:
204+ return _get_8da4w_activation_config (self .scales_precision )
205+
206+ def get_weight_fake_quantize_config (self ) -> Optional [FakeQuantizeConfig ]:
207+ return _get_8da4w_weight_config (self .groupsize , self .scales_precision )
208+
192209
193210class Int8DynActInt4WeightQATLinear (FakeQuantizedLinear ):
194211 """
@@ -211,22 +228,8 @@ def __init__(
211228 precision : torch .dtype = torch .float32 ,
212229 scales_precision : torch .dtype = torch .float32 ,
213230 ) -> None :
214- activation_config = FakeQuantizeConfig (
215- dtype = torch .int8 ,
216- granularity = "per_token" ,
217- is_symmetric = False ,
218- is_dynamic = True ,
219- scale_precision = scales_precision ,
220- zero_point_precision = scales_precision ,
221- )
222- weight_config = FakeQuantizeConfig (
223- dtype = TorchAODType .INT4 ,
224- group_size = groupsize ,
225- is_symmetric = True ,
226- is_dynamic = True ,
227- scale_precision = scales_precision ,
228- zero_point_precision = scales_precision ,
229- )
231+ activation_config = _get_8da4w_activation_config (scales_precision )
232+ weight_config = _get_8da4w_weight_config (groupsize , scales_precision )
230233 super ().__init__ (
231234 in_features ,
232235 out_features ,
@@ -261,12 +264,43 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
261264 mod .disable_fake_quant ()
262265
263266
267+ def _get_8da4w_activation_config (qparams_precision : torch .dtype ) -> FakeQuantizeConfig :
268+ """
269+ Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
270+ """
271+ return FakeQuantizeConfig (
272+ dtype = torch .int8 ,
273+ granularity = "per_token" ,
274+ is_symmetric = False ,
275+ is_dynamic = True ,
276+ scale_precision = qparams_precision ,
277+ zero_point_precision = qparams_precision ,
278+ )
279+
280+
281+ def _get_8da4w_weight_config (
282+ group_size : int ,
283+ qparams_precision : torch .dtype ,
284+ ) -> FakeQuantizeConfig :
285+ """
286+ Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
287+ """
288+ return FakeQuantizeConfig (
289+ dtype = TorchAODType .INT4 ,
290+ group_size = group_size ,
291+ is_symmetric = True ,
292+ is_dynamic = True ,
293+ scale_precision = qparams_precision ,
294+ zero_point_precision = qparams_precision ,
295+ )
296+
297+
264298# ===================================
265299# | Linear int4 weight-only QAT |
266300# ===================================
267301
268302
269- class Int4WeightOnlyQATQuantizer (TwoStepQuantizer ):
303+ class Int4WeightOnlyQATQuantizer (_LegacyQATQuantizer ):
270304 """
271305 Quantizer for performing QAT on a model, where linear layers have
272306 int4 fake quantized grouped per channel weights.
@@ -348,6 +382,9 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
348382 else :
349383 self ._convert_qat_linear_4w (child )
350384
385+ def get_weight_fake_quantize_config (self ) -> Optional [FakeQuantizeConfig ]:
386+ return _get_4w_weight_config (self .groupsize , self .scales_precision )
387+
351388
352389class Int4WeightOnlyQATLinear (FakeQuantizedLinear ):
353390 """
@@ -376,15 +413,7 @@ def __init__(
376413 if not _check_linear_int4_k (in_features , groupsize , inner_k_tiles ):
377414 raise ValueError ("Padding for QAT 4w is not supported yet" )
378415 self .inner_k_tiles = inner_k_tiles
379- weight_config = FakeQuantizeConfig (
380- dtype = torch .uint4 ,
381- group_size = groupsize ,
382- is_symmetric = False ,
383- is_dynamic = True ,
384- scale_precision = scales_precision ,
385- zero_point_precision = scales_precision ,
386- zero_point_domain = ZeroPointDomain .FLOAT ,
387- )
416+ weight_config = _get_4w_weight_config (groupsize , scales_precision )
388417 super ().__init__ (
389418 in_features ,
390419 out_features ,
@@ -417,3 +446,21 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
417446 """
418447 if isinstance (mod , Int4WeightOnlyQATLinear ):
419448 mod .disable_fake_quant ()
449+
450+
451+ def _get_4w_weight_config (
452+ group_size : int ,
453+ qparams_precision : torch .dtype ,
454+ ) -> FakeQuantizeConfig :
455+ """
456+ Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`.
457+ """
458+ return FakeQuantizeConfig (
459+ dtype = torch .uint4 ,
460+ group_size = group_size ,
461+ is_symmetric = False ,
462+ is_dynamic = True ,
463+ scale_precision = qparams_precision ,
464+ zero_point_precision = qparams_precision ,
465+ zero_point_domain = ZeroPointDomain .FLOAT ,
466+ )
0 commit comments