Skip to content

Commit 00bbf84

Browse files
authored
add WeightOnlyLinear for low memory inference (#1076)
Signed-off-by: Xin He <[email protected]>
1 parent 72f079b commit 00bbf84

File tree

6 files changed

+318
-45
lines changed

6 files changed

+318
-45
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from packaging.version import Version
2525
import yaml
2626
from functools import partial
27+
from neural_compressor.adaptor.torch_utils.util import set_module
2728
from neural_compressor.utils.utility import dump_elapsed_time
2829
from .adaptor import adaptor_registry, Adaptor
2930
from ..utils.utility import LazyImport, CpuInfo, GLOBAL_STATE, MODE
@@ -4548,7 +4549,8 @@ def rtn_quantize(self, model, tune_cfg):
45484549
if algorithm != 'RTN':
45494550
continue
45504551
m = fetch_module(model, op_name)
4551-
rtn_quantize(m, num_bits, group_size, scheme)
4552+
m = rtn_quantize(m, num_bits, group_size, scheme, return_int=False)
4553+
set_module(model, op_name, m)
45524554
return model
45534555

45544556
def gptq_quantize(self, model, tune_cfg, dataloader):
@@ -4591,6 +4593,7 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
45914593
flipped_dict[m] = {'absorb_layer': k}
45924594

45934595
# check tune_cfg to skip layers without AWQ config
4596+
weight_config = {}
45944597
skipped_op_name_set = set()
45954598
for key, config in tune_cfg['op'].items():
45964599
op_name, op_type = key
@@ -4599,29 +4602,26 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
45994602
absorb_to_layer.pop(flipped_dict[op_name]['absorb_layer'])
46004603
continue
46014604
else:
4605+
weight_config[op_name] = {}
4606+
weight_config[op_name]['bits'] = config['weight']['bits']
4607+
weight_config[op_name]['group_size'] = config['weight']['group_size']
4608+
weight_config[op_name]['scheme'] = config['weight']['scheme']
46024609
if op_name in flipped_dict:
4603-
flipped_dict[op_name]['bits'] = config['weight']['bits']
4604-
flipped_dict[op_name]['group_size'] = config['weight']['group_size']
4605-
flipped_dict[op_name]['scheme'] = config['weight']['scheme']
46064610
algorithm = config['weight']['algorithm']
46074611
if algorithm != 'AWQ':
4608-
if op_name in flipped_dict:
4609-
absorb_to_layer.pop(flipped_dict[op_name]['absorb_layer'])
4612+
absorb_to_layer.pop(weight_config[op_name]['absorb_layer'])
46104613
else:
46114614
skipped_op_name_set.add(op_name)
46124615
if skipped_op_name_set:
46134616
logger.info("{} is skipped by AWQ algorithm".format(skipped_op_name_set))
46144617

46154618
# collect AWQ config from tune_cfg for quantization.
4616-
weight_config = {}
46174619
if len(absorb_to_layer) == 0:
46184620
logger.warning('No absorb layer needs AWQ algorithim, skip it')
46194621
else:
46204622
logger.debug("**absorb layer**: **absorbed layers**")
46214623
for k, v in absorb_to_layer.items():
46224624
logger.debug(f"{k}: {v}")
4623-
for m in v:
4624-
weight_config[m] = flipped_dict[m]
46254625
logger.info("Absorbed layers with the same absorb layer use the same config")
46264626

46274627
if 'awq_args' in self.recipes:
@@ -4641,6 +4641,7 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
46414641
mse_range=mse_range,
46424642
calib_func=calib_func,
46434643
n_blocks=n_blocks,
4644+
return_int=False,
46444645
)
46454646
return model
46464647

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
"""Torch.nn.Module Class Defination."""
1919
# Note: Do not import this file unless you have already imported torch,
2020
# since the model classes inherit torch.nn.Module.
21+
import math
2122
import torch
23+
from torch.nn import functional as F
2224
from packaging.version import Version
2325

2426

