Skip to content

Commit 05d4086

Browse files
committed
Move QAT out of prototype
Summary: Move QAT out of prototype so we can provide stronger BC guarantees moving forward. **BC-breaking notes** Before: ``` from torchao.quantization.prototype.qat import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.prototype.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.prototype.qat.fake_quantizer import ( FakeQuantizer, ) ``` After: ``` from torchao.quantization.qat import ( ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, ) from torchao.quantization.qat.linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) ``` Test Plan: python test/quantization/test_qat.py ghstack-source-id: 9ebc45d Pull Request resolved: #1091
1 parent 0b71b8d commit 05d4086

File tree

13 files changed

+21
-42
lines changed

13 files changed

+21
-42
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to *
5959
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
6060

6161
```python
62-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
62+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
6363

6464
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
6565

test/quantization/test_qat.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@
2222
PerRow,
2323
PerToken,
2424
)
25-
from torchao.quantization.prototype.qat.api import (
25+
from torchao.quantization.qat.api import (
2626
ComposableQATQuantizer,
2727
FakeQuantizeConfig,
2828
)
29-
from torchao.quantization.prototype.qat.fake_quantizer import (
29+
from torchao.quantization.qat.fake_quantizer import (
3030
FakeQuantizer,
3131
)
32-
from torchao.quantization.prototype.qat.embedding import (
32+
from torchao.quantization.qat.embedding import (
3333
FakeQuantizedEmbedding,
3434
)
35-
from torchao.quantization.prototype.qat.linear import (
35+
from torchao.quantization.qat.linear import (
3636
FakeQuantizedLinear,
3737
)
38-
from torchao.quantization.prototype.qat.utils import (
38+
from torchao.quantization.qat.utils import (
3939
_choose_qparams_per_token_asymmetric,
4040
_fake_quantize_per_channel_group,
4141
_fake_quantize_per_token,
@@ -175,7 +175,7 @@ def _set_ptq_weight(
175175
Int8DynActInt4WeightLinear,
176176
WeightOnlyInt4Linear,
177177
)
178-
from torchao.quantization.prototype.qat.linear import (
178+
from torchao.quantization.qat.linear import (
179179
Int8DynActInt4WeightQATLinear,
180180
Int4WeightOnlyQATLinear,
181181
)
@@ -207,7 +207,7 @@ def _set_ptq_weight(
207207

208208
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
209209
def test_qat_8da4w_linear(self):
210-
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear
210+
from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear
211211
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
212212

213213
group_size = 128
@@ -232,7 +232,7 @@ def test_qat_8da4w_linear(self):
232232

233233
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
234234
def test_qat_8da4w_quantizer(self):
235-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
235+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
236236
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
237237

238238
group_size = 16
@@ -266,7 +266,7 @@ def test_qat_8da4w_quantizer(self):
266266

267267
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
268268
def test_qat_8da4w_quantizer_meta_weights(self):
269-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
269+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
270270

271271
with torch.device("meta"):
272272
m = M()
@@ -281,7 +281,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
281281
"""
282282
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
283283
"""
284-
from torchao.quantization.prototype.qat import (
284+
from torchao.quantization.qat.linear import (
285285
Int8DynActInt4WeightQATQuantizer,
286286
disable_8da4w_fake_quant,
287287
enable_8da4w_fake_quant,
@@ -340,7 +340,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
340340
"""
341341
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
342342
"""
343-
from torchao.quantization.prototype.qat import (
343+
from torchao.quantization.qat.linear import (
344344
Int8DynActInt4WeightQATQuantizer,
345345
disable_8da4w_fake_quant,
346346
)
@@ -422,7 +422,7 @@ def _test_qat_quantized_gradients(self, quantizer):
422422

423423
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
424424
def test_qat_8da4w_quantizer_gradients(self):
425-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
425+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
426426
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
427427
self._test_qat_quantized_gradients(quantizer)
428428

@@ -512,7 +512,7 @@ def test_qat_4w_primitives(self):
512512
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
513513
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
514514
def test_qat_4w_linear(self):
515-
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear
515+
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
516516
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
517517

518518
group_size = 128
@@ -539,14 +539,14 @@ def test_qat_4w_linear(self):
539539

540540
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
541541
def test_qat_4w_quantizer_gradients(self):
542-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
542+
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
543543
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
544544
self._test_qat_quantized_gradients(quantizer)
545545

546546
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
547547
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
548548
def test_qat_4w_quantizer(self):
549-
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
549+
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
550550
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
551551

552552
group_size = 32
@@ -624,7 +624,7 @@ def test_composable_qat_quantizer(self):
624624

625625
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
626626
def test_qat_4w_embedding(self):
627-
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
627+
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
628628
model = M2()
629629
x = model.example_inputs()
630630
out = model(*x)

torchao/quantization/prototype/qat/_module_swap_api.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

torchao/quantization/prototype/qat/README.md renamed to torchao/quantization/qat/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ For example, on a single GPU:
4141
```python
4242
import torch
4343
from torchtune.models.llama3 import llama3
44-
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
44+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
4545

4646
# Smaller version of llama3 to fit in a single GPU
4747
model = llama3(

torchao/quantization/prototype/qat/__init__.py renamed to torchao/quantization/qat/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,16 @@
22
ComposableQATQuantizer,
33
)
44
from .linear import (
5-
disable_4w_fake_quant,
6-
disable_8da4w_fake_quant,
7-
enable_4w_fake_quant,
8-
enable_8da4w_fake_quant,
95
Int4WeightOnlyQATQuantizer,
10-
Int8DynActInt4WeightQATLinear,
116
Int8DynActInt4WeightQATQuantizer,
127
)
138
from .embedding import (
149
Int4WeightOnlyEmbeddingQATQuantizer,
1510
)
1611

1712
__all__ = [
18-
"disable_4w_fake_quant",
19-
"disable_8da4w_fake_quant",
20-
"enable_4w_fake_quant",
21-
"enable_8da4w_fake_quant",
2213
"ComposableQATQuantizer",
2314
"Int4WeightOnlyQATQuantizer",
2415
"Int4WeightOnlyEmbeddingQATQuantizer"
2516
"Int8DynActInt4WeightQATQuantizer",
26-
"Int8DynActInt4WeightQATLinear",
2717
]
File renamed without changes.

0 commit comments

Comments
 (0)