Skip to content

Commit 5eddaa3

Browse files
committed
Add static quantization as an example for calibration flow
Summary: So far quantization flow API that we provided (`quantize_`) does not require calibration (calibrate a model with sample data), this PR added a static quantization example that serves as an example for calibration flow * 1. first prepare the model for calibration * 2. calibrate the prepared model with sample data * 3. convert the calibrated model to quantized model Test Plan: python torchao/prototype/calibration_flow/static_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent d1e15b4 commit 5eddaa3

File tree

3 files changed

+153
-1
lines changed

3 files changed

+153
-1
lines changed

torchao/dtypes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
33
from .uint4 import UInt4Tensor
4-
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
4+
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized, to_affine_quantized_static
55

66
__all__ = [
77
"NF4Tensor",
88
"to_nf4",
99
"UInt4Tensor"
1010
"AffineQuantizedTensor",
1111
"to_affine_quantized",
12+
"to_affine_quantized_static",
1213
]

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,49 @@ def from_float(
228228
dtype=input_float.dtype
229229
)
230230

231+
@classmethod
232+
def from_float_static(
233+
cls,
234+
input_float: torch.Tensor,
235+
scale: torch.Tensor,
236+
zero_point: torch.Tensor,
237+
block_size: Tuple[int, ...],
238+
target_dtype: torch.dtype,
239+
quant_min: Optional[int] = None,
240+
quant_max: Optional[int] = None,
241+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
242+
extended_layout: str = "plain",
243+
# TODO: this is only for "tensor_core_tiled", need to figure out
244+
# the proper API for this arg
245+
inner_k_tiles: Optional[int] = None,
246+
):
247+
original_shape = input_float.shape
248+
if extended_layout == "tensor_core_tiled":
249+
orig_out_features, orig_in_features = input_float.shape
250+
in_features = find_multiple(orig_in_features, 1024)
251+
out_features = find_multiple(orig_out_features, 8)
252+
input_float = torch.nn.functional.pad(
253+
input_float,
254+
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
255+
)
256+
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
257+
258+
layout_cls_ctr = get_layout_tensor_constructor(extended_layout)
259+
# TODO: this is temporary, need to come up with the proper UX
260+
if extended_layout == "tensor_core_tiled":
261+
layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles)
262+
else:
263+
layout_tensor = layout_cls_ctr(int_data, scale, zero_point)
264+
return cls(
265+
layout_tensor,
266+
block_size,
267+
original_shape,
268+
quant_min,
269+
quant_max,
270+
zero_point_domain,
271+
dtype=input_float.dtype,
272+
)
273+
231274
@property
232275
def extended_layout(self) -> str:
233276
return self.layout_tensor.extended_layout
@@ -764,3 +807,4 @@ def t(func, *args, **kwargs):
764807
return return_and_correct_aliasing(func, args, kwargs, new)
765808

766809
to_affine_quantized = AffineQuantizedTensor.from_float
810+
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
Demo for static quantization flow
3+
"""
4+
import torch
5+
import copy
6+
7+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
8+
import torch.nn.functional as F
9+
from torch import Tensor
10+
from torchao.dtypes import to_affine_quantized_static
11+
from torchao.quantization.utils import compute_error
12+
13+
14+
class CalibratingLinear(torch.nn.Linear):
15+
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None):
16+
super().__init__(in_features, out_features, bias, device, dtype)
17+
self.act_obs = act_obs
18+
self.weight_obs = weight_obs
19+
20+
def forward(self, input: Tensor):
21+
calibrating_input = self.act_obs(input)
22+
calibrating_weight = self.weight_obs(self.weight)
23+
return F.linear(calibrating_input, calibrating_weight, self.bias)
24+
25+
@classmethod
26+
def from_float(cls, float_linear, act_obs, weight_obs):
27+
calibrating_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
28+
calibrating_linear.weight = float_linear.weight
29+
calibrating_linear.bias = float_linear.bias
30+
return calibrating_linear
31+
32+
class QuantizedLinear(torch.nn.Module):
33+
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor):
34+
super().__init__()
35+
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
36+
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
37+
assert weight.dim() == 2
38+
block_size = (1, weight.shape[1])
39+
target_dtype = torch.uint8
40+
self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
41+
self.bias = bias
42+
43+
def forward(self, input: Tensor):
44+
block_size = input.shape
45+
target_dtype = torch.uint8
46+
qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype)
47+
return F.linear(qinput, self.qweight, self.bias)
48+
49+
@classmethod
50+
def from_calibrating(cls, calibrating_linear):
51+
quantized_linear = cls(calibrating_linear.in_features, calibrating_linear.out_features, calibrating_linear.act_obs, calibrating_linear.weight_obs, calibrating_linear.weight, calibrating_linear.bias)
52+
return quantized_linear
53+
54+
class ToyLinearModel(torch.nn.Module):
55+
def __init__(self, m=64, n=32, k=64):
56+
super().__init__()
57+
self.linear1 = torch.nn.Linear(m, n, bias=False)
58+
self.linear2 = torch.nn.Linear(n, k, bias=False)
59+
60+
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
61+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
62+
63+
def forward(self, x):
64+
x = self.linear1(x)
65+
x = self.linear2(x)
66+
return x
67+
68+
dtype = torch.bfloat16
69+
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
70+
m_bf16 = copy.deepcopy(m)
71+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
72+
73+
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
74+
75+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
76+
import copy
77+
78+
def to_calibrating_(model, act_obs, weight_obs):
79+
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
80+
replacement_fn = lambda m: CalibratingLinear.from_float(m, act_obs, weight_obs)
81+
act_obs = copy.deepcopy(act_obs)
82+
weight_obs = copy.deepcopy(weight_obs)
83+
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
84+
85+
def to_quantized_(model):
86+
_is_calibrating_linear = lambda m, fqn: isinstance(m, CalibratingLinear)
87+
88+
replacement_fn = lambda m: QuantizedLinear.from_calibrating(m)
89+
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_calibrating_linear)
90+
91+
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda")
92+
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda")
93+
94+
before_quant = m(*example_inputs)
95+
96+
to_calibrating_(m, act_obs, weight_obs)
97+
# calibrating / training
98+
for _ in range(10):
99+
m(*example_inputs)
100+
101+
after_obs = m(*example_inputs)
102+
to_quantized_(m)
103+
104+
print("quantized model:", m)
105+
after_quant = m(*example_inputs)
106+
assert compute_error(before_quant, after_quant) > 30
107+
print("test passed")

0 commit comments

Comments
 (0)