|
18 | 18 | """Torch.nn.Module Class Defination.""" |
19 | 19 | # Note: Do not import this file unless you have already imported torch, |
20 | 20 | # since the model classes inherit torch.nn.Module. |
| 21 | +import math |
21 | 22 | import torch |
| 23 | +from torch.nn import functional as F |
22 | 24 | from packaging.version import Version |
23 | 25 |
|
24 | 26 |
|
@@ -146,3 +148,156 @@ def _wrapper_qdq_linear(tmp_model, module_name_list=[]): |
146 | 148 | new_module = QDQLinear(module) |
147 | 149 | set_module(tmp_model, name, new_module) |
148 | 150 | return tmp_model |
| 151 | + |
| 152 | + |
| 153 | +class WeightOnlyLinear(torch.nn.Module): |
| 154 | + def __init__(self, in_features, out_features, bits, groupsize): |
| 155 | + super().__init__() |
| 156 | + self.in_features = in_features |
| 157 | + self.out_features = out_features |
| 158 | + self.bits = bits |
| 159 | + self.groupsize = groupsize if groupsize != -1 else in_features |
| 160 | + self.n_pack = 32 // self.bits |
| 161 | + |
| 162 | + self.register_buffer( |
| 163 | + 'packed_weight', |
| 164 | + torch.zeros( |
| 165 | + (out_features, math.ceil(in_features / self.n_pack)), |
| 166 | + dtype=torch.int32, |
| 167 | + ) |
| 168 | + ) |
| 169 | + self.register_buffer( |
| 170 | + 'scale', |
| 171 | + torch.zeros( |
| 172 | + (out_features, math.ceil(in_features / self.groupsize)), |
| 173 | + dtype=torch.float, |
| 174 | + ) |
| 175 | + ) |
| 176 | + |
| 177 | + def pack(self, int_weight, scale, zp, bias): |
| 178 | + if bias is not None: |
| 179 | + self.register_buffer('bias', torch.zeros(self.out_features, dtype=torch.float)) |
| 180 | + else: |
| 181 | + self.bias = None |
| 182 | + self.bias = bias |
| 183 | + assert scale.shape == self.scale.shape, "Scale shape is mismatched." |
| 184 | + self.scale = scale |
| 185 | + origin_shape = int_weight.shape |
| 186 | + target_shape = self.packed_weight.shape |
| 187 | + assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." |
| 188 | + mask = torch.tensor(2**self.bits - 1, dtype=torch.int32) |
| 189 | + |
| 190 | + # pack weight |
| 191 | + for i in range(target_shape[0]): |
| 192 | + for j in range(target_shape[1]): |
| 193 | + start = self.n_pack * j |
| 194 | + end = self.n_pack * (j + 1) |
| 195 | + tmp = int_weight[i][start: end].type(torch.int32) |
| 196 | + for e in range(len(tmp)): |
| 197 | + tmp[e] &= mask |
| 198 | + tmp[e] = tmp[e] << self.bits * (self.n_pack - 1 - e) |
| 199 | + self.packed_weight[i][j] |= tmp[e] |
| 200 | + |
| 201 | + if zp is not None: |
| 202 | + # pack zero_points |
| 203 | + self.register_buffer( |
| 204 | + 'packed_zp', |
| 205 | + torch.zeros( |
| 206 | + (self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)), |
| 207 | + dtype=torch.int32, |
| 208 | + ) |
| 209 | + ) |
| 210 | + target_shape = self.packed_zp.shape |
| 211 | + for i in range(target_shape[0]): |
| 212 | + for j in range(target_shape[1]): |
| 213 | + start = self.n_pack * j |
| 214 | + end = self.n_pack * (j + 1) |
| 215 | + tmp = zp[i][start: end].type(torch.int32) |
| 216 | + for e in range(len(tmp)): |
| 217 | + tmp[e] &= mask |
| 218 | + tmp[e] = tmp[e] << self.bits * (self.n_pack - 1 - e) |
| 219 | + self.packed_zp[i][j] |= tmp[e] |
| 220 | + |
| 221 | + def recover(self): |
| 222 | + mask = torch.tensor(2**self.bits - 1, dtype=torch.int32) |
| 223 | + if hasattr(self, 'packed_zp'): |
| 224 | + weight_dtype = torch.uint8 |
| 225 | + else: |
| 226 | + weight_dtype = torch.int8 |
| 227 | + # unpack weight |
| 228 | + weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype) |
| 229 | + origin_shape = weight.shape |
| 230 | + target_shape = self.packed_weight.shape |
| 231 | + for i in range(target_shape[0]): |
| 232 | + for j in range(target_shape[1]): |
| 233 | + for e in range(self.n_pack): |
| 234 | + index = j * self.n_pack + e |
| 235 | + if index >= origin_shape[1]: |
| 236 | + continue |
| 237 | + tmp = self.packed_weight[i][j] |
| 238 | + tmp = tmp << 32 - self.bits * (self.n_pack - e) |
| 239 | + tmp = tmp >> 32 - self.bits |
| 240 | + if weight_dtype == torch.uint8: |
| 241 | + tmp &= mask # remove sign bit |
| 242 | + weight[i][index] = tmp.type(weight_dtype) |
| 243 | + # unpack zero_point |
| 244 | + if hasattr(self, 'packed_zp'): |
| 245 | + zp_dtype = torch.int32 # to avoid overflow when weight-zp |
| 246 | + zp = torch.zeros(self.scale.shape, dtype=zp_dtype) |
| 247 | + origin_shape = zp.shape |
| 248 | + target_shape = self.packed_zp.shape |
| 249 | + for i in range(target_shape[0]): |
| 250 | + for j in range(target_shape[1]): |
| 251 | + for e in range(self.n_pack): |
| 252 | + index = j * self.n_pack + e |
| 253 | + if index >= origin_shape[1]: |
| 254 | + continue |
| 255 | + tmp = self.packed_zp[i][j] |
| 256 | + tmp = tmp << 32 - self.bits * (self.n_pack - e) |
| 257 | + tmp = tmp >> 32 - self.bits |
| 258 | + tmp &= mask |
| 259 | + zp[i][index] = tmp.type(zp_dtype) |
| 260 | + # recover fp32 weight with int_weight, scale, and zero_point |
| 261 | + left_element = self.in_features % self.groupsize |
| 262 | + if left_element != 0: |
| 263 | + split_index = self.in_features // self.groupsize * self.groupsize |
| 264 | + weight1 = weight[:, :-split_index].reshape(-1, self.groupsize) |
| 265 | + scale1 = self.scale[:, :-1].reshape(-1, 1) |
| 266 | + zp1 = zp[:, :-1].reshape(-1, 1) |
| 267 | + weight1 = ((weight1 - zp1) * scale1).reshape(self.out_features, -1) |
| 268 | + weight2 = weight[:, -split_index:] |
| 269 | + scale2 = self.scale[:, -1:] |
| 270 | + zp2 = zp[:, -1].reshape(-1, 1) |
| 271 | + weight2 = ((weight2 - zp2) * scale2) |
| 272 | + fp32_weight = torch.cat((weight1, weight2), dim=1) |
| 273 | + else: |
| 274 | + weight = weight.reshape(-1, self.groupsize) |
| 275 | + scale = self.scale.reshape(-1, 1) |
| 276 | + zp = zp.reshape(-1, 1) |
| 277 | + fp32_weight = ((weight - zp) * scale).reshape(self.out_features, -1) |
| 278 | + else: |
| 279 | + # recover fp32 weight with int_weight, scale |
| 280 | + left_element = self.in_features % self.groupsize |
| 281 | + if left_element != 0: |
| 282 | + split_index = self.in_features // self.groupsize * self.groupsize |
| 283 | + weight1 = weight[:, :split_index].reshape(-1, self.groupsize) |
| 284 | + scale1 = self.scale[:, :-1].reshape(-1, 1) |
| 285 | + weight1 = (weight1 * scale1).reshape(self.out_features, -1) |
| 286 | + weight2 = weight[:, split_index:] |
| 287 | + scale2 = self.scale[:, -1:] |
| 288 | + weight2 = (weight2 * scale2) |
| 289 | + fp32_weight = torch.cat((weight1, weight2), dim=1) |
| 290 | + else: |
| 291 | + weight = weight.reshape(-1, self.groupsize) |
| 292 | + scale = self.scale.reshape(-1, 1) |
| 293 | + fp32_weight = (weight * scale).reshape(self.out_features, -1) |
| 294 | + return fp32_weight |
| 295 | + |
| 296 | + def forward(self, input): |
| 297 | + weight = self.recover() |
| 298 | + return F.linear(input, weight, self.bias) |
| 299 | + |
| 300 | + def extra_repr(self) -> str: |
| 301 | + return 'in_features={}, out_features={}, bits={}, group_size={}, bias={}'.format( |
| 302 | + self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None |
| 303 | + ) |
0 commit comments