Skip to content

Commit ab265ef

Browse files
Asaf Karnielitgafni
andauthored
[ALGO-808] add support for int4 weights + fp8 activations - phase 1 (#43)
* [ALGO-808] add support for int4 weights + fp8 activations - phase 1 * Add code for quantizing only single input to PatchedMatmul * w4a8 new kernel --------- Co-authored-by: Tomer Gafni <[email protected]>
1 parent 68faf23 commit ab265ef

File tree

15 files changed

+306
-3
lines changed

15 files changed

+306
-3
lines changed

neural_compressor/common/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TEQ = "teq" # pragma: no cover
3535
AUTOROUND = "autoround"
3636
FP8_QUANT = "fp8_quant"
37+
HYBRID_GPTQ = "hybrid_gptq"
3738
MX_QUANT = "mx_quant"
3839
MIXED_PRECISION = "mixed_precision"
3940

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ def scales_module_config_to_q_and_dq(self, module):
197197
use_qdq,
198198
fake_quant,
199199
)
200+
201+
# 4bit->8bit inputs, no need to quant
202+
if hasattr(self.mod, "no_input_quant"):
203+
input_config[1] = QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)
204+
200205
# outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here
201206
output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)]
202207
return ModuleConfig(input_config, output_config)

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class DeviceForScalesType(Enum):
9999
"ignore_modules_wo_measures": TrueFalse,
100100
"use_qdq": TrueFalse,
101101
"fake_quant": TrueFalse,
102+
"int4_weights": TrueFalse,
102103
"scale_format": ScaleFormat,
103104
"device_for_scales": DeviceForScalesType,
104105
"measure_on_hpu": TrueFalse,
@@ -111,6 +112,7 @@ class DeviceForScalesType(Enum):
111112
"ignore_modules_wo_measures",
112113
"recalc_scales",
113114
"fake_quant",
115+
"int4_weights",
114116
"use_qdq",
115117
"device_for_scales",
116118
"measure_on_hpu",
@@ -189,6 +191,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
189191
}, # types and names to be quantized. Allowlist by names is not yet implemented
190192
"mode": QuantMode.QUANTIZE, # Quantize or Measure
191193
"fake_quant": False, # Fake or Real Quant, fake_quant only works for linear(PatchedLinear) and matmul(PatchedMatmul), usually used for training.
194+
"int4_weights": False,
192195
"use_qdq": False, # QDQ or Real Quant, QDQ works for operators in helper_modules.py, usually used for inference.
193196
"scale_method": ScaleMethod.MAXABS_HW, # Method to quantize with
194197
"scale_params": {}, # scaling parameters that are different then the default ones

neural_compressor/torch/algorithms/mixed_low_precision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from neural_compressor.torch.algorithms.mixed_low_precision.quantizer import HybridGPTQQuantizer
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"method": "HOOKS",
3+
"mode": "QUANTIZE",
4+
"observer": "maxabs",
5+
"scale_method": "maxabs_hw",
6+
"dump_stats_path": "./calib_output/measure",
7+
"int4_weights": "True"
8+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import math
2+
from abc import abstractmethod
3+
4+
import numpy as np
5+
import torch
6+
from torch.autograd import Function
7+
from torch.nn import functional as F
8+
9+
from ..weight_only.modules import HPUWeightOnlyLinear
10+
from neural_compressor.torch.utils import accelerator, logger
11+
12+
13+
class HPUMixedPrecisionLinear(HPUWeightOnlyLinear):
14+
"""Weight and Activations quant (W4A8 gptq) Linear for HPU device."""
15+
16+
def __init__(
17+
self, in_features, out_features,
18+
**kwargs,
19+
):
20+
"""Init the HPUMixedPrecisionLinear object.
21+
"""
22+
super(HPUMixedPrecisionLinear, self).__init__(in_features, out_features)
23+
24+
def forward(self, input):
25+
"""The forward function of HPUMixedPrecisionLinear."""
26+
input_dtype = input.dtype
27+
output_shape = input.shape[:-1] + (self.out_features,)
28+
scales = self.scales
29+
qweight = self.qweight
30+
zeros = self.qzeros
31+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales/self.matmul_internal.scale_other, zeros, torch.float8_e4m3fn) # todo: div scales in init
32+
output = self.matmul_internal(input, weight)
33+
output = output.to(dtype=input_dtype).reshape(
34+
output_shape
35+
) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.
36+
output = output + self.bias if self.bias is not None else output
37+
return output
38+
39+
@staticmethod
40+
def convert_from_weight_only(obj):
41+
new_self = HPUMixedPrecisionLinear(obj.in_features, obj.out_features)
42+
for attr, value in vars(obj).items():
43+
setattr(new_self, attr, value)
44+
new_self.matmul_internal.no_input_quant = True # flag for 8bit input, which shouldn't be quantized in matmul
45+
return new_self
46+
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from neural_compressor.torch.algorithms import Quantizer
16+
from neural_compressor.torch.algorithms.mixed_low_precision.modules import HPUMixedPrecisionLinear
17+
from neural_compressor.torch.algorithms.weight_only.modules import HPUWeightOnlyLinear
18+
19+
class HybridGPTQQuantizer(Quantizer):
20+
def __init__(self, quant_config):
21+
super().__init__(quant_config)
22+
if isinstance(quant_config, dict):
23+
json_file = [cfg.json_file for cfg in quant_config.values()]
24+
assert len(json_file) > 0, "Cannot get json file from config."
25+
self.quant_config = json_file[0]
26+
27+
def prepare(self, model):
28+
return model
29+
30+
def convert(self, model):
31+
_convert(model)
32+
return model
33+
34+
def set_module(model, op_name, new_module):
35+
"""Set module with a given op name.
36+
37+
Args:
38+
model (object): the input model.
39+
op_name (str): name of op.
40+
new_module (object): the input model.
41+
42+
Returns:
43+
module (object).
44+
"""
45+
module = model
46+
name_list = op_name.split(".")
47+
for name in name_list[:-1]:
48+
if hasattr(module, name):
49+
module = getattr(module, name)
50+
else:
51+
module = module
52+
setattr(module, name_list[-1], new_module)
53+
54+
def _convert(model):
55+
for name, module in model.named_modules():
56+
# replace `HPUWeightOnlyLinear`s forward func
57+
if isinstance(module, HPUWeightOnlyLinear):
58+
module = HPUMixedPrecisionLinear.convert_from_weight_only(module)
59+
set_module(model, name, module)
60+
61+
return model

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@
2929

