Skip to content

Commit 200589b

Browse files
andrewor14pytorchmergebot
authored andcommitted
Add new QAT API through quantize_ (#1415)
**Summary:** This commit adds a new QAT API that can be used with the existing `quantize_`. This is an alternative to the old QAT *Quantizer APIs, which are much less flexible. The new API can be used as follows: ``` from torchao import quantize_ from torchao.quantization.qat import ( FakeQuantizeConfig, intx_quantization_aware_training, ) my_model = ... activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, intx_quantization_aware_training(activation_config, weight_config), ) ``` **Test Plan:** python test/quantization/test_qat.py -k test_quantize_api Pull Request resolved: #1415 Approved by: https://github.com/jerryzh168
1 parent 46b8796 commit 200589b

File tree

7 files changed

+223
-4
lines changed

7 files changed

+223
-4
lines changed

test/quantization/test_qat.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.nn.functional as F
1515
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1616

17+
from torchao import quantize_
1718
from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4
1819
from torchao.quantization.granularity import (
1920
PerAxis,
@@ -24,6 +25,7 @@
2425
from torchao.quantization.qat.api import (
2526
ComposableQATQuantizer,
2627
FakeQuantizeConfig,
28+
intx_quantization_aware_training,
2729
)
2830
from torchao.quantization.qat.embedding import (
2931
FakeQuantizedEmbedding,
@@ -104,6 +106,25 @@ def forward(self, x):
104106
return self.embedding(x)
105107

106108

109+
class M3(torch.nn.Module):
110+
def __init__(self):
111+
super().__init__()
112+
self.embedding = torch.nn.Embedding(10, 512)
113+
self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float)
114+
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
115+
self.relu = torch.nn.ReLU()
116+
117+
def example_inputs(self):
118+
return (torch.randint(1, 10, (1, 512)),)
119+
120+
def forward(self, x):
121+
x = self.embedding(x)
122+
x = self.linear1(x)
123+
x = self.linear2(x)
124+
x = self.relu(x)
125+
return x
126+
127+
107128
class TestQAT(unittest.TestCase):
108129
SEED = 123
109130

@@ -1156,6 +1177,91 @@ def test_qat_prototype_bc(self):
11561177
Int8DynActInt4WeightQATQuantizer,
11571178
)
11581179

1180+
@unittest.skipIf(
1181+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1182+
)
1183+
def test_quantize_api(self):
1184+
"""
1185+
Test that the following:
1186+
1187+
quantize_(model, intx_quantization_aware_training(...))
1188+
1189+
can produce the same results as `ComposableQATQuantizer`.
1190+
"""
1191+
from torchao.quantization.qat import (
1192+
ComposableQATQuantizer,
1193+
Int4WeightOnlyEmbeddingQATQuantizer,
1194+
Int8DynActInt4WeightQATQuantizer,
1195+
)
1196+
1197+
group_size = 16
1198+
torch.manual_seed(self.SEED)
1199+
m = M3()
1200+
baseline_model = copy.deepcopy(m)
1201+
1202+
# Baseline quantizer
1203+
baseline_quantizer = ComposableQATQuantizer(
1204+
[
1205+
Int8DynActInt4WeightQATQuantizer(groupsize=group_size),
1206+
Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size),
1207+
]
1208+
)
1209+
baseline_model = baseline_quantizer.prepare(baseline_model)
1210+
1211+
# quantize_ API
1212+
activation_config = FakeQuantizeConfig(
1213+
torch.int8,
1214+
"per_token",
1215+
is_symmetric=False,
1216+
)
1217+
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1218+
quantize_(
1219+
m,
1220+
intx_quantization_aware_training(activation_config, weight_config),
1221+
)
1222+
quantize_(
1223+
m,
1224+
intx_quantization_aware_training(weight_config=weight_config),
1225+
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
1226+
)
1227+
1228+
# Compare model values
1229+
torch.manual_seed(self.SEED)
1230+
x = m.example_inputs()
1231+
x2 = copy.deepcopy(x)
1232+
out = m(*x)
1233+
baseline_out = baseline_model(*x2)
1234+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1235+
1236+
@unittest.skipIf(
1237+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1238+
)
1239+
def test_quantize_api_errors(self):
1240+
"""
1241+
Test that we throw exceptions with helpful error messages if `quantize_`
1242+
runs into unexpected configurations.
1243+
"""
1244+
my_config = FakeQuantizeConfig(torch.int8, group_size=32)
1245+
m = M3()
1246+
1247+
# Embedding currently only supports weight-only quantization
1248+
with self.assertRaisesRegex(
1249+
ValueError, "Activation fake quantization is not supported for embedding"
1250+
):
1251+
quantize_(
1252+
m,
1253+
intx_quantization_aware_training(my_config, my_config),
1254+
lambda m, _: isinstance(m, torch.nn.Embedding),
1255+
)
1256+
1257+
# Only linear and embedding are supported currently
1258+
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
1259+
quantize_(
1260+
m,
1261+
intx_quantization_aware_training(my_config, my_config),
1262+
lambda m, _: isinstance(m, torch.nn.ReLU),
1263+
)
1264+
11591265

