Skip to content

Commit a753e3f

Browse files
committed
Adding gpu quantization workflows and apis
Summary: Apis and workflows used for quantization and pruning in the segment-anything-fast and gpt-fast repos. Test Plan: python /home/cdhernandez/local/ao/ao/quantization/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 31191a7 Pull Request resolved: #1
1 parent 7b3330c commit a753e3f

17 files changed

+2132
-0
lines changed

ao/quantization/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from smoothquant import * # noqa: F403
2+
from quant_api import * # noqa: F403
3+
from subclass import * # noqa: F403
4+
from quant_primitives import * # noqa: F403
5+
from utils import * # noqa: F403
6+
from weight_only import * # noqa: F403
7+
8+
__all__ = [
9+
"DynamicallyPerAxisQuantizedLinear",
10+
"replace_with_custom_fn_if_matches_filter",
11+
"apply_weight_only_int8_quant",
12+
"apply_dynamic_quant",
13+
"change_linear_weights_to_dqtensors",
14+
"insert_subclass",
15+
"safe_int_mm",
16+
"dynamically_quantize_per_tensor",
17+
"quantize_activation_per_token_absmax",
18+
"dynamically_quantize_per_channel",
19+
"dequantize_per_tensor",
20+
"dequantize_per_channel",
21+
"quant_int8_dynamic_linear",
22+
"quant_int8_matmul",
23+
"quant_int8_dynamic_per_token_linear",
24+
"quant_int8_per_token_matmul",
25+
"get_scale",
26+
"SmoothFakeDynQuantMixin",
27+
"SmoothFakeDynamicallyQuantizedLinear",
28+
"swap_linear_with_smooth_fq_linear",
29+
"smooth_fq_linear_to_inference",
30+
"set_smooth_fq_attribute",
31+
"DynamicallyQuantizedLinearWeight",
32+
"log_with_rank",
33+
"clear_logs",
34+
"compute_error",
35+
"forward_hook",
36+
"apply_logging_hook",
37+
"get_model_size_in_bytes",
38+
"WeightOnlyInt8QuantLinear",
39+
]
2.8 KB
Binary file not shown.
2.39 KB
Binary file not shown.
5.99 KB
Binary file not shown.
6.76 KB
Binary file not shown.
3.85 KB
Binary file not shown.
2.73 KB
Binary file not shown.
1.55 KB
Binary file not shown.

ao/quantization/dynamic_quant.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch
2+
import torch.nn as nn
3+
from quant_primitives import (
4+
dynamically_quantize_per_channel,
5+
quant_int8_dynamic_per_token_linear,
6+
)
7+
8+
__all__ = ["DynamicallyPerAxisQuantizedLinear"]
9+
10+
11+
class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear):
12+
"""
13+
This class is a replacement for `torch.nn.Linear`, implementing dynamic quantization on
14+
the input across all axes except for the last axis.
15+
"""
16+
17+
def __init__(
18+
self,
19+
in_features: int,
20+
out_features: int,
21+
bias: bool = True,
22+
use_fused_int_mm=False,
23+
) -> None:
24+
super().__init__(in_features, out_features, bias)
25+
self.use_fused_int_mm = use_fused_int_mm
26+
# note: enabling use_fused_int_mm = True has best perf when additionally setting
27+
# torch._inductor.config.force_fuse_int_mm_with_mul = True
28+
29+
def forward(self, X: torch.Tensor) -> torch.Tensor:
30+
"""
31+
Performs the forward pass of the quantized linear layer.
32+
33+
This method applies dynamic quantization to the input tensor across all axes except
34+
the last axis using the `quant_int8_dynamic_per_token_linear` function.
35+
36+
Args:
37+
X (torch.Tensor): The input tensor to the quantized linear layer.
38+
39+
Returns:
40+
torch.Tensor: The output tensor after the quantized matmul and rescale.
41+
42+
"""
43+
# The following line mimics the behavior of SmoothFakeDynamicallyQuantizedLinear
44+
if not self.use_fused_int_mm:
45+
X = X / self.fake_rescale
46+
# somehow the inductor fusion that occurs for most transformer models
47+
# when this module has an additional div op is faster than when it doesn't
48+
# have it although the memory usage is slightly higher. fake_rescale is scalar 1
49+
# so it doesn't affect accuracy
50+
Y = quant_int8_dynamic_per_token_linear(
51+
X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype
52+
)
53+
return Y
54+
55+
@classmethod
56+
def from_float(
57+
cls, mod: torch.nn.Linear, use_fused_int_mm=False
58+
) -> "DynamicallyPerAxisQuantizedLinear":
59+
"""
60+
Converts a `mod` of class `torch.nn.Linear` to the dynamically quantized version of it.
61+
62+
Note: this class does not require calibration.
63+
64+
Args:
65+
mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert.
66+
67+
Returns:
68+
DynamicallyPerAxisQuantizedLinear: The converted quantized linear module.
69+
70+
"""
71+
72+
# create the new module with a toy size to ensure initialization is fast
73+
fake_in_features, fake_out_features = 8, 8
74+
new_mod = cls(
75+
fake_in_features,
76+
fake_out_features,
77+
bias=mod.bias is not None,
78+
use_fused_int_mm=use_fused_int_mm,
79+
)
80+
new_mod.in_features = mod.in_features
81+
new_mod.out_features = mod.out_features
82+
W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel(
83+
mod.weight, -128, 127, torch.int8
84+
)
85+
new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t())
86+
new_mod.W_scales = nn.Parameter(W_scales)
87+
new_mod.bias = mod.bias
88+
if not use_fused_int_mm:
89+
new_mod.fake_rescale = torch.tensor(
90+
[1.0], dtype=mod.weight.dtype, device=mod.weight.device
91+
)
92+
del new_mod.weight
93+
94+
device_to_use = next(mod.parameters()).device
95+
new_mod.to(device_to_use)
96+
return new_mod

