diff --git a/_posts/2024-09-26-pytorch-native-architecture-optimization.md b/_posts/2024-09-26-pytorch-native-architecture-optimization.md index dd901a5a2517..0bcbb7c6a7d7 100644 --- a/_posts/2024-09-26-pytorch-native-architecture-optimization.md +++ b/_posts/2024-09-26-pytorch-native-architecture-optimization.md @@ -10,19 +10,12 @@ We’re happy to officially launch torchao, a PyTorch native library that makes We benchmarked our techniques on popular GenAI models like LLama 3 and Diffusion models and saw minimal drops in accuracy. Unless otherwise noted the baselines are bf16 run on A100 80GB GPU. Our topline metrics for llama 3 are - -For inference - -* 97% speedup for Llama 3 8B using autoquant with int4 weight only quantization and hqq -* 73% peak VRAM reduction for Llama 3.1 8B at 128K context length with a quantized KV cache - -For training - +* 97% speedup for Llama 3 8B inference using autoquant with int4 weight only quantization and hqq +* 73% peak VRAM reduction for Llama 3.1 8B inference at 128K context length with a quantized KV cache * 50% speedup for Llama 3 70B pretraining using float8 training on H100 * 30% peak VRAM reduction for Llama 3 8B using 4 bit quantized optimizers. Our topline metrics for diffusion model inference - * 53% speedup using float8 dynamic quantization inference with float8 row-wise scaling on flux1.dev onH100 * 50% reduction in model VRAM for CogVideoX using int8 dynamic quantization @@ -30,37 +23,37 @@ Below we'll walk through some of the techniques available in torchao you can app ## Inference -[Our inference quantization algorithms](https://github.com/pytorch/ao/tree/main/torchao/quantization) work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level quantize\_ api +[Our inference quantization algorithms](https://github.com/pytorch/ao/tree/main/torchao/quantization) work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level `quantize_` api -```py +```python from torchao.quantization import ( - quantize\_, - int4\_weight\_only, + quantize_, + int4_weight_only, ) -quantize\_(model, int4\_weight\_only()) +quantize_(model, int4_weight_only()) ``` Sometimes quantizing a layer can make it slower because of overhead so if you’d rather we just pick how to quantize each layer in a model for you then you can instead run -```py -model \= torchao.autoquant(torch.compile(model, mode='max-autotune')) +```python +model = torchao.autoquant(torch.compile(model, mode='max-autotune')) ``` -quantize\_ API has a few different options depending on whether your model is compute bound or memory bound. +`quantize_` API has a few different options depending on whether your model is compute bound or memory bound. -```py +```python from torchao.quantization import ( # Memory bound models - int4\_weight\_only, - int8\_weight\_only, + int4_weight_only, + int8_weight_only, # Compute bound models - int8\_dynamic\_activation\_int8\_semi\_sparse\_weight, - int8\_dynamic\_activation\_int8\_weight, + int8_dynamic_activation_int8_semi_sparse_weight, + int8_dynamic_activation_int8_weight, # Device capability 8.9+ - float8\_weight\_only, - float8\_dynamic\_activation\_float8\_weight, + float8_weight_only, + float8_dynamic_activation_float8_weight, ) ``` @@ -86,8 +79,10 @@ Post training quantization, especially at less than 4 bit can suffer from seriou torchao provides easy to use e2e workflows for reducing the precision of training compute and distributed communications, starting with float8 for \`torch.nn.Linear\` layers.Here is a one-liner to convert the compute gemms of your training run to float8: -from torchao.float8 import convert\_to\_float8\_training -convert\_to\_float8\_training(model) +```python +from torchao.float8 import convert_to_float8_training +convert_to_float8_training(model) +``` For an e2e example of how to speed up LLaMa 3 70B pretraining by up to **1.5x** with float8, see our [README](https://github.com/pytorch/ao/tree/main/torchao/float8), and torchtitan's [blog](https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359) and [float8 recipe](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md). @@ -106,8 +101,11 @@ We are expanding our training workflows to more dtypes and layouts Inspired by Bits and Bytes we’ve also added prototype support for 8 and 4 bit optimizers as a drop in replacement for AdamW. -from torchao.prototype.low\_bit\_optim import AdamW8bit, AdamW4bit -optim \= AdamW8bit(model.parameters()) +```python +from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit +optim = AdamW8bit(model.parameters()) +``` + ![](/assets/images/Figure_5.png){:style="width:100%"} # Integrations @@ -129,11 +127,11 @@ pip install torchao There are a lot of things we’re excited about next ranging from going lower than 4 bit, performant kernels for high-throughput inference, expanding to more layers, scaling types or granularities, MX hardware support and supporting more hardware backends. If any of the above sounds exciting you can follow our progress at: [https://github.com/pytorch/ao](https://github.com/pytorch/ao) -If you’re interested in working on torchao, we’ve created a [contributors guide](https://github.com/pytorch/ao/issues/391), and if you have any questions we hang out on the \#torchao channel on [discord.gg/cudamode](http://discord.gg/cudamode) +If you’re interested in working on torchao, we’ve created a [contributors guide](https://github.com/pytorch/ao/issues/391), and if you have any questions we hang out on the `#torchao` channel on [discord.gg/gpumode](http://discord.gg/gpumode) ## Acknowledgements -We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you\! +We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you! 1. Bits and Bytes for pioneering work in low bit optimizers and QLoRA 2. Answer.ai for their engineering work to get FSDP and QLoRA composing @@ -141,4 +139,4 @@ We are fortunate to stand on the shoulders of giants and collaborate with some o 4. HuggingFace transformers for their help in battle testing and integrating our work 5. HuggingFace diffusers for our collaboration on extensive benchmarks and best practices 6. torch.compile so we could write our algorithms in pure PyTorch -7. CUDA MODE for most of our early contributors +7. GPU MODE for most of our early contributors