|
| 1 | +""" |
| 2 | +Demo for static quantization flow |
| 3 | +""" |
| 4 | +import torch |
| 5 | +import copy |
| 6 | + |
| 7 | +from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver |
| 8 | +import torch.nn.functional as F |
| 9 | +from torch import Tensor |
| 10 | +from torchao.dtypes import to_affine_quantized_static |
| 11 | +from torchao.quantization.utils import compute_error |
| 12 | + |
| 13 | + |
| 14 | +class CalibratingLinear(torch.nn.Linear): |
| 15 | + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): |
| 16 | + super().__init__(in_features, out_features, bias, device, dtype) |
| 17 | + self.act_obs = act_obs |
| 18 | + self.weight_obs = weight_obs |
| 19 | + |
| 20 | + def forward(self, input: Tensor): |
| 21 | + calibrating_input = self.act_obs(input) |
| 22 | + calibrating_weight = self.weight_obs(self.weight) |
| 23 | + return F.linear(calibrating_input, calibrating_weight, self.bias) |
| 24 | + |
| 25 | + @classmethod |
| 26 | + def from_float(cls, float_linear, act_obs, weight_obs): |
| 27 | + calibrating_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) |
| 28 | + calibrating_linear.weight = float_linear.weight |
| 29 | + calibrating_linear.bias = float_linear.bias |
| 30 | + return calibrating_linear |
| 31 | + |
| 32 | +class QuantizedLinear(torch.nn.Module): |
| 33 | + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor): |
| 34 | + super().__init__() |
| 35 | + self.act_scale, self.act_zero_point = act_obs.calculate_qparams() |
| 36 | + weight_scale, weight_zero_point = weight_obs.calculate_qparams() |
| 37 | + assert weight.dim() == 2 |
| 38 | + block_size = (1, weight.shape[1]) |
| 39 | + target_dtype = torch.uint8 |
| 40 | + self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) |
| 41 | + self.bias = bias |
| 42 | + |
| 43 | + def forward(self, input: Tensor): |
| 44 | + block_size = input.shape |
| 45 | + target_dtype = torch.uint8 |
| 46 | + qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype) |
| 47 | + return F.linear(qinput, self.qweight, self.bias) |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def from_calibrating(cls, calibrating_linear): |
| 51 | + quantized_linear = cls(calibrating_linear.in_features, calibrating_linear.out_features, calibrating_linear.act_obs, calibrating_linear.weight_obs, calibrating_linear.weight, calibrating_linear.bias) |
| 52 | + return quantized_linear |
| 53 | + |
| 54 | +class ToyLinearModel(torch.nn.Module): |
| 55 | + def __init__(self, m=64, n=32, k=64): |
| 56 | + super().__init__() |
| 57 | + self.linear1 = torch.nn.Linear(m, n, bias=False) |
| 58 | + self.linear2 = torch.nn.Linear(n, k, bias=False) |
| 59 | + |
| 60 | + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): |
| 61 | + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) |
| 62 | + |
| 63 | + def forward(self, x): |
| 64 | + x = self.linear1(x) |
| 65 | + x = self.linear2(x) |
| 66 | + return x |
| 67 | + |
| 68 | +dtype = torch.bfloat16 |
| 69 | +m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") |
| 70 | +m_bf16 = copy.deepcopy(m) |
| 71 | +example_inputs = m.example_inputs(dtype=dtype, device="cuda") |
| 72 | + |
| 73 | +m_bf16 = torch.compile(m_bf16, mode='max-autotune') |
| 74 | + |
| 75 | +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
| 76 | +import copy |
| 77 | + |
| 78 | +def to_calibrating_(model, act_obs, weight_obs): |
| 79 | + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) |
| 80 | + replacement_fn = lambda m: CalibratingLinear.from_float(m, act_obs, weight_obs) |
| 81 | + act_obs = copy.deepcopy(act_obs) |
| 82 | + weight_obs = copy.deepcopy(weight_obs) |
| 83 | + _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) |
| 84 | + |
| 85 | +def to_quantized_(model): |
| 86 | + _is_calibrating_linear = lambda m, fqn: isinstance(m, CalibratingLinear) |
| 87 | + |
| 88 | + replacement_fn = lambda m: QuantizedLinear.from_calibrating(m) |
| 89 | + _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_calibrating_linear) |
| 90 | + |
| 91 | +act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda") |
| 92 | +weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda") |
| 93 | + |
| 94 | +before_quant = m(*example_inputs) |
| 95 | + |
| 96 | +to_calibrating_(m, act_obs, weight_obs) |
| 97 | +# calibrating / training |
| 98 | +for _ in range(10): |
| 99 | + m(*example_inputs) |
| 100 | + |
| 101 | +after_obs = m(*example_inputs) |
| 102 | +to_quantized_(m) |
| 103 | + |
| 104 | +print("quantized model:", m) |
| 105 | +after_quant = m(*example_inputs) |
| 106 | +assert compute_error(before_quant, after_quant) > 30 |
| 107 | +print("test passed") |
0 commit comments