Skip to content

Commit 5179da1

Browse files
authored
add use_HF_format for export_compressed_model (#1379)
Signed-off-by: Xin He <[email protected]>
1 parent 0a20016 commit 5179da1

File tree

7 files changed

+257
-120
lines changed

7 files changed

+257
-120
lines changed

docs/source/quantization_weight_only.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ To support low memory inference, Neural Compressor implemented WeightOnlyLinear,
9696
| compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64] |
9797
| compression_dim | 1 | 0 means output channel while 1 means input channel |
9898
| scale_dtype | torch.float32 | Data type for scale and bias |
99+
| use_hf_format | False | Whether to use the popular format present on HuggingFace hub |
100+
101+
**Note:** HuggingFace format is quite special, the main differences are as follows:
102+
103+
> 1: Compression Dimension: weight = 1, zero = 0 and both are transposed.
104+
> 2: Zero Point: zero_point-= 1 before compression. zero_point is always required even for sym.
105+
> 3: Group Index: Use the same number for a group instead of recording channel order.
106+
99107

100108
### **User Code Example**
101109
```python
@@ -119,12 +127,14 @@ conf = PostTrainingQuantConfig(
119127
)
120128
q_model = quantization.fit(model, conf, eval_func=eval_func)
121129
q_model.save("saved_results")
122-
compressed_model = q_model.export_compressed_model(
123-
compression_dtype=torch.int32,
124-
compression_dim=1,
125-
scale_dtype=torch.float16,
126-
)
130+
compressed_model = q_model.export_compressed_model()
127131
torch.save(compressed_model.state_dict(), "compressed_model.pt")
132+
# or
133+
model = Model()
134+
compressed_model = export_compressed_model(
135+
model,
136+
saved_dir="saved_results",
137+
)
128138
```
129139

130140
The saved_results folder contains two files: `best_model.pt` and `qconfig.json`, and the generated q_model is a fake quantized model.

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 136 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,12 @@ def __init__(
215215
scale_dtype=torch.float32,
216216
compression_dtype=torch.int32,
217217
compression_dim=1,
218-
gptq_perm=False,
218+
g_idx=False,
219219
device="cpu",
220+
use_hf_format=False,
220221
):
221222
super().__init__()
223+
self.use_hf_format = use_hf_format
222224
self.dtype = dtype
223225
if "int" not in self.dtype: # for nf4, fp4
224226
from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING
@@ -249,69 +251,105 @@ def __init__(
249251
assert compression_dim in [0, 1], (
250252
"Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel."
251253
)
252-
self.register_buffer(
253-
"scale",
254-
torch.zeros(
255-
(out_features, math.ceil(in_features / self.groupsize)),
256-
dtype=self.float_type,
257-
).to(device),
258-
)
259-
if compression_dim == 1:
254+
if self.use_hf_format:
260255
self.register_buffer(
261-
"packed_weight",
256+
"scales",
262257
torch.zeros(
263-
(out_features, math.ceil(in_features / self.n_pack)),
258+
(math.ceil(in_features / self.groupsize), out_features),
259+
dtype=self.float_type,
260+
).to(device),
261+
)
262+
self.scales = self.scales.T
263+
self.register_buffer(
264+
"qweight",
265+
torch.zeros(
266+
(math.ceil(in_features / self.n_pack), out_features),
264267
dtype=self.compressed_dtype,
265268
).to(device),
266269
)
267-
if zp:
268-
self.register_buffer(
269-
"packed_zp",
270-
torch.zeros(
271-
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
272-
dtype=self.compressed_dtype,
273-
).to(device),
274-
)
275-
else:
270+
self.qweight = self.qweight.T
276271
self.register_buffer(
277-
"packed_weight",
272+
"qzeros",
278273
torch.zeros(
279-
(math.ceil(out_features / self.n_pack), in_features),
274+
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
280275
dtype=self.compressed_dtype,
281276
).to(device),
282277
)
283-
if zp:
278+
self.qzeros = self.qzeros.T
279+
else:
280+
self.register_buffer(
281+
"scales",
282+
torch.zeros(
283+
(out_features, math.ceil(in_features / self.groupsize)),
284+
dtype=self.float_type,
285+
).to(device),
286+
)
287+
if compression_dim == 1:
284288
self.register_buffer(
285-
"packed_zp",
289+
"qweight",
286290
torch.zeros(
287-
(math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)),
291+
(out_features, math.ceil(in_features / self.n_pack)),
288292
dtype=self.compressed_dtype,
289293
).to(device),
290294
)
295+
if zp:
296+
self.register_buffer(
297+
"qzeros",
298+
torch.zeros(
299+
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
300+
dtype=self.compressed_dtype,
301+
).to(device),
302+
)
303+
else:
304+
self.register_buffer(
305+
"qweight",
306+
torch.zeros(
307+
(math.ceil(out_features / self.n_pack), in_features),
308+
dtype=self.compressed_dtype,
309+
).to(device),
310+
)
311+
if zp:
312+
self.register_buffer(
313+
"qzeros",
314+
torch.zeros(
315+
(math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)),
316+
dtype=self.compressed_dtype,
317+
).to(device),
318+
)
319+
if g_idx:
320+
self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device))
321+
else:
322+
self.g_idx = None
291323
if bias:
292324
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
293325
else:
294326
self.bias = None
295-
if gptq_perm:
296-
self.register_buffer("gptq_perm", torch.zeros(in_features, dtype=torch.int32).to(device))
297-
else:
298-
self.gptq_perm = None
299327

