Skip to content

Commit e4f3e74

Browse files
committed
Add static quantization as an example for calibration flow
Summary: So far quantization flow API that we provided (`quantize_`) does not require calibration (calibrate a model with sample data), this PR added a static quantization example that serves as an example for calibration flow * 1. first prepare the model for calibration * 2. calibrate the prepared model with sample data * 3. convert the calibrated model to quantized model Test Plan: python torchao/prototype/calibration_flow/static_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent d1e15b4 commit e4f3e74

File tree

3 files changed

+156
-1
lines changed

3 files changed

+156
-1
lines changed

torchao/dtypes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
33
from .uint4 import UInt4Tensor
4-
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
4+
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized, to_affine_quantized_static
55

66
__all__ = [
77
"NF4Tensor",
88
"to_nf4",
99
"UInt4Tensor"
1010
"AffineQuantizedTensor",
1111
"to_affine_quantized",
12+
"to_affine_quantized_static",
1213
]

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,49 @@ def from_float(
228228
dtype=input_float.dtype
229229
)
230230

231+
@classmethod
232+
def from_float_static(
233+
cls,
234+
input_float: torch.Tensor,
235+
scale: torch.Tensor,
236+
zero_point: torch.Tensor,
237+
block_size: Tuple[int, ...],
238+
target_dtype: torch.dtype,
239+
quant_min: Optional[int] = None,
240+
quant_max: Optional[int] = None,
241+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
242+
extended_layout: str = "plain",
243+
# TODO: this is only for "tensor_core_tiled", need to figure out
244+
# the proper API for this arg
245+
inner_k_tiles: Optional[int] = None,
246+
):
247+
original_shape = input_float.shape
248+
if extended_layout == "tensor_core_tiled":
249+
orig_out_features, orig_in_features = input_float.shape
250+
in_features = find_multiple(orig_in_features, 1024)
251+
out_features = find_multiple(orig_out_features, 8)
252+
input_float = torch.nn.functional.pad(
253+
input_float,
254+
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
255+
)
256+
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
257+
258+
layout_cls_ctr = get_layout_tensor_constructor(extended_layout)
259+
# TODO: this is temporary, need to come up with the proper UX
260+
if extended_layout == "tensor_core_tiled":
261+
layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles)
262+
else:
263+
layout_tensor = layout_cls_ctr(int_data, scale, zero_point)
264+
return cls(
265+
layout_tensor,
266+
block_size,
267+
original_shape,
268+
quant_min,
269+
quant_max,
270+
zero_point_domain,
271+
dtype=input_float.dtype,
272+
)
273+
231274
@property
232275
def extended_layout(self) -> str:
233276
return self.layout_tensor.extended_layout
@@ -764,3 +807,4 @@ def t(func, *args, **kwargs):
764807
return return_and_correct_aliasing(func, args, kwargs, new)
765808

766809
to_affine_quantized = AffineQuantizedTensor.from_float
810+
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Demo for static quantization flow
3+
"""
4+
import torch
5+
import copy
6+
7+
# TODO: use the generalized observer for affine qunatization in the future
8+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
9+
import torch.nn.functional as F
10+
from torch import Tensor
11+
from torchao.dtypes import to_affine_quantized_static
12+
from torchao.quantization.utils import compute_error
13+
14+
15+
class CalibratingLinear(torch.nn.Linear):
16+
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None):
17+
super().__init__(in_features, out_features, bias, device, dtype)
18+
self.act_obs = act_obs
19+
self.weight_obs = weight_obs
20+
21+
def forward(self, input: Tensor):
22+
calibrating_input = self.act_obs(input)
23+
calibrating_weight = self.weight_obs(self.weight)
24+
return F.linear(calibrating_input, calibrating_weight, self.bias)
25+
26+
@classmethod
27+
def from_float(cls, float_linear, act_obs, weight_obs):
28+
calibrating_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
29+
calibrating_linear.weight = float_linear.weight
30+
calibrating_linear.bias = float_linear.bias
31+
return calibrating_linear
32+
33+
class QuantizedLinear(torch.nn.Module):
34+
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor):
35+
super().__init__()
36+
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
37+
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
38+
assert weight.dim() == 2
39+
block_size = (1, weight.shape[1])
40+
target_dtype = torch.uint8
41+
self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
42+
self.bias = bias
43+
44+
def forward(self, input: Tensor):
45+
block_size = input.shape
46+
target_dtype = torch.uint8
47+
qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype)
48+
return F.linear(qinput, self.qweight, self.bias)
49+
50+
@classmethod
51+
def from_calibrating(cls, calibrating_linear):
52+
quantized_linear = cls(calibrating_linear.in_features, calibrating_linear.out_features, calibrating_linear.act_obs, calibrating_linear.weight_obs, calibrating_linear.weight, calibrating_linear.bias)
53+
return quantized_linear
54+
55+
class ToyLinearModel(torch.nn.Module):
56+
def __init__(self, m=64, n=32, k=64):
57+
super().__init__()
58+
self.linear1 = torch.nn.Linear(m, n, bias=False)
59+
self.linear2 = torch.nn.Linear(n, k, bias=False)
60+
61+
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
62+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
63+
64+
def forward(self, x):
65+
x = self.linear1(x)
66+
x = self.linear2(x)
67+
return x
68+
69+
dtype = torch.bfloat16
70+
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
71+
m_bf16 = copy.deepcopy(m)
72+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
73+
74+
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
75+
76+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
77+
import copy
78+
79+
def to_calibrating_(model, act_obs, weight_obs):
80+
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
81+
replacement_fn = lambda m: CalibratingLinear.from_float(m, act_obs, weight_obs)
82+
act_obs = copy.deepcopy(act_obs)
83+
weight_obs = copy.deepcopy(weight_obs)
84+
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
85+
86+
def to_quantized_(model):
87+
_is_calibrating_linear = lambda m, fqn: isinstance(m, CalibratingLinear)
88+
89+
replacement_fn = lambda m: QuantizedLinear.from_calibrating(m)
90+
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_calibrating_linear)
91+
92+
93+
# TODO: use the generalized observer for affine qunatization in the future
94+
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda")
95+
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda")
96+
97+
before_quant = m(*example_inputs)
98+
99+
to_calibrating_(m, act_obs, weight_obs)
100+
# calibrating / training
101+
for _ in range(10):
102+
m(*example_inputs)
103+
104+
after_obs = m(*example_inputs)
105+
to_quantized_(m)
106+
107+
print("quantized model:", m)
108+
after_quant = m(*example_inputs)
109+
assert compute_error(before_quant, after_quant) > 30
110+
print("test passed")

0 commit comments

Comments
 (0)