3030
from .utility import quant_tensor
3131

32+
class Matmul(torch.nn.Module):
33+
34+
def __init__(self, ) -> None:
35+
super().__init__()
36+
37+
def forward(self, X, Y):
38+
"""Forward function."""
39+
return torch.matmul(X, Y)
3240

3341
class QDQLayer(torch.nn.Module):
3442
"""Quantized and dequantized layer."""
@@ -672,6 +680,7 @@ def __init__(
672680
self.half_indim = self.in_features // 2
673681

674682
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)
683+
self.matmul_internal = Matmul()
675684

676685
def forward(self, input):
677686
"""The forward function of HPUWeighOnlyLinear."""
@@ -681,7 +690,7 @@ def forward(self, input):
681690
qweight = self.qweight
682691
zeros = self.qzeros
683692
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype)
684-
output = torch.matmul(input, weight)
693+
output = self.matmul_internal(input, weight)
685694
output = output.to(dtype=input_dtype).reshape(
686695
output_shape
687696
) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.

neural_compressor/torch/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
get_woq_tuning_config,
4545
DynamicQuantConfig,
4646
get_default_dynamic_config,
47+
HybridGPTQConfig
4748
)
4849

4950
from neural_compressor.torch.quantization.autotune import (

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AUTOROUND,
2424
AWQ,
2525
FP8_QUANT,
26+
HYBRID_GPTQ,
2627
GPTQ,
2728
HQQ,
2829
MIXED_PRECISION,
@@ -45,6 +46,7 @@
4546
SmoothQuantConfig,
4647
StaticQuantConfig,
4748
TEQConfig,
49+
HybridGPTQConfig
4850
)
4951
from neural_compressor.torch.utils import (
5052
dump_model_op_stats,
@@ -721,6 +723,26 @@ def fp8_entry(
721723
postprocess_model(model, mode, quantizer)
722724
return model
723725

726+
###################### Habana MixedPrecision Algo Entry ##################################
727+
@register_algo(HYBRID_GPTQ)
728+
@torch.no_grad()
729+
def hybrid_gptq_entry(
730+
model: torch.nn.Module,
731+
configs_mapping: Dict[Tuple[str], FP8Config],
732+
mode: Mode = Mode.QUANTIZE,
733+
*args,
734+
**kwargs,
735+
) -> torch.nn.Module:
736+
"""The main entry to apply w4a8 gptq quantization."""
737+
738+
from neural_compressor.torch.algorithms.mixed_low_precision import HybridGPTQQuantizer
739+
740+
quantizer = get_quantizer(model, quantizer_cls=HybridGPTQQuantizer, quant_config=configs_mapping)
741+
model = quantizer.execute(model, mode=mode)
742+
743+
fp8_entry(model, configs_mapping, mode, *args, **kwargs)
744+
return model
745+
724746

725747
###################### MX Quant Algo Entry ##################################
726748
@register_algo(name=MX_QUANT)

0 commit comments

Comments
 (0)