Skip to content

Commit 3dce6a3

Browse files
committed
Update on "Update QAT READMEs using new APIs"
Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned]
2 parents be538d6 + 2dea276 commit 3dce6a3

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ In practice these features alongside int4 weight only quantization allow us to *
5858

5959
### Quantization Aware Training
6060

61-
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
61+
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md).
6262

6363
```python
6464
from torchao.quantization import (

torchao/quantization/qat/README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ from torchao.quantization.qat import (
9191
model = get_model()
9292

9393
# prepare: insert fake quantization ops
94-
# Swap `torch.nn.Linear` with `FakeQuantizedLinear`
94+
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
9595
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
9696
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
9797
quantize_(
@@ -103,7 +103,7 @@ quantize_(
103103
train_loop(model)
104104

105105
# convert: transform fake quantization ops into actual quantized ops
106-
# Swap `FakeQuantizedLinear` back to `torch.nn.Linear` and insert
106+
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
107107
# quantized activation and weight tensor subclasses
108108
quantize_(model, from_intx_quantization_aware_training())
109109
quantize_(model, int8_dynamic_activation_int4_weight(group_size=32))
@@ -112,7 +112,7 @@ quantize_(model, int8_dynamic_activation_int4_weight(group_size=32))
112112
```
113113

114114
To fake quantize embedding in addition to linear, you can additionally call
115-
the following with a filter function during the prepare step.
115+
the following with a filter function during the prepare step:
116116

117117
```
118118
quantize_(
@@ -138,14 +138,14 @@ qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
138138
model = get_model()
139139

140140
# prepare: insert fake quantization ops
141-
# Swap `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
141+
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
142142
model = qat_quantizer.prepare(model)
143143

144144
# train
145145
train_loop(model)
146146

147147
# convert: transform fake quantization ops into actual quantized ops
148-
# Swap `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
148+
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
149149
model = qat_quantizer.convert(model)
150150

151151
# inference or generate
@@ -155,7 +155,7 @@ To use multiple Quantizers in the same model for different layer types,
155155
users can also leverage the [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242)
156156
as follows:
157157

158-
```
158+
```python
159159
from torchao.quantization.qat import (
160160
ComposableQATQuantizer,
161161
Int4WeightOnlyEmbeddingQATQuantizer,
@@ -175,16 +175,16 @@ model = qat_quantizer.convert(model)
175175

176176
## torchtune integration
177177

178-
Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
179-
and apply quantized-aware fine-tuning as follows:
178+
torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune)
179+
to allow users to run quantized-aware fine-tuning as follows:
180180

181181
```
182182
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
183183
```
184184

185-
torchtune also supports a QAT + LoRA distributed training recipe that is 1.89x faster
186-
and uses 36.1% memory compared to vanilla QAT in our early experiments. You can read
187-
more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700).
185+
torchtune also supports a [QAT + LoRA distributed training recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py)
186+
that is 1.89x faster and uses 36.1% memory compared to vanilla QAT in our early experiments.
187+
You can read more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700):
188188

189189
```
190190
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora

0 commit comments

Comments
 (0)