Skip to content

Commit ba9475d

Browse files
authored
Revert "fp8 aware gptq (hybrid gptq) (#154)" (#184)
This reverts commit 050dc44.
1 parent b591068 commit ba9475d

File tree

10 files changed

+60
-375
lines changed

10 files changed

+60
-375
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,6 @@
101101
parser.add_argument("--gptq_blockwise", action="store_true",
102102
help="Whether to quantize blockwise.")
103103
parser.add_argument("--blockwise_load_folder", default=None, type=str, help="Directory to load blockwise checkpoints from.")
104-
parser.add_argument("--fp8_aware", action="store_true", help="Enable an FP8-aware GPTQ quantization flow, "
105-
"where an intermediate FP8 quantization step is applied.")
106-
parser.add_argument("--hybrid_act_order", action="store_true", help="Enable constrained activation reordering: "
107-
"elements can be reordered within each group "
108-
"and the groups themselves can also be reordered, "
109-
"but elements cannot move between groups.")
110104

111105
# =============AWQ configs====================
112106
parser.add_argument("--use_auto_scale", action="store_true",
@@ -464,8 +458,6 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
464458
use_mse_search=args.woq_use_mse_search,
465459
percdamp=args.gptq_percdamp,
466460
act_order=args.gptq_actorder,
467-
hybrid_order = args.hybrid_act_order,
468-
fp8_aware = args.fp8_aware,
469461
block_size=args.gptq_block_size,
470462
static_groups=args.gptq_static_groups,
471463
use_double_quant=False,

neural_compressor/torch/algorithms/mixed_low_precision/modules.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,26 @@
99
from ..weight_only.modules import HPUWeightOnlyLinear
1010
from neural_compressor.torch.utils import accelerator, logger
1111

12-
cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0]
1312

1413
class HPUMixedPrecisionLinear(HPUWeightOnlyLinear):
1514
"""Weight and Activations quant (W4A8 gptq) Linear for HPU device."""
1615

1716
def __init__(
18-
self, in_features, out_features, bias,
17+
self, in_features, out_features,
1918
**kwargs,
2019
):
2120
"""Init the HPUMixedPrecisionLinear object.
2221
"""
23-
super(HPUMixedPrecisionLinear, self).__init__(in_features, out_features, bias=bias)
22+
super(HPUMixedPrecisionLinear, self).__init__(in_features, out_features)
2423

2524
def forward(self, input):
2625
"""The forward function of HPUMixedPrecisionLinear."""
2726
input_dtype = input.dtype
2827
output_shape = input.shape[:-1] + (self.out_features,)
2928
scales = self.scales
30-
scale_bf16_to_fp8 = self.scale_bf16_to_fp8
3129
qweight = self.qweight
3230
zeros = self.qzeros
33-
self.matmul_internal.scale_other = torch.nn.Parameter(scale_bf16_to_fp8)
34-
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, torch.bfloat16) # the uint4->fp8 is currently slower and with bugs. Jira ticket: https://jira.habana-labs.com/browse/SW-218009
35-
weight = cast_to_fp8_fcn(weight, torch.float8_e4m3fn)
31+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales/self.matmul_internal.scale_other, zeros, torch.float8_e4m3fn) # todo: div scales in init
3632
output = self.matmul_internal(input, weight)
3733
output = output.to(dtype=input_dtype).reshape(
3834
output_shape
@@ -42,8 +38,7 @@ def forward(self, input):
4238

4339
@staticmethod
4440
def convert_from_weight_only(obj):
45-
bias = obj.bias is not None
46-
new_self = HPUMixedPrecisionLinear(obj.in_features, obj.out_features, bias)
41+
new_self = HPUMixedPrecisionLinear(obj.in_features, obj.out_features)
4742
for attr, value in vars(obj).items():
4843
setattr(new_self, attr, value)
4944
new_self.matmul_internal.no_input_quant = True # flag for 8bit input, which shouldn't be quantized in matmul

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 13 additions & 190 deletions
Large diffs are not rendered by default.

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def forward(self, X):
6262
class UnpackedWeightOnlyLinearParams(dict):
6363
"""Contains all unpacked weight values."""
6464

65-
def __init__(self, unpack_weight, scales, scale_bf16_to_fp8, unpack_zp, **kwargs):
65+
def __init__(self, unpack_weight, scales, unpack_zp, **kwargs):
6666
"""Create dict."""
67-
super().__init__(int_weight=unpack_weight, scales=scales, scale_bf16_to_fp8 = scale_bf16_to_fp8, zp=unpack_zp, **kwargs)
67+
super().__init__(int_weight=unpack_weight, scales=scales, zp=unpack_zp, **kwargs)
6868

6969
def to(self, device):
7070
"""Change device for all values."""
@@ -209,14 +209,6 @@ def __init__(
209209
dtype=self.float_type,
210210
).to(device),
211211
)
212-
self.register_buffer(
213-
"scale_bf16_to_fp8",
214-
torch.zeros(
215-
1,
216-
dtype=self.float_type,
217-
).to(device),
218-
)
219-
# scale_bf16_to_fp8 is only used in w4a8 measurement mode and currently supports only per-tensor scaling
220212
self.register_buffer(
221213
"qweight",
222214
torch.zeros(
@@ -242,13 +234,6 @@ def __init__(
242234
dtype=self.float_type,
243235
).to(device),
244236
)
245-
self.register_buffer(
246-
"scale_bf16_to_fp8",
247-
torch.zeros(
248-
1,
249-
dtype=self.float_type,
250-
).to(device),
251-
)
252237
if compression_dim == 1:
253238
self.register_buffer(
254239
"qweight",
@@ -290,7 +275,7 @@ def __init__(
290275
else:
291276
self.g_idx = None
292277

293-
def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=None, **kwargs):
278+
def pack(self, int_weight, scales, zp, bias=None, g_idx=None, **kwargs):
294279
"""Pack int weight."""
295280
if self.use_optimum_format:
296281
self.scales = self.scales.T.contiguous()
@@ -316,8 +301,6 @@ def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=
316301
self.g_idx = self.g_idx.type(torch.int32).to(self.device)
317302
assert scales.shape == self.scales.shape, f"{scales.shape} != {self.scales.shape} Scale shape is mismatched."
318303
self.scales = scales.type(self.float_type).to(self.device)
319-
if scale_bf16_to_fp8 is not None:
320-
self.scale_bf16_to_fp8 = scale_bf16_to_fp8.type(self.float_type).to(self.device)
321304
if not self.use_optimum_format and self.compression_dim == 0:
322305
int_weight = int_weight.T.contiguous()
323306
self.qweight = self.qweight.T.contiguous()
@@ -349,7 +332,6 @@ def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=
349332
def unpack(self):
350333
"""Unpack weight and zero point."""
351334
scales = self.scales.T.contiguous() if self.use_optimum_format else self.scales
352-
scale_bf16_to_fp8 = self.scale_bf16_to_fp8
353335
qweight = self.qweight.T.contiguous() if self.use_optimum_format else self.qweight
354336

355337
device = scales.device
@@ -385,15 +367,14 @@ def unpack(self):
385367
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
386368
zp += 1
387369
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
388-
return UnpackedWeightOnlyLinearParams(weight, scales, scale_bf16_to_fp8, zp, g_idx=self.g_idx, bias=self.bias)
370+
return UnpackedWeightOnlyLinearParams(weight, scales, zp, g_idx=self.g_idx, bias=self.bias)
389371

390372
def recover(self):
391373
"""Recover fp32 weight from packed weight."""
392374
logger.debug(f"Recovering {self} weight")
393375
unpack_params_dict = self.unpack()
394376
weight = unpack_params_dict.get("int_weight")
395377
scales = unpack_params_dict.get("scales")
396-
scale_bf16_to_fp8 = unpack_params_dict.get("scale_bf16_to_fp8")
397378
zp = unpack_params_dict.get("zp")
398379

399380
device = scales.device
@@ -687,13 +668,7 @@ def __init__(
687668
dtype=self.float_type,
688669
),
689670
)
690-
self.register_buffer(
691-
"scale_bf16_to_fp8",
692-
torch.zeros(
693-
1,
694-
dtype=self.float_type,
695-
),
696-
)
671+
697672
if g_idx:
698673
self.register_buffer(
699674
"g_idx",
@@ -712,22 +687,17 @@ def forward(self, input):
712687
input_dtype = input.dtype
713688
output_shape = input.shape[:-1] + (self.out_features,)
714689
scales = self.scales
715-
scale_bf16_to_fp8 = self.scale_bf16_to_fp8 # Added by Tomer, per tensor scale.
716690
qweight = self.qweight
717691
zeros = self.qzeros
718-
if scale_bf16_to_fp8 > 0: # this means we are at w4a8 mode.
719-
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, torch.float8_e4m3fn)
720-
weight = weight.to(input_dtype) * scale_bf16_to_fp8
721-
else:
722-
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype)
692+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype)
723693
output = self.matmul_internal(input, weight)
724694
output = output.to(dtype=input_dtype).reshape(
725695
output_shape
726696
) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.
727697
output = output + self.bias if self.bias is not None else output
728698
return output
729699