ao/quantization/quant_api.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Quantization API stuff which is not specific to SmoothQuant
3+
4+
Note: this is throwaway code for fast results on Blueberry, this is not
5+
intended to be the actual long term quantization API for server GPUs.
6+
"""
7+
8+
import torch
9+
from dynamic_quant import (
10+
DynamicallyPerAxisQuantizedLinear,
11+
)
12+
from subclass import (
13+
DynamicallyQuantizedLinearWeight,
14+
)
15+
from weight_only import (
16+
WeightOnlyInt8QuantLinear,
17+
)
18+
19+
__all__ = [
20+
"replace_with_custom_fn_if_matches_filter",
21+
"apply_weight_only_int8_quant",
22+
"apply_dynamic_quant",
23+
"change_linear_weights_to_dqtensors",
24+
]
25+
26+
27+
def replace_with_custom_fn_if_matches_filter(
28+
model, replacement_fn, filter_fn, cur_fqn=""
29+
) -> None:
30+
"""
31+
For each `child` in `model`, replaces it with `replacement_fn(child)`
32+
if `filter_fn(child)` is `True`
33+
"""
34+
name_to_child = dict(model.named_children())
35+
for name, child in name_to_child.items():
36+
if cur_fqn == "":
37+
new_fqn = name
38+
else:
39+
new_fqn = f"{cur_fqn}.{name}"
40+
if filter_fn(child, new_fqn):
41+
new_child = replacement_fn(child)
42+
setattr(model, name, new_child)
43+
else:
44+
replace_with_custom_fn_if_matches_filter(
45+
child, replacement_fn, filter_fn, new_fqn
46+
)
47+
48+
49+
def apply_weight_only_int8_quant(model):
50+
replace_with_custom_fn_if_matches_filter(
51+
model,
52+
WeightOnlyInt8QuantLinear.from_float,
53+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
54+
)
55+
56+
57+
def apply_dynamic_quant(model, use_fused_int_mm=0):
58+
replace_with_custom_fn_if_matches_filter(
59+
model,
60+
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod, use_fused_int_mm),
61+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
62+
)
63+
64+
65+
def change_linear_weights_to_dqtensors(model):
66+
def insert_subclass(lin):
67+
lin.weight = torch.nn.Parameter(
68+
DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False
69+
)
70+
return lin
71+
72+
replace_with_custom_fn_if_matches_filter(
73+
model, insert_subclass, lambda mod, fqn: isinstance(mod, torch.nn.Linear)
74+
)

0 commit comments

Comments
 (0)