300-
def pack(self, int_weight, scale, zp, bias, gptq_perm=None):
328+
def pack(self, int_weight, scale, zp, bias, g_idx=None):
301329
int_weight = int_weight.to(self.device)
330+
if self.use_hf_format and zp is None:
331+
# to avoid overflow
332+
int_weight = int_weight.type(torch.int32)
333+
shift_bias = 2 ** (self.bits - 1)
334+
int_weight += shift_bias
335+
zp = torch.zeros_like(scale, dtype=torch.uint8) + shift_bias
302336
if bias is not None:
303337
assert hasattr(self, "bias"), "bias is not set when initializing."
304338
self.bias = bias.type(self.float_type).to(self.device)
305-
if gptq_perm is not None:
306-
assert hasattr(self, "gptq_perm"), "gptq_perm is not set when initializing."
307-
self.gptq_perm = gptq_perm.type(torch.int32).to(self.device)
308-
assert scale.shape == self.scale.shape, "Scale shape is mismatched."
309-
self.scale = scale.type(self.float_type).to(self.device)
310-
if self.compression_dim == 0:
339+
if g_idx is not None:
340+
assert hasattr(self, "g_idx"), "g_idx is not set when initializing."
341+
self.g_idx = g_idx.type(torch.int32).to(self.device)
342+
if self.use_hf_format:
343+
invperm = torch.argsort(self.g_idx)
344+
self.g_idx = invperm // self.groupsize
345+
self.g_idx = self.g_idx.type(torch.int32).to(self.device)
346+
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
347+
self.scales = scale.type(self.float_type).to(self.device)
348+
if not self.use_hf_format and self.compression_dim == 0:
311349
int_weight = int_weight.T
312-
self.packed_weight = self.packed_weight.T
350+
self.qweight = self.qweight.T
313351
origin_shape = int_weight.shape
314-
target_shape = self.packed_weight.shape
352+
target_shape = self.qweight.shape
315353
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
316354
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(self.device)
317355

@@ -323,121 +361,112 @@ def pack(self, int_weight, scale, zp, bias, gptq_perm=None):
323361
for e in range(tmp.shape[1]):
324362
tmp[:, e] &= mask
325363
tmp[:, e] = tmp[:, e] << (self.bits * e)
326-
self.packed_weight[:, j] |= tmp[:, e]
327-
if self.compression_dim == 0:
328-
self.packed_weight = self.packed_weight.T
364+
self.qweight[:, j] |= tmp[:, e]
365+
if not self.use_hf_format and self.compression_dim == 0:
366+
self.qweight = self.qweight.T
329367

