|
14 | 14 | import torch.nn.functional as F |
15 | 15 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 |
16 | 16 |
|
| 17 | +from torchao import quantize_ |
17 | 18 | from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4 |
18 | 19 | from torchao.quantization.granularity import ( |
19 | 20 | PerAxis, |
|
24 | 25 | from torchao.quantization.qat.api import ( |
25 | 26 | ComposableQATQuantizer, |
26 | 27 | FakeQuantizeConfig, |
| 28 | + intx_quantization_aware_training, |
27 | 29 | ) |
28 | 30 | from torchao.quantization.qat.embedding import ( |
29 | 31 | FakeQuantizedEmbedding, |
@@ -104,6 +106,25 @@ def forward(self, x): |
104 | 106 | return self.embedding(x) |
105 | 107 |
|
106 | 108 |
|
| 109 | +class M3(torch.nn.Module): |
| 110 | + def __init__(self): |
| 111 | + super().__init__() |
| 112 | + self.embedding = torch.nn.Embedding(10, 512) |
| 113 | + self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float) |
| 114 | + self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float) |
| 115 | + self.relu = torch.nn.ReLU() |
| 116 | + |
| 117 | + def example_inputs(self): |
| 118 | + return (torch.randint(1, 10, (1, 512)),) |
| 119 | + |
| 120 | + def forward(self, x): |
| 121 | + x = self.embedding(x) |
| 122 | + x = self.linear1(x) |
| 123 | + x = self.linear2(x) |
| 124 | + x = self.relu(x) |
| 125 | + return x |
| 126 | + |
| 127 | + |
107 | 128 | class TestQAT(unittest.TestCase): |
108 | 129 | SEED = 123 |
109 | 130 |
|
@@ -1156,6 +1177,91 @@ def test_qat_prototype_bc(self): |
1156 | 1177 | Int8DynActInt4WeightQATQuantizer, |
1157 | 1178 | ) |
1158 | 1179 |
|
| 1180 | + @unittest.skipIf( |
| 1181 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1182 | + ) |
| 1183 | + def test_quantize_api(self): |
| 1184 | + """ |
| 1185 | + Test that the following: |
| 1186 | +
|
| 1187 | + quantize_(model, intx_quantization_aware_training(...)) |
| 1188 | +
|
| 1189 | + can produce the same results as `ComposableQATQuantizer`. |
| 1190 | + """ |
| 1191 | + from torchao.quantization.qat import ( |
| 1192 | + ComposableQATQuantizer, |
| 1193 | + Int4WeightOnlyEmbeddingQATQuantizer, |
| 1194 | + Int8DynActInt4WeightQATQuantizer, |
| 1195 | + ) |
| 1196 | + |
| 1197 | + group_size = 16 |
| 1198 | + torch.manual_seed(self.SEED) |
| 1199 | + m = M3() |
| 1200 | + baseline_model = copy.deepcopy(m) |
| 1201 | + |
| 1202 | + # Baseline quantizer |
| 1203 | + baseline_quantizer = ComposableQATQuantizer( |
| 1204 | + [ |
| 1205 | + Int8DynActInt4WeightQATQuantizer(groupsize=group_size), |
| 1206 | + Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size), |
| 1207 | + ] |
| 1208 | + ) |
| 1209 | + baseline_model = baseline_quantizer.prepare(baseline_model) |
| 1210 | + |
| 1211 | + # quantize_ API |
| 1212 | + activation_config = FakeQuantizeConfig( |
| 1213 | + torch.int8, |
| 1214 | + "per_token", |
| 1215 | + is_symmetric=False, |
| 1216 | + ) |
| 1217 | + weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) |
| 1218 | + quantize_( |
| 1219 | + m, |
| 1220 | + intx_quantization_aware_training(activation_config, weight_config), |
| 1221 | + ) |
| 1222 | + quantize_( |
| 1223 | + m, |
| 1224 | + intx_quantization_aware_training(weight_config=weight_config), |
| 1225 | + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), |
| 1226 | + ) |
| 1227 | + |
| 1228 | + # Compare model values |
| 1229 | + torch.manual_seed(self.SEED) |
| 1230 | + x = m.example_inputs() |
| 1231 | + x2 = copy.deepcopy(x) |
| 1232 | + out = m(*x) |
| 1233 | + baseline_out = baseline_model(*x2) |
| 1234 | + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) |
| 1235 | + |
| 1236 | + @unittest.skipIf( |
| 1237 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1238 | + ) |
| 1239 | + def test_quantize_api_errors(self): |
| 1240 | + """ |
| 1241 | + Test that we throw exceptions with helpful error messages if `quantize_` |
| 1242 | + runs into unexpected configurations. |
| 1243 | + """ |
| 1244 | + my_config = FakeQuantizeConfig(torch.int8, group_size=32) |
| 1245 | + m = M3() |
| 1246 | + |
| 1247 | + # Embedding currently only supports weight-only quantization |
| 1248 | + with self.assertRaisesRegex( |
| 1249 | + ValueError, "Activation fake quantization is not supported for embedding" |
| 1250 | + ): |
| 1251 | + quantize_( |
| 1252 | + m, |
| 1253 | + intx_quantization_aware_training(my_config, my_config), |
| 1254 | + lambda m, _: isinstance(m, torch.nn.Embedding), |
| 1255 | + ) |
| 1256 | + |
| 1257 | + # Only linear and embedding are supported currently |
| 1258 | + with self.assertRaisesRegex(ValueError, "does not have QAT support"): |
| 1259 | + quantize_( |
| 1260 | + m, |
| 1261 | + intx_quantization_aware_training(my_config, my_config), |
| 1262 | + lambda m, _: isinstance(m, torch.nn.ReLU), |
| 1263 | + ) |
| 1264 | + |
1159 | 1265 |
|
1160 | 1266 | if __name__ == "__main__": |
1161 | 1267 | unittest.main() |
0 commit comments