Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions _posts/2024-09-26-pytorch-native-architecture-optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,57 +10,50 @@ We’re happy to officially launch torchao, a PyTorch native library that makes
We benchmarked our techniques on popular GenAI models like LLama 3 and Diffusion models and saw minimal drops in accuracy. Unless otherwise noted the baselines are bf16 run on A100 80GB GPU.

Our topline metrics for llama 3 are

For inference

* 97% speedup for Llama 3 8B using autoquant with int4 weight only quantization and hqq
* 73% peak VRAM reduction for Llama 3.1 8B at 128K context length with a quantized KV cache

For training

* 97% speedup for Llama 3 8B inference using autoquant with int4 weight only quantization and hqq
* 73% peak VRAM reduction for Llama 3.1 8B inference at 128K context length with a quantized KV cache
* 50% speedup for Llama 3 70B pretraining using float8 training on H100
* 30% peak VRAM reduction for Llama 3 8B using 4 bit quantized optimizers.

Our topline metrics for diffusion model inference

* 53% speedup using float8 dynamic quantization inference with float8 row-wise scaling on flux1.dev onH100
* 50% reduction in model VRAM for CogVideoX using int8 dynamic quantization

Below we'll walk through some of the techniques available in torchao you can apply to your models for inference and training.

## Inference

[Our inference quantization algorithms](https://github.com/pytorch/ao/tree/main/torchao/quantization) work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level quantize\_ api
[Our inference quantization algorithms](https://github.com/pytorch/ao/tree/main/torchao/quantization) work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level `quantize_` api

```py
```python
from torchao.quantization import (
quantize\_,
int4\_weight\_only,
quantize_,
int4_weight_only,
)
quantize\_(model, int4\_weight\_only())
quantize_(model, int4_weight_only())
```

Sometimes quantizing a layer can make it slower because of overhead so if you’d rather we just pick how to quantize each layer in a model for you then you can instead run

```py
model \= torchao.autoquant(torch.compile(model, mode='max-autotune'))
```python
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
```

quantize\_ API has a few different options depending on whether your model is compute bound or memory bound.
`quantize_` API has a few different options depending on whether your model is compute bound or memory bound.

```py
```python
from torchao.quantization import (
# Memory bound models
int4\_weight\_only,
int8\_weight\_only,
int4_weight_only,
int8_weight_only,

# Compute bound models
int8\_dynamic\_activation\_int8\_semi\_sparse\_weight,
int8\_dynamic\_activation\_int8\_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
int8_dynamic_activation_int8_weight,

# Device capability 8.9+
float8\_weight\_only,
float8\_dynamic\_activation\_float8\_weight,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
```

Expand All @@ -86,8 +79,10 @@ Post training quantization, especially at less than 4 bit can suffer from seriou

torchao provides easy to use e2e workflows for reducing the precision of training compute and distributed communications, starting with float8 for \`torch.nn.Linear\` layers.Here is a one-liner to convert the compute gemms of your training run to float8:

from torchao.float8 import convert\_to\_float8\_training
convert\_to\_float8\_training(model)
```python
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)
```

For an e2e example of how to speed up LLaMa 3 70B pretraining by up to **1.5x** with float8, see our [README](https://github.com/pytorch/ao/tree/main/torchao/float8), and torchtitan's [blog](https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359) and [float8 recipe](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md).

Expand All @@ -106,8 +101,11 @@ We are expanding our training workflows to more dtypes and layouts

Inspired by Bits and Bytes we’ve also added prototype support for 8 and 4 bit optimizers as a drop in replacement for AdamW.

from torchao.prototype.low\_bit\_optim import AdamW8bit, AdamW4bit
optim \= AdamW8bit(model.parameters())
```python
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit
optim = AdamW8bit(model.parameters())
```

![](/assets/images/Figure_5.png){:style="width:100%"}

# Integrations
Expand All @@ -129,16 +127,16 @@ pip install torchao

There are a lot of things we’re excited about next ranging from going lower than 4 bit, performant kernels for high-throughput inference, expanding to more layers, scaling types or granularities, MX hardware support and supporting more hardware backends. If any of the above sounds exciting you can follow our progress at: [https://github.com/pytorch/ao](https://github.com/pytorch/ao)

If you’re interested in working on torchao, we’ve created a [contributors guide](https://github.com/pytorch/ao/issues/391), and if you have any questions we hang out on the \#torchao channel on [discord.gg/cudamode](http://discord.gg/cudamode)
If you’re interested in working on torchao, we’ve created a [contributors guide](https://github.com/pytorch/ao/issues/391), and if you have any questions we hang out on the `#torchao` channel on [discord.gg/gpumode](http://discord.gg/gpumode)

## Acknowledgements

We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you\!
We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you!

1. Bits and Bytes for pioneering work in low bit optimizers and QLoRA
2. Answer.ai for their engineering work to get FSDP and QLoRA composing
3. Mobius Labs for the lovely back and forths on quantization algorithms and low bit kernels
4. HuggingFace transformers for their help in battle testing and integrating our work
5. HuggingFace diffusers for our collaboration on extensive benchmarks and best practices
6. torch.compile so we could write our algorithms in pure PyTorch
7. CUDA MODE for most of our early contributors
7. GPU MODE for most of our early contributors