330368
if zp is not None:
331369
zp = zp.to(self.device)
332-
if self.compression_dim == 0:
370+
if self.use_hf_format:
371+
zp -= 1
372+
if self.use_hf_format or self.compression_dim == 0:
333373
zp = zp.T
334-
self.packed_zp = self.packed_zp.T
335-
assert hasattr(self, "packed_zp"), "zp is not set when initializing."
336-
target_shape = self.packed_zp.shape
374+
self.qzeros = self.qzeros.T
375+
assert hasattr(self, "qzeros"), "zp is not set when initializing."
376+
target_shape = self.qzeros.shape
337377
for j in range(target_shape[1]):
338378
start = self.n_pack * j
339379
end = self.n_pack * (j + 1)
340380
tmp = zp[:, start:end].type(self.compressed_dtype)
341381
for e in range(tmp.shape[1]):
342382
tmp[:, e] &= mask
343383
tmp[:, e] = tmp[:, e] << (self.bits * e)
344-
self.packed_zp[:, j] |= tmp[:, e]
345-
if self.compression_dim == 0:
346-
self.packed_zp = self.packed_zp.T
384+
self.qzeros[:, j] |= tmp[:, e]
385+
if self.use_hf_format or self.compression_dim == 0:
386+
self.qzeros = self.qzeros.T
387+
if self.use_hf_format:
388+
self.scales = self.scales.T
389+
self.qweight = self.qweight.T
390+
self.g_idx = self.g_idx
391+
self.qzeros = self.qzeros.T
347392