@@ -146,3 +148,156 @@ def _wrapper_qdq_linear(tmp_model, module_name_list=[]):
146148
new_module = QDQLinear(module)
147149
set_module(tmp_model, name, new_module)
148150
return tmp_model
151+
152+
153+
class WeightOnlyLinear(torch.nn.Module):
154+
def __init__(self, in_features, out_features, bits, groupsize):
155+
super().__init__()
156+
self.in_features = in_features
157+
self.out_features = out_features
158+
self.bits = bits
159+
self.groupsize = groupsize if groupsize != -1 else in_features
160+
self.n_pack = 32 // self.bits
161+
162+
self.register_buffer(
163+
'packed_weight',
164+
torch.zeros(
165+
(out_features, math.ceil(in_features / self.n_pack)),
166+
dtype=torch.int32,
167+
)
168+
)
169+
self.register_buffer(
170+
'scale',
171+
torch.zeros(
172+
(out_features, math.ceil(in_features / self.groupsize)),
173+
dtype=torch.float,
174+
)
175+
)
176+
177+
def pack(self, int_weight, scale, zp, bias):
178+
if bias is not None:
179+
self.register_buffer('bias', torch.zeros(self.out_features, dtype=torch.float))
180+
else:
181+
self.bias = None
182+
self.bias = bias
183+
assert scale.shape == self.scale.shape, "Scale shape is mismatched."
184+
self.scale = scale
185+
origin_shape = int_weight.shape
186+
target_shape = self.packed_weight.shape
187+
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
188+
mask = torch.tensor(2**self.bits - 1, dtype=torch.int32)
189+
190+
# pack weight
191+
for i in range(target_shape[0]):
192+
for j in range(target_shape[1]):
193+
start = self.n_pack * j
194+
end = self.n_pack * (j + 1)
195+
tmp = int_weight[i][start: end].type(torch.int32)
196+
for e in range(len(tmp)):
197+
tmp[e] &= mask
198+
tmp[e] = tmp[e] << self.bits * (self.n_pack - 1 - e)
199+
self.packed_weight[i][j] |= tmp[e]
200+
201+
if zp is not None:
202+
# pack zero_points
203+
self.register_buffer(
204+
'packed_zp',
205+
torch.zeros(
206+
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
207+
dtype=torch.int32,
208+
)
209+
)
210+
target_shape = self.packed_zp.shape
211+
for i in range(target_shape[0]):
212+
for j in range(target_shape[1]):
213+
start = self.n_pack * j
214+
end = self.n_pack * (j + 1)
215+
tmp = zp[i][start: end].type(torch.int32)
216+
for e in range(len(tmp)):
217+
tmp[e] &= mask
218+
tmp[e] = tmp[e] << self.bits * (self.n_pack - 1 - e)
219+
self.packed_zp[i][j] |= tmp[e]
220+
221+
def recover(self):
222+
mask = torch.tensor(2**self.bits - 1, dtype=torch.int32)
223+
if hasattr(self, 'packed_zp'):
224+
weight_dtype = torch.uint8
225+
else:
226+
weight_dtype = torch.int8
227+
# unpack weight
228+
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype)
229+
origin_shape = weight.shape
230+
target_shape = self.packed_weight.shape
231+
for i in range(target_shape[0]):
232+
for j in range(target_shape[1]):
233+
for e in range(self.n_pack):
234+
index = j * self.n_pack + e
235+
if index >= origin_shape[1]:
236+
continue
237+
tmp = self.packed_weight[i][j]
238+
tmp = tmp << 32 - self.bits * (self.n_pack - e)
239+
tmp = tmp >> 32 - self.bits
240+
if weight_dtype == torch.uint8:
241+
tmp &= mask # remove sign bit
242+
weight[i][index] = tmp.type(weight_dtype)
243+
# unpack zero_point
244+
if hasattr(self, 'packed_zp'):
245+
zp_dtype = torch.int32 # to avoid overflow when weight-zp
246+
zp = torch.zeros(self.scale.shape, dtype=zp_dtype)
247+
origin_shape = zp.shape
248+
target_shape = self.packed_zp.shape
249+
for i in range(target_shape[0]):
250+
for j in range(target_shape[1]):
251+
for e in range(self.n_pack):
252+
index = j * self.n_pack + e
253+
if index >= origin_shape[1]:
254+
continue
255+
tmp = self.packed_zp[i][j]
256+
tmp = tmp << 32 - self.bits * (self.n_pack - e)
257+
tmp = tmp >> 32 - self.bits
258+
tmp &= mask
259+
zp[i][index] = tmp.type(zp_dtype)
260+
# recover fp32 weight with int_weight, scale, and zero_point
261+
left_element = self.in_features % self.groupsize
262+
if left_element != 0:
263+
split_index = self.in_features // self.groupsize * self.groupsize
264+
weight1 = weight[:, :-split_index].reshape(-1, self.groupsize)
265+
scale1 = self.scale[:, :-1].reshape(-1, 1)
266+
zp1 = zp[:, :-1].reshape(-1, 1)
267+
weight1 = ((weight1 - zp1) * scale1).reshape(self.out_features, -1)
268+
weight2 = weight[:, -split_index:]
269+
scale2 = self.scale[:, -1:]
270+
zp2 = zp[:, -1].reshape(-1, 1)
271+
weight2 = ((weight2 - zp2) * scale2)
272+
fp32_weight = torch.cat((weight1, weight2), dim=1)
273+
else:
274+
weight = weight.reshape(-1, self.groupsize)
275+
scale = self.scale.reshape(-1, 1)
276+
zp = zp.reshape(-1, 1)
277+
fp32_weight = ((weight - zp) * scale).reshape(self.out_features, -1)
278+
else:
279+
# recover fp32 weight with int_weight, scale
280+
left_element = self.in_features % self.groupsize
281+
if left_element != 0:
282+
split_index = self.in_features // self.groupsize * self.groupsize
283+
weight1 = weight[:, :split_index].reshape(-1, self.groupsize)
284+
scale1 = self.scale[:, :-1].reshape(-1, 1)
285+
weight1 = (weight1 * scale1).reshape(self.out_features, -1)
286+
weight2 = weight[:, split_index:]
287+
scale2 = self.scale[:, -1:]
288+
weight2 = (weight2 * scale2)
289+
fp32_weight = torch.cat((weight1, weight2), dim=1)
290+
else:
291+
weight = weight.reshape(-1, self.groupsize)
292+
scale = self.scale.reshape(-1, 1)
293+
fp32_weight = (weight * scale).reshape(self.out_features, -1)
294+
return fp32_weight
295+
296+
def forward(self, input):
297+
weight = self.recover()
298+
return F.linear(input, weight, self.bias)
299+
300+
def extra_repr(self) -> str:
301+
return 'in_features={}, out_features={}, bits={}, group_size={}, bias={}'.format(
302+
self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None
303+
)

0 commit comments

Comments
 (0)