Skip to content

Commit 050dc44

Browse files
tgafniAsaf Karnieli
andauthored
fp8 aware gptq (hybrid gptq) (#154)
* fp8 aware gptq (hybrid gptq) * review1 * loading bias to mixed low precision * fixing tests for fp8 aware quantization and hybrid re-ordering * Addressed second review round comments * Adressed review 3 comments --------- Co-authored-by: Asaf Karnieli <[email protected]>
1 parent 8219b5a commit 050dc44

File tree

10 files changed

+375
-60
lines changed

10 files changed

+375
-60
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@
101101
parser.add_argument("--gptq_blockwise", action="store_true",
102102
help="Whether to quantize blockwise.")
103103
parser.add_argument("--blockwise_load_folder", default=None, type=str, help="Directory to load blockwise checkpoints from.")
104+
parser.add_argument("--fp8_aware", action="store_true", help="Enable an FP8-aware GPTQ quantization flow, "
105+
"where an intermediate FP8 quantization step is applied.")
106+
parser.add_argument("--hybrid_act_order", action="store_true", help="Enable constrained activation reordering: "
107+
"elements can be reordered within each group "
108+
"and the groups themselves can also be reordered, "
109+
"but elements cannot move between groups.")
104110

105111
# =============AWQ configs====================
106112
parser.add_argument("--use_auto_scale", action="store_true",
@@ -458,6 +464,8 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
458464
use_mse_search=args.woq_use_mse_search,
459465
percdamp=args.gptq_percdamp,
460466
act_order=args.gptq_actorder,
467+
hybrid_order = args.hybrid_act_order,
468+
fp8_aware = args.fp8_aware,
461469
block_size=args.gptq_block_size,
462470
static_groups=args.gptq_static_groups,
463471
use_double_quant=False,

neural_compressor/torch/algorithms/mixed_low_precision/modules.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,30 @@
99
from ..weight_only.modules import HPUWeightOnlyLinear
1010
from neural_compressor.torch.utils import accelerator, logger
1111

12+
cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0]
1213

1314
class HPUMixedPrecisionLinear(HPUWeightOnlyLinear):
1415
"""Weight and Activations quant (W4A8 gptq) Linear for HPU device."""
1516

1617
def __init__(
17-
self, in_features, out_features,
18+
self, in_features, out_features, bias,
1819
**kwargs,
1920
):
2021
"""Init the HPUMixedPrecisionLinear object.
2122
"""
22-
super(HPUMixedPrecisionLinear, self).__init__(in_features, out_features)
23+
super(HPUMixedPrecisionLinear, self).__init__(in_features, out_features, bias=bias)
2324

2425
def forward(self, input):
2526
"""The forward function of HPUMixedPrecisionLinear."""
2627
input_dtype = input.dtype
2728
output_shape = input.shape[:-1] + (self.out_features,)
2829
scales = self.scales
30+
scale_bf16_to_fp8 = self.scale_bf16_to_fp8
2931
qweight = self.qweight
3032
zeros = self.qzeros
31-
weight = torch.ops.hpu.convert_from_uint4(qweight, scales/self.matmul_internal.scale_other, zeros, torch.float8_e4m3fn) # todo: div scales in init
33+
self.matmul_internal.scale_other = torch.nn.Parameter(scale_bf16_to_fp8)
34+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, torch.bfloat16) # the uint4->fp8 is currently slower and with bugs. Jira ticket: https://jira.habana-labs.com/browse/SW-218009
35+
weight = cast_to_fp8_fcn(weight, torch.float8_e4m3fn)
3236
output = self.matmul_internal(input, weight)
3337
output = output.to(dtype=input_dtype).reshape(
3438
output_shape
@@ -38,7 +42,8 @@ def forward(self, input):
3842

3943
@staticmethod
4044
def convert_from_weight_only(obj):
41-
new_self = HPUMixedPrecisionLinear(obj.in_features, obj.out_features)
45+
bias = obj.bias is not None
46+
new_self = HPUMixedPrecisionLinear(obj.in_features, obj.out_features, bias)
4247
for attr, value in vars(obj).items():
4348
setattr(new_self, attr, value)
4449
new_self.matmul_internal.no_input_quant = True # flag for 8bit input, which shouldn't be quantized in matmul

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 190 additions & 13 deletions
Large diffs are not rendered by default.

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def forward(self, X):
6262
class UnpackedWeightOnlyLinearParams(dict):
6363
"""Contains all unpacked weight values."""
6464

65-
def __init__(self, unpack_weight, scales, unpack_zp, **kwargs):
65+
def __init__(self, unpack_weight, scales, scale_bf16_to_fp8, unpack_zp, **kwargs):
6666
"""Create dict."""
67-
super().__init__(int_weight=unpack_weight, scales=scales, zp=unpack_zp, **kwargs)
67+
super().__init__(int_weight=unpack_weight, scales=scales, scale_bf16_to_fp8 = scale_bf16_to_fp8, zp=unpack_zp, **kwargs)
6868

6969
def to(self, device):
7070
"""Change device for all values."""
@@ -209,6 +209,14 @@ def __init__(
209209
dtype=self.float_type,
210210
).to(device),
211211
)
212+
self.register_buffer(
213+
"scale_bf16_to_fp8",
214+
torch.zeros(
215+
1,
216+
dtype=self.float_type,
217+
).to(device),
218+
)
219+
# scale_bf16_to_fp8 is only used in w4a8 measurement mode and currently supports only per-tensor scaling
212220
self.register_buffer(
213221
"qweight",
214222
torch.zeros(
@@ -234,6 +242,13 @@ def __init__(
234242
dtype=self.float_type,
235243
).to(device),
236244
)
245+
self.register_buffer(
246+
"scale_bf16_to_fp8",
247+
torch.zeros(
248+
1,
249+
dtype=self.float_type,
250+
).to(device),
251+
)
237252
if compression_dim == 1:
238253
self.register_buffer(
239254
"qweight",
@@ -275,7 +290,7 @@ def __init__(
275290
else:
276291
self.g_idx = None
277292

278-
def pack(self, int_weight, scales, zp, bias=None, g_idx=None, **kwargs):
293+
def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=None, **kwargs):
279294
"""Pack int weight."""
280295
if self.use_optimum_format:
281296
self.scales = self.scales.T.contiguous()
@@ -301,6 +316,8 @@ def pack(self, int_weight, scales, zp, bias=None, g_idx=None, **kwargs):
301316
self.g_idx = self.g_idx.type(torch.int32).to(self.device)
302317
assert scales.shape == self.scales.shape, f"{scales.shape} != {self.scales.shape} Scale shape is mismatched."
303318
self.scales = scales.type(self.float_type).to(self.device)
319+
if scale_bf16_to_fp8 is not None:
320+
self.scale_bf16_to_fp8 = scale_bf16_to_fp8.type(self.float_type).to(self.device)
304321
if not self.use_optimum_format and self.compression_dim == 0:
305322
int_weight = int_weight.T.contiguous()
306323
self.qweight = self.qweight.T.contiguous()
@@ -332,6 +349,7 @@ def pack(self, int_weight, scales, zp, bias=None, g_idx=None, **kwargs):
332349
def unpack(self):
333350
"""Unpack weight and zero point."""
334351
scales = self.scales.T.contiguous() if self.use_optimum_format else self.scales
352+
scale_bf16_to_fp8 = self.scale_bf16_to_fp8
335353
qweight = self.qweight.T.contiguous() if self.use_optimum_format else self.qweight
336354

337355
device = scales.device
@@ -367,14 +385,15 @@ def unpack(self):
367385
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
368386
zp += 1
369387
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
370-
return UnpackedWeightOnlyLinearParams(weight, scales, zp, g_idx=self.g_idx, bias=self.bias)
388+
return UnpackedWeightOnlyLinearParams(weight, scales, scale_bf16_to_fp8, zp, g_idx=self.g_idx, bias=self.bias)
371389

372390
def recover(self):
373391
"""Recover fp32 weight from packed weight."""
374392
logger.debug(f"Recovering {self} weight")
375393
unpack_params_dict = self.unpack()
376394
weight = unpack_params_dict.get("int_weight")
377395
scales = unpack_params_dict.get("scales")
396+
scale_bf16_to_fp8 = unpack_params_dict.get("scale_bf16_to_fp8")
378397
zp = unpack_params_dict.get("zp")
379398

380399
device = scales.device
@@ -668,7 +687,13 @@ def __init__(
668687
dtype=self.float_type,
669688
),
670689
)
671-
690+
self.register_buffer(
691+
"scale_bf16_to_fp8",
692+
torch.zeros(
693+
1,
694+
dtype=self.float_type,
695+
),
696+
)
672697
if g_idx:
673698
self.register_buffer(
674699
"g_idx",
@@ -687,17 +712,22 @@ def forward(self, input):
687712
input_dtype = input.dtype
688713
output_shape = input.shape[:-1] + (self.out_features,)
689714
scales = self.scales
715+
scale_bf16_to_fp8 = self.scale_bf16_to_fp8 # Added by Tomer, per tensor scale.
690716
qweight = self.qweight
691717
zeros = self.qzeros
692-
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype)
718+
if scale_bf16_to_fp8 > 0: # this means we are at w4a8 mode.
719+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, torch.float8_e4m3fn)
720+
weight = weight.to(input_dtype) * scale_bf16_to_fp8
721+
else:
722+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype)
693723
output = self.matmul_internal(input, weight)
694724
output = output.to(dtype=input_dtype).reshape(
695725
output_shape
696726
) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.
697727
output = output + self.bias if self.bias is not None else output
698728
return output
699729