348393
def recover(self):
349394
logger.debug(f"Recovering {self} weight")
350-
device = self.scale.device
395+
if self.use_hf_format:
396+
# Prevent broken id links of self.scales and self.scales
397+
self.scales = self.scales.T
398+
self.qweight = self.qweight.T
399+
self.g_idx = self.g_idx
400+
self.qzeros = self.qzeros.T
401+
device = self.scales.device
402+
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
403+
if self.g_idx is None:
404+
# used for recovering fp32_weight
405+
self.g_idx = torch.tensor([i // self.groupsize for i in range(self.in_features)], dtype=torch.int32)
351406
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(device)
352-
if hasattr(self, "packed_zp"):
407+
if hasattr(self, "qzeros"):
353408
weight_dtype = torch.uint8
354409
else:
355410
weight_dtype = torch.int8
356411
# unpack weight
357412
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
358-
packed_weight = self.packed_weight
359-
if self.compression_dim == 0:
413+
qweight = self.qweight
414+
if not self.use_hf_format and self.compression_dim == 0:
360415
weight = weight.T
361-
packed_weight = packed_weight.T
416+
qweight = qweight.T
362417
origin_shape = weight.shape
363-
target_shape = packed_weight.shape
418+
target_shape = qweight.shape
364419
for j in range(target_shape[1]):
365420
for e in range(self.n_pack):
366421
index = j * self.n_pack + e
367422
if index >= origin_shape[1]:
368423
continue
369-
tmp = packed_weight[:, j]
424+
tmp = qweight[:, j]
370425
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
371426
tmp = tmp >> self.compress_bits - self.bits
372427
if weight_dtype == torch.uint8:
373428
tmp &= mask # remove sign bit
374429
weight[:, index] = tmp.type(weight_dtype)
375-
if self.compression_dim == 0:
430+
if not self.use_hf_format and self.compression_dim == 0:
376431
weight = weight.T
377432
if "int" not in self.dtype:
378433
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
379434
for k, v in self.int2float_mapping.items():
380435
new_weight += torch.where(weight == k, v, 0)
381436
weight = new_weight
382437
# unpack zero_point
383-
if hasattr(self, "packed_zp"):
438+
if hasattr(self, "qzeros"):
384439
zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp
385-
zp = torch.zeros(self.scale.shape, dtype=zp_dtype).to(device)
386-
packed_zp = self.packed_zp
387-
if self.compression_dim == 0:
440+
zp = torch.zeros(self.scales.shape, dtype=zp_dtype).to(device)
441+
qzeros = self.qzeros
442+
if self.use_hf_format or self.compression_dim == 0:
388443
zp = zp.T
389-
packed_zp = packed_zp.T
444+
qzeros = qzeros.T
390445
origin_shape = zp.shape
391-
target_shape = packed_zp.shape
446+
target_shape = qzeros.shape
392447
for j in range(target_shape[1]):
393448
for e in range(self.n_pack):
394449
index = j * self.n_pack + e
395450
if index >= origin_shape[1]:
396451
continue
397-
tmp = packed_zp[:, j]
452+
tmp = qzeros[:, j]
398453
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
399454
tmp = tmp >> self.compress_bits - self.bits
400455
tmp &= mask
401456
zp[:, index] = tmp.type(zp_dtype)
402-
if self.compression_dim == 0:
457+
if self.use_hf_format or self.compression_dim == 0:
403458
zp = zp.T
459+
if self.use_hf_format:
460+
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
461+
zp += 1
462+
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
404463
# recover fp32 weight with int_weight, scale, and zero_point
405-
left_element = self.in_features % self.groupsize
406-
if left_element != 0:
407-
split_index = self.in_features // self.groupsize * self.groupsize
408-
weight1 = weight[:, :-split_index].reshape(-1, self.groupsize)
409-
scale1 = self.scale[:, :-1].reshape(-1, 1)
410-
zp1 = zp[:, :-1].reshape(-1, 1)
411-
weight1 = ((weight1 - zp1) * scale1).reshape(self.out_features, -1)
412-
weight2 = weight[:, -split_index:]
413-
scale2 = self.scale[:, -1:]
414-
zp2 = zp[:, -1].reshape(-1, 1)
415-
weight2 = (weight2 - zp2) * scale2
416-
fp32_weight = torch.cat((weight1, weight2), dim=1)
417-
else:
418-
weight = weight.reshape(-1, self.groupsize)
419-
scale = self.scale.reshape(-1, 1)
420-
zp = zp.reshape(-1, 1)
421-
fp32_weight = ((weight - zp) * scale).reshape(self.out_features, -1)
464+
for idx in range(self.in_features):
465+
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * self.scales[:, self.g_idx[idx]]
422466
else:
423467
# recover fp32 weight with int_weight, scale
424-
left_element = self.in_features % self.groupsize
425-
if left_element != 0:
426-
split_index = self.in_features // self.groupsize * self.groupsize
427-
weight1 = weight[:, :split_index].reshape(-1, self.groupsize)
428-
scale1 = self.scale[:, :-1].reshape(-1, 1)
429-
weight1 = (weight1 * scale1).reshape(self.out_features, -1)
430-
weight2 = weight[:, split_index:]
431-
scale2 = self.scale[:, -1:]
432-
weight2 = weight2 * scale2
433-
fp32_weight = torch.cat((weight1, weight2), dim=1)
434-
else:
435-
weight = weight.reshape(-1, self.groupsize)
436-
scale = self.scale.reshape(-1, 1)
437-
fp32_weight = (weight * scale).reshape(self.out_features, -1)
438-
if self.gptq_perm is not None:
439-
invperm = torch.argsort(self.gptq_perm)
440-
fp32_weight = fp32_weight[:, invperm]
468+
for idx in range(self.in_features):
469+
fp32_weight[:, idx] = weight[:, idx] * self.scales[:, self.g_idx[idx]]
441470
return fp32_weight
442471

443472
def forward(self, input):
@@ -453,9 +482,16 @@ def forward(self, input):
453482
return F.linear(input, weight, self.bias)
454483

455484
def extra_repr(self) -> str:
456-
return "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format(
457-
self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None
485+
tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format(
486+
self.in_features,
487+
self.out_features,
488+
self.bits,
489+
self.groupsize,
490+
self.bias is not None,
458491
)
492+
if self.use_hf_format:
493+
tmp_str += ", use_hf_format=True"
494+
return tmp_str
459495

460496

461497
class FakeAffineTensorQuantFunction(Function):

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def rtn_quantize(
396396
compression_dim = kwargs.get("compression_dim", 1)
397397
scale_dtype = kwargs.get("scale_dtype", torch.float32)
398398
device = kwargs.get("device", "cpu")
399+
use_hf_format = kwargs.get("use_hf_format", False)
399400
for name, m in model.named_modules():
400401
if m.__class__.__name__ not in supported_layers:
401402
continue
@@ -451,6 +452,7 @@ def rtn_quantize(
451452
compression_dim=compression_dim,
452453
scale_dtype=scale_dtype,
453454
device=device,
455+
use_hf_format=use_hf_format,
454456
)
455457
new_module.pack(int_weight, scale, zp, m.bias)
456458
if name == "":

0 commit comments

Comments
 (0)