730-
def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=None):
700+
def pack(self, int_weight, scales, zp, bias=None, g_idx=None):
731701
"""Pack weight and zero point."""
732702
logger.debug("Packing for HPU")
733703

@@ -736,7 +706,6 @@ def pack(self, int_weight, scales, zp, scale_bf16_to_fp8=None, bias=None, g_idx=
736706
qweight = int_weight.T.contiguous()
737707

738708
self.scales = scales.to(dtype=torch.bfloat16)
739-
self.scale_bf16_to_fp8 = scale_bf16_to_fp8.to(dtype=torch.bfloat16)
740709

741710
# weights and zp are on device from unpack, need to load to cpu for packing
742711
self.qweight = qweight.cpu()

neural_compressor/torch/algorithms/weight_only/save_load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def _replace_woqlinear_modules(self, name, linear_module, module_quantization_co
450450

451451
def _load_data_to_new_module(self, new_module, module_name):
452452
new_module_state_dict = {}
453-
for key in [".qweight", ".scales", ".scale_bf16_to_fp8", ".qzeros", ".bias", ".g_idx"]:
453+
for key in [".qweight", ".scales", ".qzeros", ".bias", ".g_idx"]:
454454
full_name = module_name + key
455455
if full_name in self.loaded_state_dict:
456456
new_module_state_dict[key[1:]] = self.loaded_state_dict.pop(full_name)

neural_compressor/torch/algorithms/weight_only/utility.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def search_clip(m, bits=4, group_size=32, scheme="asym", dtype="int", enable_ful
479479
return best_clip_ratio
480480

481481

482-
def quant_weight_w_scale(weight, scale, scale_bf16_to_fp8, zp=None, group_size=-1, dtype="int", fp8_aware=False):
482+
def quant_weight_w_scale(weight, scale, zp=None, group_size=-1, dtype="int"):
483483
"""Quant and dequant tensor with group size. It's an in-place function.
484484
485485
Args:
@@ -494,11 +494,6 @@ def quant_weight_w_scale(weight, scale, scale_bf16_to_fp8, zp=None, group_size=-
494494
"""
495495
device = weight.device
496496
scale = scale.to(device)
497-
if fp8_aware:
498-
weight.mul_(1 / scale_bf16_to_fp8)
499-
weight = torch.clamp(weight, min=-torch.finfo(torch.float8_e4m3fnuz).max, max=torch.finfo(torch.float8_e4m3fnuz).max)
500-
weight = weight.to(torch.float8_e4m3fn)
501-
weight = weight.to(torch.float32)
502497
if zp is not None:
503498
zp = zp.to(device)
504499
# group_size = -1

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ def gptq_entry(
158158
"double_quant_sym": quant_config.double_quant_use_sym,
159159
"double_quant_group_size": quant_config.double_quant_group_size,
160160
"act_order": quant_config.act_order,
161-
"hybrid_order": quant_config.hybrid_order,
162-
"fp8_aware": quant_config.fp8_aware,
163161
"percdamp": quant_config.percdamp,
164162
"block_size": quant_config.block_size,
165163
"static_groups": quant_config.static_groups,

neural_compressor/torch/quantization/config.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,6 @@ class GPTQConfig(TorchBaseConfig):
352352
"quant_lm_head",
353353
# gptq params
354354
"act_order",
355-
"hybrid_order",
356-
"fp8_aware",
357355
"percdamp",
358356
"block_size",
359357
"static_groups",
@@ -381,8 +379,6 @@ def __init__(
381379
quant_lm_head: bool = False,
382380
# gptq params
383381
act_order: bool = False,
384-
hybrid_order: bool = False,
385-
fp8_aware: bool = False,
386382
percdamp: float = 0.01,
387383
block_size: int = 2048,
388384
static_groups: bool = False,
@@ -410,10 +406,6 @@ def __init__(
410406
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
411407
act_order (bool): Whether to sort Hessian's diagonal values to rearrange channel-wise
412408
quantization order. Default is False.
413-
hybrid_order (bool): Enables activation re-ordering with no inference overhead.
414-
Weights are re-ordered within their groups without cross-group mixing.
415-
fp8_aware (bool): Whether to include an FP8 quantization step in the GPTQ process.
416-
This improves accuracy when using the W4A8 quantization scheme.
417409
percdamp (float): Percentage of Hessian's diagonal values' average, which will be added to
418410
Hessian's diagonal to increase numerical stability. Default is 0.01.
419411
block_size (int): Execute GPTQ quantization per block, block shape = [C_out, block_size].
@@ -446,8 +438,6 @@ def __init__(
446438
self.double_quant_group_size = double_quant_group_size
447439
# gptq
448440
self.act_order = act_order
449-
self.hybrid_order = hybrid_order
450-
self.fp8_aware = fp8_aware
451441
self.percdamp = percdamp
452442
self.block_size = block_size
453443
self.static_groups = static_groups

test/3x/torch/quantization/fp8_quant/test_fp8_jsons/test_pow2_w4a8_quant.json

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)