-
Notifications
You must be signed in to change notification settings - Fork 369
Description
Smoothquant implements input-weight equalization, currently the implementation in torchao is using module swap, but it can be refactored to use tensor subclass, and also to use AffineQuantizedTensor so that we can consolidate the performance optimizations to one place. We can use static quantization flow as an example: #487.
Main benefit of the refactor would be: (1) aligning model level APIs (2) easier deserialization story (https://pytorch.org/ao/stable/serialization.html#what-happens-when-deserializing-an-optimized-model), you can load the quantized state dict to original model directly and get a model ready for inference
Overview
Here is the top level API for smoothquant: https://github.com/pytorch/ao/tree/main/torchao/quantization#to-be-moved-to-prototype-a8w8-dynamic-quantization-with-smoothquant
It follows our calibration flow (static quant flow) pretty closely:
ao/tutorials/calibration_flow/static_quant.py
Lines 121 to 134 in afde175
| insert_observers_(m, act_obs, weight_obs) | |
| # calibrating / training | |
| for _ in range(10): | |
| m(*example_inputs) | |
| after_obs = m(*example_inputs) | |
| m2 = copy.deepcopy(m) | |
| is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) | |
| # quantized linear represented as an nn.Linear with modified tensor subclass weights | |
| # for both activation and weight quantization | |
| quantize_(m, apply_static_quant, is_observed_linear) |
How to implement it in torchao
Similar to static quantization flow, at the high level, we can have two steps.
Step 1. Inserting Observers
First step is to insert observers that records the running absolute max value:
ao/torchao/quantization/smoothquant.py
Lines 146 to 147 in afde175
| self.update_x_running_abs_max(X) | |
| Y = F.linear(X, self.weight, self.bias) |
we can create a function insert_smoothquant_observers_ similar to
| def insert_observers_(model, act_obs, weight_obs): |
Step 2. Convert to AffineQuantizedTensor with a new layout
After we collected the stats, we can convert the floating point weight to AffineQuantizedTensor with a new LayoutType and AQTLayout, with an extra equalization_scale Tensor, this can share the same implementation as AWQ I think, although with different dtypes (int8). Example conversion code:
ao/tutorials/calibration_flow/static_quant.py
Lines 46 to 63 in afde175
| def apply_static_quant(observed_linear): | |
| target_dtype = torch.uint8 | |
| # weight quantization | |
| weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() | |
| def weight_quant_func(weight): | |
| block_size = (1, weight.shape[1]) | |
| return to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) | |
| linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) | |
| linear.weight = observed_linear.weight | |
| linear.bias = observed_linear.bias | |
| linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) | |
| # activation quantization | |
| act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() | |
| input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype) | |
| linear.weight = torch.nn.Parameter(to_linear_activation_quantized(linear.weight, input_quant_func), requires_grad=False) |
In terms of model level API, we can implement some helper function like
ao/torchao/quantization/quant_api.py
Line 363 in afde175
| def int4_weight_only(group_size=128, inner_k_tiles=8): |
Logistics (Code Location, Test and Benchmarks)
Please create an smoothquant folder under https://github.com/pytorch/ao/tree/main/torchao/prototype
The flow and layout implementation can be in separate files, e.g. flow.py, layout.py (there might be some missing extension points of AffineQuantizedTensor, but we'll work on these at the same time)
For Testing, please create a test_smoothquant.py in https://github.com/pytorch/ao/tree/main/test/prototype and move the tests from
ao/test/integration/test_integration.py
Line 159 in afde175
| class SmoothquantUnitTest(unittest.TestCase): |
For e2e flow demo, please add a smoothquant.py in https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow
following the static quant example, please show the benchmarking result as well (since we are using optimized kernel) following https://github.com/pytorch/ao/tree/main/torchao/quantization#quantization-flow-example
Last step is to test this with llama2/llama3 following instructions in https://github.com/pytorch/ao/tree/main/torchao/_models/llama and measure the metrics in https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks if you have GPU machines. For smoothquant, you can test in CPU machines and add results in the quantization README as well
References
- General tensor subclass based API doc: [RFC] torchao Contributor Guide #391
- smoothquant paper: https://arxiv.org/abs/2211.10438