11601266
if __name__ == "__main__":
11611267
unittest.main()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
int8_dynamic_activation_int8_semi_sparse_weight,
5454
int8_dynamic_activation_int8_weight,
5555
int8_weight_only,
56+
intx_quantization_aware_training,
5657
quantize_,
5758
swap_conv2d_1x1_to_linear,
5859
uintx_weight_only,
@@ -103,6 +104,7 @@
103104
"int8_dynamic_activation_int8_semi_sparse_weight",
104105
"int4_weight_only",
105106
"int8_weight_only",
107+
"intx_quantization_aware_training",
106108
"float8_weight_only",
107109
"float8_dynamic_activation_float8_weight",
108110
"float8_static_activation_float8_weight",

torchao/quantization/qat/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .api import (
22
ComposableQATQuantizer,
3+
FakeQuantizeConfig,
4+
intx_quantization_aware_training,
35
)
46
from .embedding import (
57
Int4WeightOnlyEmbeddingQATQuantizer,
@@ -11,7 +13,9 @@
1113

1214
__all__ = [
1315
"ComposableQATQuantizer",
16+
"FakeQuantizeConfig",
1417
"Int4WeightOnlyQATQuantizer",
1518
"Int4WeightOnlyEmbeddingQATQuantizer",
1619
"Int8DynActInt4WeightQATQuantizer",
20+
"intx_quantization_aware_training",
1721
]

torchao/quantization/qat/api.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class FakeQuantizeConfig:
4646
scale_precision: scale dtype (default torch.fp32)
4747
zero_point_precision: zero point dtype (default torch.int32)
4848
zero_point_domain: whether zero point is in integer (default) or float domain
49-
is_dynamic: whether to use dynamic (defualt) or static scale and zero points
49+
is_dynamic: whether to use dynamic (default) or static scale and zero points
5050
range_learning: whether to learn scale and zero points during training (coming soon)
5151
5252
kwargs (optional):
@@ -239,6 +239,62 @@ def __setattr__(self, name: str, value: Any):
239239
super().__setattr__(name, value)
240240

241241

242+
def intx_quantization_aware_training(
243+
activation_config: Optional[FakeQuantizeConfig] = None,
244+
weight_config: Optional[FakeQuantizeConfig] = None,
245+
) -> torch.nn.Module:
246+
"""
247+
Return a function that applies fake quantization to a `torch.nn.Module`.
248+
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
249+
250+
Example usage::
251+
252+
from torchao.quantization import quantize_
253+
from torchao.quantization.qat import FakeQuantizeConfig
254+
activation_config = FakeQuantizeConfig(
255+
torch.int8, "per_token", is_symmetric=False,
256+
)
257+
weight_config = FakeQuantizeConfig(
258+
torch.int4, group_size=32, is_symmetric=True,
259+
)
260+
quantize_(
261+
model,
262+
intx_quantization_aware_training(activation_config, weight_config),
263+
)
264+
265+
Note: If the returned function is applied on a module that is not
266+
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
267+
`torch.nn.Embedding` with an activation config, then we will raise
268+
ValueError as these are not supported.
269+
"""
270+
271+
def _insert_fake_quantize(mod: torch.nn.Module):
272+
"""
273+
Swap the given module with its corresponding fake quantized version.
274+
"""
275+
from .embedding import FakeQuantizedEmbedding
276+
from .linear import FakeQuantizedLinear
277+
278+
if isinstance(mod, torch.nn.Linear):
279+
return FakeQuantizedLinear.from_linear(
280+
mod,
281+
activation_config,
282+
weight_config,
283+
)
284+
elif isinstance(mod, torch.nn.Embedding):
285+
if activation_config is not None:
286+
raise ValueError(
287+
"Activation fake quantization is not supported for embedding"
288+
)
289+
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
290+
else:
291+
raise ValueError(
292+
"Module of type '%s' does not have QAT support" % type(mod)
293+
)
294+
295+
return _insert_fake_quantize
296+
297+
242298
class ComposableQATQuantizer(TwoStepQuantizer):
243299
"""
244300
Composable quantizer that users can use to apply multiple QAT quantizers easily.

torchao/quantization/qat/embedding.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
import torch
1010
import torch.nn.functional as F
1111

12-
from torchao.quantization.quant_api import (
13-
_replace_with_custom_fn_if_matches_filter,
14-
)
1512
from torchao.quantization.quant_primitives import TorchAODType
1613
from torchao.quantization.unified import TwoStepQuantizer
1714
from torchao.quantization.utils import get_group_qparams_symmetric
@@ -85,6 +82,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8582
self.sparse,
8683
)
8784

85+
@classmethod
86+
def from_embedding(
87+
cls,
88+
mod: torch.nn.Embedding,
89+
weight_config: Optional[FakeQuantizeConfig] = None,
90+
):
91+
new_embedding = FakeQuantizedEmbedding(
92+
mod.num_embeddings,
93+
mod.embedding_dim,
94+
mod.padding_idx,
95+
mod.max_norm,
96+
mod.norm_type,
97+
mod.scale_grad_by_freq,
98+
mod.sparse,
99+
weight_config=weight_config,
100+
device=mod.weight.device,
101+
)
102+
# In distributed training, the model may be instantiated
103+
# on the meta device, in which case there is no need to
104+
# copy the weights, and doing so will result in an error
105+
if mod.weight.device != torch.device("meta"):
106+
new_embedding.weight = mod.weight
107+
return new_embedding
108+
88109

89110
# ======================================
90111
# | Embedding int4 weight-only QAT |
@@ -115,6 +136,10 @@ def prepare(
115136
"""
116137
Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`.
117138
"""
139+
# avoid circular imports
140+
from torchao.quantization.quant_api import (
141+
_replace_with_custom_fn_if_matches_filter,
142+
)
118143

119144
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
120145
return isinstance(child, torch.nn.Embedding)

torchao/quantization/qat/linear.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
105105
w = self.weight
106106
return F.linear(x, w)
107107

108+
@classmethod
109+
def from_linear(
110+
cls,
111+
mod: torch.nn.Linear,
112+
activation_config: Optional[FakeQuantizeConfig] = None,
113+
weight_config: Optional[FakeQuantizeConfig] = None,
114+
):
115+
new_linear = FakeQuantizedLinear(
116+
mod.in_features,
117+
mod.out_features,
118+
mod.bias,
119+
activation_config=activation_config,
120+
weight_config=weight_config,
121+
device=mod.weight.device,
122+
)
123+
# In distributed training, the model may be instantiated
124+
# on the meta device, in which case there is no need to
125+
# copy the weights, and doing so will result in an error
126+
if mod.weight.device != torch.device("meta"):
127+
new_linear.weight = mod.weight
128+
return new_linear
129+
108130

109131
class _LegacyQATQuantizer(TwoStepQuantizer):
110132
"""

torchao/quantization/quant_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@
7272
LinearActivationQuantizedTensor,
7373
to_linear_activation_quantized,
7474
)
75+
from .qat import (
76+
intx_quantization_aware_training,
77+
)
7578
from .quant_primitives import (
7679
MappingType,
7780
ZeroPointDomain,
@@ -101,6 +104,7 @@
101104
"int8_dynamic_activation_int8_semi_sparse_weight",
102105
"int4_weight_only",
103106
"int8_weight_only",
107+
"intx_quantization_aware_training",
104108
"float8_weight_only",
105109
"uintx_weight_only",
106110
"fpx_weight_only",

0 commit comments

Comments
 (0)