700-
def pack(self, int_weight, scales, zp, bias=None, g_idx=None):
730+
def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=None):
701731
"""Pack weight and zero point."""
702732
logger.debug("Packing for HPU")
703733

@@ -706,6 +736,7 @@ def pack(self, int_weight, scales, zp, bias=None, g_idx=None):
706736
qweight = int_weight.T.contiguous()
707737

708738
self.scales = scales.to(dtype=torch.bfloat16)
739+
self.scale_bf16_to_fp8 = scale_bf16_to_fp8.to(dtype=torch.bfloat16)
709740

710741
# weights and zp are on device from unpack, need to load to cpu for packing
711742
self.qweight = qweight.cpu()

neural_compressor/torch/algorithms/weight_only/save_load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def _replace_woqlinear_modules(self, name, linear_module, module_quantization_co
450450

451451
def _load_data_to_new_module(self, new_module, module_name):
452452
new_module_state_dict = {}
453-
for key in [".qweight", ".scales", ".qzeros", ".bias", ".g_idx"]:
453+
for key in [".qweight", ".scales", ".scale_bf16_to_fp8", ".qzeros", ".bias", ".g_idx"]:
454454
full_name = module_name + key
455455
if full_name in self.loaded_state_dict:
456456
new_module_state_dict[key[1:]] = self.loaded_state_dict.pop(full_name)

neural_compressor/torch/algorithms/weight_only/utility.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def search_clip(m, bits=4, group_size=32, scheme="asym", dtype="int", enable_ful
479479
return best_clip_ratio
480480

481481

482-
def quant_weight_w_scale(weight, scale, zp=None, group_size=-1, dtype="int"):
482+
def quant_weight_w_scale(weight, scale, scale_bf16_to_fp8, zp=None, group_size=-1, dtype="int", fp8_aware=False):
483483
"""Quant and dequant tensor with group size. It's an in-place function.
484484
485485
Args:
@@ -494,6 +494,11 @@ def quant_weight_w_scale(weight, scale, zp=None, group_size=-1, dtype="int"):
494494
"""
495495
device = weight.device
496496
scale = scale.to(device)
497+
if fp8_aware:
498+
weight.mul_(1 / scale_bf16_to_fp8)
499+
weight = torch.clamp(weight, min=-torch.finfo(torch.float8_e4m3fnuz).max, max=torch.finfo(torch.float8_e4m3fnuz).max)
500+
weight = weight.to(torch.float8_e4m3fn)
501+
weight = weight.to(torch.float32)
497502
if zp is not None:
498503
zp = zp.to(device)
499504
# group_size = -1

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def gptq_entry(
158158
"double_quant_sym": quant_config.double_quant_use_sym,
159159
"double_quant_group_size": quant_config.double_quant_group_size,
160160
"act_order": quant_config.act_order,
161+
"hybrid_order": quant_config.hybrid_order,
162+
"fp8_aware": quant_config.fp8_aware,
161163
"percdamp": quant_config.percdamp,
162164
"block_size": quant_config.block_size,
163165
"static_groups": quant_config.static_groups,

neural_compressor/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ class GPTQConfig(TorchBaseConfig):
352352
"quant_lm_head",
353353
# gptq params
354354
"act_order",
355+
"hybrid_order",
356+
"fp8_aware",
355357
"percdamp",
356358
"block_size",
357359
"static_groups",
@@ -379,6 +381,8 @@ def __init__(
379381
quant_lm_head: bool = False,
380382
# gptq params
381383
act_order: bool = False,
384+
hybrid_order: bool = False,
385+
fp8_aware: bool = False,
382386
percdamp: float = 0.01,
383387
block_size: int = 2048,
384388
static_groups: bool = False,
@@ -406,6 +410,10 @@ def __init__(
406410
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
407411
act_order (bool): Whether to sort Hessian's diagonal values to rearrange channel-wise
408412
quantization order. Default is False.
413+
hybrid_order (bool): Enables activation re-ordering with no inference overhead.
414+
Weights are re-ordered within their groups without cross-group mixing.
415+
fp8_aware (bool): Whether to include an FP8 quantization step in the GPTQ process.
416+
This improves accuracy when using the W4A8 quantization scheme.
409417
percdamp (float): Percentage of Hessian's diagonal values' average, which will be added to
410418
Hessian's diagonal to increase numerical stability. Default is 0.01.
411419
block_size (int): Execute GPTQ quantization per block, block shape = [C_out, block_size].
@@ -438,6 +446,8 @@ def __init__(
438446
self.double_quant_group_size = double_quant_group_size
439447
# gptq
440448
self.act_order = act_order
449+
self.hybrid_order = hybrid_order
450+
self.fp8_aware = fp8_aware
441451
self.percdamp = percdamp
442452
self.block_size = block_size
443453
self.static_groups = static_groups
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"mode": "QUANTIZE",
3+
"observer": "maxabs",
4+
"scale_method": "maxabs_pow2",
5+
"blacklist": {
6+
"types": [],
7+
"names": [
8+
"matmul_qk",
9+
"matmul_av",
10+
"k_cache",
11+
"v_cache",
12+
"fused_scaled_dot_product_attention",
13+
"lm_head"
14+
]
15+
},
16+
"scale_params": {
17+
"input_backoff": 1,
18+
"weight_backoff": 1
19+
},
20+
"dump_stats_path": "./test_outputs/unit_test",
21+
"int4_weights": "True"
22+
}

0 commit comments

Comments
 (0)