From c825009639ca8e0c29be05c951bd8f38bf71cb12 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 11:48:29 -0800 Subject: [PATCH 1/7] Add convert path for quantize_ QAT API Summary: https://github.com/pytorch/ao/pull/1415 added a quantize_ QAT API for the prepare path. This commit adds the remaining convert path for users to actually perform end-to-end QAT using the quantize_ API. The new flow will look like: ``` from torchao.quantization import ( quantize_, int8_dynamic_activation_int4_weight, ) from torchao.quantization.qat import ( FakeQuantizeConfig, from_intx_quantization_aware_training, intx_quantization_aware_training, ) activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, intx_quantization_aware_training(activation_config, weight_config), ) quantize_(my_model, from_intx_quantization_aware_training()) quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) ``` Test Plan: python test/quantization/test_qat.py -k test_quantize_api_convert_path [ghstack-poisoned] --- test/quantization/test_qat.py | 65 +++++++++++++++++++++++++++ torchao/quantization/qat/__init__.py | 2 + torchao/quantization/qat/api.py | 40 ++++++++++++++++- torchao/quantization/qat/embedding.py | 18 ++++++++ torchao/quantization/qat/linear.py | 14 ++++++ 5 files changed, 137 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 42900c54f1..642f0bd4ad 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -25,6 +25,7 @@ from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, + from_intx_quantization_aware_training, intx_quantization_aware_training, ) from torchao.quantization.qat.embedding import ( @@ -42,6 +43,9 @@ _GenericFakeQuantize, _get_qmin_qmax, ) +from torchao.quantization.quant_api import ( + int8_dynamic_activation_int4_weight, +) from torchao.quantization.quant_primitives import ( MappingType, TorchAODType, @@ -1262,6 +1266,67 @@ def test_quantize_api_errors(self): lambda m, _: isinstance(m, torch.nn.ReLU), ) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_quantize_api_convert_path(self): + """ + Test that the following: + + quantize_(model, intx_quantization_aware_training(...)) + quantize_(model, from_intx_quantization_aware_training(...)) + quantize_(model, int8_dynamic_activation_int4_weight()) + + can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert. + """ + from torchao.quantization.qat import ( + Int8DynActInt4WeightQATQuantizer, + ) + + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + baseline_model = copy.deepcopy(m) + + # Baseline prepare + baseline_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + baseline_model = baseline_quantizer.prepare(baseline_model) + + # quantize_ prepare + activation_config = FakeQuantizeConfig( + torch.int8, + "per_token", + is_symmetric=False, + ) + weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + quantize_( + m, + intx_quantization_aware_training(activation_config, weight_config), + ) + + # Compare prepared values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + + # Baseline convert + baseline_model = baseline_quantizer.convert(baseline_model) + + # quantize_ convert + quantize_(m, from_intx_quantization_aware_training()) + quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + + # Compare converted values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 75ba6f22db..9b1f43a2bc 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -2,6 +2,7 @@ ComposableQATQuantizer, FakeQuantizeConfig, intx_quantization_aware_training, + from_intx_quantization_aware_training, ) from .embedding import ( Int4WeightOnlyEmbeddingQATQuantizer, @@ -18,4 +19,5 @@ "Int4WeightOnlyEmbeddingQATQuantizer", "Int8DynActInt4WeightQATQuantizer", "intx_quantization_aware_training", + "from_intx_quantization_aware_training", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 8f0244a858..cd3813291f 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch @@ -242,7 +242,7 @@ def __setattr__(self, name: str, value: Any): def intx_quantization_aware_training( activation_config: Optional[FakeQuantizeConfig] = None, weight_config: Optional[FakeQuantizeConfig] = None, -) -> torch.nn.Module: +) -> Callable: """ Return a function that applies fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. @@ -295,6 +295,42 @@ def _insert_fake_quantize(mod: torch.nn.Module): return _insert_fake_quantize +def from_intx_quantization_aware_training() -> Callable: + """ + Return a function that converts a model with fake quantized modules, + such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` + and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, + back to model with the original, corresponding modules without + fake quantization. This should be used with + :func:`~torchao.quantization.quant_api.quantize_`. + + Example usage:: + + from torchao.quantization import quantize_ + quantize_( + model_with_fake_quantized_linears, + from_intx_quantization_aware_training(), + ) + """ + + def _remove_fake_quantize(mod: torch.nn.Module): + """ + If the given module is a fake quantized module, return the original + corresponding version of the module without fake quantization. + """ + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + if isinstance(mod, FakeQuantizedLinear): + return mod.to_linear() + elif isinstance(mod, FakeQuantizedEmbedding): + return mod.to_embedding() + else: + return mod + + return _remove_fake_quantize + + class ComposableQATQuantizer(TwoStepQuantizer): """ Composable quantizer that users can use to apply multiple QAT quantizers easily. diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index ff580ac1d3..cc63c5181d 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -82,6 +82,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.sparse, ) + def to_embedding(self) -> torch.nn.Embedding: + new_embedding = torch.nn.Embedding( + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + device=self.weight.device, + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if self.weight.device != torch.device("meta"): + new_embedding.weight = self.weight + return new_embedding + @classmethod def from_embedding( cls, diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 153e324838..34f7bba2f8 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -105,6 +105,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.weight return F.linear(x, w) + def to_linear(self) -> torch.nn.Linear: + new_linear = torch.nn.Linear( + self.in_features, + self.out_features, + self.bias, + device=self.weight.device + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if self.weight.device != torch.device("meta"): + new_linear.weight = self.weight + return new_linear + @classmethod def from_linear( cls, From 50a8355d4521625bb9291fbfca3dea910767e96b Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 11:48:31 -0800 Subject: [PATCH 2/7] Update QAT READMEs using new APIs Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] --- README.md | 33 ++++--- torchao/quantization/qat/README.md | 146 ++++++++++++++++++++++------- 2 files changed, 135 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 6ba0e3be4c..298b31a739 100644 --- a/README.md +++ b/README.md @@ -54,27 +54,38 @@ We've added kv cache quantization and other features in order to enable long con In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md) +## Training + ### Quantization Aware Training Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) ```python -from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer - -qat_quantizer = Int8DynActInt4WeightQATQuantizer() +from torchao.quantization import ( + quantize_, + int8_dynamic_activation_int4_weight, +) +from torchao.quantization.qat import ( + FakeQuantizeConfig, + from_intx_quantization_aware_training, + intx_quantization_aware_training, +) -# Insert "fake quantize" operations into linear layers. -# These operations simulate quantization numerics -model = qat_quantizer.prepare(model) +# Insert fake quantization +activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +quantize_( + my_model, + intx_quantization_aware_training(activation_config, weight_config), +) -# Run Training... +# Run training... (not shown) -# Convert fake quantize to actual quantize operations -model = qat_quantizer.convert(model) +# Convert fake quantization to actual quantized operations +quantize_(my_model, from_intx_quantization_aware_training()) +quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) ``` -## Training - ### Float8 [torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 6ecccd2b18..16c8cba5d9 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -19,12 +19,6 @@ x_fq = (x_float / scale + zp).round().clamp(qmin, qmax) x_fq = (x_fq - zp) * scale ``` -## API - -torchao currently supports two QAT schemes for linear layers: -- int8 per token dynamic activations + int4 per group weights -- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) - QAT typically involves applying a transformation to your model before and after training. In torchao, these are represented as the prepare and convert steps: (1) prepare inserts fake quantize operations into linear layers, and (2) convert transforms the fake quantize @@ -34,16 +28,24 @@ Between these two steps, training can proceed exactly as before. ![qat](images/qat_diagram.png) -To use QAT in torchao, apply the prepare step using the appropriate Quantizer before -training, then apply the convert step after training for inference or generation. -For example, on a single GPU: + +## API + +torchao currently supports two QAT APIs, one through the [`quantize_`](https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_) +API (recommended) and one through the Quantizer classes (legacy). The `quantize_` API +allows flexible configuration of quantization settings for both activations and weights, +while the Quantizer classes each hardcode a specific quantization setting. + +Here's an example of running QAT using the following quantization setting on a single GPU: +- int8 per-token dynamic asymmetric activation (for linears) +- int4 per-group symmetric weight (for linears) ```python import torch from torchtune.models.llama3 import llama3 from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer -# Smaller version of llama3 to fit in a single GPU +# Set up smaller version of llama3 to fit in a single GPU model = llama3( vocab_size=4096, num_layers=16, @@ -53,36 +55,107 @@ model = llama3( max_seq_len=2048, ).cuda() -# Quantizer for int8 dynamic per token activations + -# int4 grouped per channel weights, only for linear layers -qat_quantizer = Int8DynActInt4WeightQATQuantizer() +# Example training loop +def train(m: torch.nn.Module): + optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + loss_fn = torch.nn.CrossEntropyLoss() + for i in range(10): + example = torch.randint(0, 4096, (2, 16)).cuda() + target = torch.randn((2, 16, 4096)).cuda() + output = m(example) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() +``` + +### quantize_ + +```python +from torchao.quantization import ( + quantize_, + int8_dynamic_activation_int4_weight, +) +from torchao.quantization.qat import ( + FakeQuantizeConfig, + from_intx_quantization_aware_training, + intx_quantization_aware_training, +) + +# prepare: insert fake quantization ops +# Model consists of `FakeQuantizedLinear` afterwards +activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +quantize_( + my_model, + intx_quantization_aware_training(activation_config, weight_config), +) + +# train (not shown) + +# convert: transform fake quantization ops into actual quantized ops +# Model consists of `torch.nn.Linear` with quantized activation and weight tensors afterwards +quantize_(my_model, from_intx_quantization_aware_training()) +quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) + +# inference or generate +``` + +To fake quantize embedding in addition to linear, you can additionally call +the following with a filter function during the prepare step. + +``` +quantize_( + m, + intx_quantization_aware_training(weight_config=weight_config), + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), +) +``` -# Insert "fake quantize" operations into linear layers. -# These operations simulate quantization numerics during -# training without performing any dtype casting + +### Quantizer + +```python +from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + +qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32) + +# prepare: insert fake quantization ops +# Model consists of `Int8DynActInt4WeightQATLinear` afterwards model = qat_quantizer.prepare(model) -# Standard training loop -optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) -loss_fn = torch.nn.CrossEntropyLoss() -for i in range(10): - example = torch.randint(0, 4096, (2, 16)).cuda() - target = torch.randn((2, 16, 4096)).cuda() - output = model(example) - loss = loss_fn(output, target) - loss.backward() - optimizer.step() - optimizer.zero_grad() - -# Convert fake quantize to actual quantize operations -# The quantized model has the exact same structure as the -# quantized model produced in the corresponding PTQ flow -# through `Int8DynActInt4WeightQuantizer` +# train (not shown) + +# convert: transform fake quantization ops into actual quantized ops +# Model consists of `Int8DynActInt4WeightLinear` afterwards model = qat_quantizer.convert(model) # inference or generate ``` +torchao currently supports the following Quantizers: +- Linear: [Int8DynActInt4QATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L126), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight +- Linear: [Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) +- Embedding: [Int4WeightOnlyEmbeddingQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/embedding.py#L94), targeting int4 per-group symmetric weight +- [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242), allow users to compose multiple Quantizers (one for each layer), for example: + +``` +from torchao.quantization.qat import ( + ComposableQATQuantizer, + Int4WeightOnlyEmbeddingQATQuantizer, + Int8DynActInt4WeightQATQuantizer, +) + +quantizer = ComposableQATQuantizer([ + Int8DynActInt4WeightQATQuantizer(groupsize=group_size), + Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size), +]) + +# prepare + train + convert as before +``` + +## torchtune integration + Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune) and apply quantized-aware fine-tuning as follows: @@ -90,8 +163,15 @@ and apply quantized-aware fine-tuning as follows: tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full ``` -For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). +torchtune also supports a QAT + LoRA distributed training recipe that is 1.89x faster +and uses 36.1% memory compared to vanilla QAT in our early experiments. You can read +more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700). +``` +tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora +``` + +For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). ## Evaluation Results From ec947972797dda797ac7a987b5251d7a5238026d Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 11:52:01 -0800 Subject: [PATCH 3/7] Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] --- torchao/quantization/qat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 9b1f43a2bc..15008e03ea 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,8 +1,8 @@ from .api import ( ComposableQATQuantizer, FakeQuantizeConfig, - intx_quantization_aware_training, from_intx_quantization_aware_training, + intx_quantization_aware_training, ) from .embedding import ( Int4WeightOnlyEmbeddingQATQuantizer, From 507a9e6e63aacf40966b051399a1b36f58a7ef39 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 11:53:37 -0800 Subject: [PATCH 4/7] Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] --- torchao/quantization/qat/linear.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 34f7bba2f8..fafda68d58 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -107,10 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def to_linear(self) -> torch.nn.Linear: new_linear = torch.nn.Linear( - self.in_features, - self.out_features, - self.bias, - device=self.weight.device + self.in_features, self.out_features, self.bias, device=self.weight.device ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to From 72000a2674c2c53e68fea9f01c67e5fe0758438a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 11:55:10 -0800 Subject: [PATCH 5/7] Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] From 87bb7db9533f2f734225b20c77c4c68150de6527 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 12:41:23 -0800 Subject: [PATCH 6/7] Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] From 2dea2760744fc463edf5d92298e84ba6c8b592d8 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 12:53:33 -0800 Subject: [PATCH 7/7] Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned]