Skip to content

Commit 8bb9758

Browse files
mengniwang95Mengni Wang
andauthored
[SW-233731] Support FP8 QDQ quant on CPU (#239)
supported module types: Linear, Conv2D, EmbeddingBag (weight-only quant) validated scheme: per-tensor, sym, E4M3 validated model: DLRM, vit --------- Signed-off-by: Mengni Wang <[email protected]> Signed-off-by: Mengni Wang <[email protected]> Co-authored-by: Mengni Wang <[email protected]>
1 parent 9d13736 commit 8bb9758

File tree

13 files changed

+332
-30
lines changed

13 files changed

+332
-30
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from functools import lru_cache
2323
from ..utils.logger import logger
2424
from neural_compressor.torch.algorithms.fp8_quant.model_configs import ModuleConfig
25+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
26+
27+
cur_device = auto_detect_accelerator().current_device_name()
2528

2629
UNMEASURED_MODELS = "UnmeasuredModels"
2730

@@ -161,7 +164,7 @@ def load_scales(fname, target_format):
161164
return d
162165

163166

164-
def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, device="hpu"):
167+
def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, device=cur_device):
165168
scales_temp = {k: scales_obj[k].__dict__ for k in scales_obj}
166169
scales_temp = format_functions_rec((scales_file_format, torch.Tensor))(scales_temp)
167170
scales_temp = rec_fn(scales_temp, lambda x: x.to(dtype=hp_dtype, device=device))

neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def create_mod_info_recursion(parent):
6161
"softmax": ModuleType(1, [], 1, True),
6262
"fused_sdpa": ModuleType(3, [], 2, True),
6363
"dynamic_moe": ModuleType(1, [], 1 + 8, True),
64+
"embedding": ModuleType(1, ["weight"], 1, False),
6465
}
6566

6667

@@ -126,13 +127,31 @@ def _import_xpu_modules():
126127
"Matmul": ModuleInfo("matmul", PatchedMatmul),})
127128
PATCHED_MODULE_TYPES_TABLE["xpu"].update({"linear": _mod_types["linear"]})
128129

130+
131+
@functools.lru_cache(maxsize=None)
132+
def _import_cpu_modules():
133+
from neural_compressor.torch.algorithms.fp8_quant.patched_module_base import (
134+
PATCHED_MODULE_TABLE, PATCHED_MODULE_TYPES_TABLE
135+
)
136+
cur_accelerator = auto_detect_accelerator()
137+
if not cur_accelerator.current_device_name().startswith("cpu"):
138+
return
139+
PATCHED_MODULE_TABLE["cpu"].update({"Linear": ModuleInfo("linear", PatchedLinear),
140+
"Conv2d": ModuleInfo("linear", PatchedConv2d),
141+
"EmbeddingBag": ModuleInfo("embedding", PatchedEmbeddingBag),
142+
})
143+
PATCHED_MODULE_TYPES_TABLE["cpu"].update({"linear": _mod_types["linear"], "embedding": _mod_types["embedding"]})
144+
145+
129146
@functools.lru_cache(maxsize=None)
130147
def _import_device_modules():
131148
cur_accelerator_type = auto_detect_accelerator().get_inc_accelerator_type()
132149
if cur_accelerator_type.value > INCAcceleratorType.GAUDI_MIN.value:
133150
_import_hpu_modules()
134151
elif cur_accelerator_type == INCAcceleratorType.XPU:
135152
_import_xpu_modules()
153+
elif cur_accelerator_type == INCAcceleratorType.CPU:
154+
_import_cpu_modules()
136155
else:
137156
logger.warning("No HPU or XPU devices were detected. No Patched Modules available.")
138157

neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ def __init__(self, lp_dtype, hp_dtype="", *args, **kwargs):
4747
self.qdq_init()
4848

4949
def qdq_init(self):
50-
import habana_frameworks.torch.utils.experimental as htexp
51-
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2 and self.lp_dtype == torch.float8_e4m3fn:
52-
self.quant_min = int(torch.finfo(torch.float8_e4m3fnuz).min)
53-
self.quant_max = int(torch.finfo(torch.float8_e4m3fnuz).max)
54-
else:
55-
self.quant_min = int(torch.finfo(self.lp_dtype).min)
56-
self.quant_max = int(torch.finfo(self.lp_dtype).max)
50+
self.quant_min = int(torch.finfo(self.lp_dtype).min)
51+
self.quant_max = int(torch.finfo(self.lp_dtype).max)
52+
53+
if cur_accelerator.current_device_name() == "hpu":
54+
import habana_frameworks.torch.utils.experimental as htexp
55+
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2 and self.lp_dtype == torch.float8_e4m3fn:
56+
self.quant_min = int(torch.finfo(torch.float8_e4m3fnuz).min)
57+
self.quant_max = int(torch.finfo(torch.float8_e4m3fnuz).max)
5758

5859
if self.scale_format == ScaleFormat.CONST:
5960
self.zero_point = nn.Parameter(torch.tensor(0.))
@@ -98,7 +99,8 @@ def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):
9899
else quantize_per_tensor_to_fp8
99100
)
100101

101-
self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
102+
else:
103+
self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
102104

103105
def forward(self, x):
104106
return self.cast_to_op(x, self.scale_inv, False, False, self.lp_dtype)
@@ -156,8 +158,8 @@ def __init__(self, scale, lp_dtype, hp_dtype, *args, **kwargs):
156158
if self.scale_format == ScaleFormat.CONST and self.scale.numel() > 1
157159
else dequantize_per_tensor_from_fp8
158160
)
159-
160-
self.cast_from_op = get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format)
161+
else:
162+
self.cast_from_op = get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format)
161163

162164
def forward(self, x):
163165
return self.cast_from_op(x, self.scale, self.hp_dtype)
@@ -185,8 +187,9 @@ def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):
185187
super(QuantDequant, self).__init__(lp_dtype, hp_dtype, *args, **kwargs)
186188
self.register_scale("scale_inv", scale_inv, self.scale_format)
187189
self.register_scale("scale", 1 / scale_inv, self.scale_format)
188-
self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
189-
self.cast_from_op = get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format)
190+
if not self.use_qdq:
191+
self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format)
192+
self.cast_from_op = get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format)
190193

191194
def forward(self, x, *args, **kwargs):
192195
y = self.cast_to_op(x, self.scale_inv, False, False, self.lp_dtype)

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/quantized_func_wrapper_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ def init_quantized_func_wrapper_factory():
3737
elif device_name == "xpu":
3838
from .xpu.xpu_quantized_func_wrapper import init_xpu_quantized_func_wrapper_factory
3939
init_xpu_quantized_func_wrapper_factory()
40+
elif device_name == "cpu":
41+
# only support QDQ now
42+
pass
4043
else:
4144
raise ValueError("Unknown device type - {}".format(device_name))
4245

4346

4447
def clear_quantized_func_wrapper_factory():
45-
QuantizedFuncWrapperFactory.clear()
48+
QuantizedFuncWrapperFactory.clear()

neural_compressor/torch/algorithms/fp8_quant/_core/scale_handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import types
1717
from .._quant_common.quant_config import ScaleFormat
1818
from .common import is_runtime_scale_patching
19+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
20+
21+
cur_device = auto_detect_accelerator().current_device_name()
1922

2023

2124
def add_scale_registry(patched_mod):
@@ -83,7 +86,7 @@ def get_scale_dtype(scale):
8386
raise Exception(f"Unexpected scale instance type: {type(scale).__name__}, expected Torch.tensor or float number")
8487

8588

86-
def get_param_scales_from_scalar(patched_mod, prefix, dtype=torch.bfloat16, device=torch.device('hpu')):
89+
def get_param_scales_from_scalar(patched_mod, prefix, dtype=torch.bfloat16, device=cur_device):
8790
"""Get all scales in param_list, used for saving scalar scales"""
8891
scale_dict = {}
8992
for name in patched_mod.scale_members:
@@ -95,7 +98,7 @@ def get_param_scales_from_scalar(patched_mod, prefix, dtype=torch.bfloat16, devi
9598
return scale_dict
9699

97100

98-
def get_param_scales_from_list(patched_mod, prefix, dtype=torch.bfloat16, device=torch.device('hpu')):
101+
def get_param_scales_from_list(patched_mod, prefix, dtype=torch.bfloat16, device=cur_device):
99102
"""Get all scales in param_list, used for saving scalar scales"""
100103
scale_dict = {}
101104
for name in patched_mod.scale_members:
@@ -141,7 +144,7 @@ def set_param_scales_into_list(patched_mod, state_dict):
141144
def get_state_dict(patched_mod, *args, destination=None, prefix='', keep_vars=False):
142145
"""replace torch.nn.Module.state_dict"""
143146
cur_state_dict = torch.nn.Module.state_dict(patched_mod, *args, destination=destination, prefix=prefix, keep_vars=keep_vars)
144-
device = torch.device('hpu')
147+
device = cur_device
145148
dtype = patched_mod.hp_dtype
146149
if patched_mod.scale_format == ScaleFormat.SCALAR:
147150
scale_dict = get_param_scales_from_scalar(patched_mod, prefix, dtype=dtype, device=device)

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput, QuantDynamicInput
2121
from ...utils.logger import logger
2222
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
23+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
24+
cur_device = auto_detect_accelerator().current_device_name()
2325

2426

2527
class BaseOpQuantizer:
@@ -114,7 +116,7 @@ def get_scales_module_config(self):
114116
rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None
115117
if self.scales_method_factory.scale_method_config_map[QuantTensorName.WEIGHT_IN_CH].scale_value_type != ScaleValueType.DUMMY_SCALES:
116118
# Calculating weight in hpu to support scale calculation CGUID torch.ops.hpu.calculate_scale_for_cast
117-
rescaled_weight = rescaled_weight.to("hpu")
119+
rescaled_weight = rescaled_weight.to(cur_device)
118120
if rescaled_weight is not None:
119121
rescaled_weight = dequant_original_fp8_weight_if_needed(self.mod, rescaled_weight)
120122
if self.weight_ich_scale_calc is not None:
@@ -420,13 +422,78 @@ def scales_module_config_to_q_and_dq(self, module):
420422

421423

422424

425+
class EmbeddingOpQuantizer(BaseOpQuantizer):
426+
427+
def __init__(self, config, mod, measurement, params, module_type):
428+
super().__init__(config, mod, measurement, params, module_type)
429+
self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT))
430+
self.weight_och_scale_calc = self.scales_method_factory.get_scale_method(QuantTensorName.WEIGHT_OUT_CH)
431+
self.weight_ich_scale_calc = self.scales_method_factory.get_scale_method(QuantTensorName.WEIGHT_IN_CH)
432+
self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT))
433+
434+
def get_scales_module_config(self):
435+
weight = self.mod.weight if hasattr(self.mod, 'weight') else None
436+
input_scales = self.calc_input_scales(num_of_inputs=1)
437+
438+
if self.weight_ich_scale_calc is not None:
439+
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
440+
weight = torch.div(weight, weight_scales_in_ch.reshape([1, -1]))
441+
weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(weight, QuantTensorType.CONST)
442+
443+
params_config = (
444+
{"weight": weights_scales_out_ch}
445+
if (self.weight_ich_scale_calc is None)
446+
else {"weight": {0: weights_scales_out_ch, 1: weight_scales_in_ch}}
447+
)
448+
return ModuleConfig(
449+
(),
450+
(),
451+
params_config,
452+
)
453+
454+
def init_weight_config(self, scales, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant):
455+
if use_qdq:
456+
# to ensure the weights to be loaded to the device in fp8
457+
weight_config = [
458+
QuantInput(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq),
459+
DequantOutput(scales, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq),
460+
]
461+
else:
462+
raise ValueError("For FP8 quantization, {} only supports QDQ mode now!".format(self.mod.__class__.__name__))
463+
return weight_config
464+
465+
def init_weights_from_module(self, params_config):
466+
if isinstance(params_config, dict):
467+
self.weight_och_scale_calc.scale = params_config[0]
468+
self.weight_ich_scale_calc.scale = params_config[1]
469+
else:
470+
self.weight_och_scale_calc.scale = params_config
471+
472+
def scales_module_config_to_q_and_dq(self, module):
473+
self.init_scales_from_module_config(module)
474+
self.init_weights_from_module(module.params["weight"])
475+
scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = self.get_module_configuration()
476+
weight_config = self.init_weight_config(
477+
self.weight_och_scale_calc.scale,
478+
self.weight_och_scale_calc.calc_invert_scales(),
479+
lp_dtype,
480+
hp_dtype,
481+
scale_format,
482+
use_qdq,
483+
fake_quant,
484+
)
485+
params_config = {"weight": weight_config}
486+
return ModuleConfig([], [], params_config)
487+
488+
423489
ops_quantizer_map = {"linear": LinearOpQuantizer,
424490
"matmul": MatmulOpQuantizer,
425491
"fused_sdpa": FsdpaOpQuantizer,
426492
"softmax": SoftmaxOpQuantizer,
427493
"kv_cache": KVCacheOpQuantizer,
428494
"dynamic_moe": DynamicMoeOpQuantizer,
429-
"row_parallel_linear": RowParallelLinearOpQuantizer
495+
"row_parallel_linear": RowParallelLinearOpQuantizer,
496+
"embedding": EmbeddingOpQuantizer,
430497
}
431498

432499
def get_op_quantizer(config, mod, measurement, params, module_type):

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scales_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def calc_scales(self, tensor, tensor_type, **additional_kwargs):
136136
# used when running with dummy measurement (prepare_model_with_dummy_measurement)
137137
class DummyScales(ScalesMethod):
138138
def calc_scales(self, tensor, tensor_type, **additional_kwargs):
139-
self.scale = torch.tensor(1.0).to("hpu")
139+
self.scale = torch.tensor(1.0).to(self.device)
140140
return self.scale
141141

142142

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def init_linear(self, mod_extra_config):
167167
self.quant_input = self._mod_extra_config.inputs[0]
168168
self.dequant_output = self._mod_extra_config.outputs[0]
169169

170+
170171
# When offloading weights to disk using device_map, the module forward is overridden.
171172
# __dict__.update call again overrides the PatchedLinear forward with the forward that device_map planted.
172173
# So need to set PatchedLinear forward to be the right forward.
@@ -585,6 +586,53 @@ def forward_measure(self, input):
585586
return output
586587

587588

589+
class PatchedEmbeddingBag(PatchedModuleBase):
590+
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
591+
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
592+
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
593+
if self.use_qdq:
594+
self.dequant_weights = self._mod_extra_config.params["weight"][1]
595+
if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)):
596+
self.register_scale("scale_weight", mod_extra_config.scale.params["weight"], self.scale_format)
597+
elif isinstance(mod_extra_config.scale.params["weight"], dict):
598+
# PCQ weight is calculated with actual weight [0] and ones [1]
599+
# only ScaleFormat.CONST is supported for per-channel scale now.
600+
self.register_scale("scale_weight", mod_extra_config.scale.params["weight"][0], ScaleFormat.CONST)
601+
else:
602+
raise ValueError("EmbeddingBag is only supported QDQ mode now!")
603+
604+
def forward_qdq(self, input, offsets, *args, **kwargs):
605+
qweight = self.dequant_weights(self.weight, )
606+
607+
return torch.nn.functional.embedding_bag(
608+
input=input,
609+
offsets=offsets,
610+
weight=qweight,
611+
max_norm=self.max_norm,
612+
norm_type=self.norm_type,
613+
scale_grad_by_freq=self.scale_grad_by_freq,
614+
mode=self.mode,
615+
sparse=self.sparse,
616+
include_last_offset=self.include_last_offset,
617+
padding_idx=self.padding_idx,
618+
*args,
619+
**kwargs,
620+
)
621+
622+
def forward_measure(self, input, *args, **kwargs):
623+
measure_input((input,), observer=self._mod_extra_config.inputs)
624+
output = self.orig_mod(input, *args, **kwargs)
625+
measure_output((output,), self._mod_extra_config.outputs)
626+
return output
627+
628+
def extra_repr(self) -> str:
629+
return extra_representation(
630+
self.extra_repr_org(),
631+
self.class_name_org,
632+
get_current_repr(self, "scale_weight"),
633+
)
634+
635+
588636
# patched vllm FusedMoE module removing the bf16 weights of all experts
589637
# measure and quant of the weights is done per expert using PatchedMoeMatmul
590638
# therefore it is configured: ModuleInfo.should_measure_and_quant = False

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
220220
validate_and_populate_scale_method(scale_method_config)
221221

222222

223+
if auto_detect_accelerator().current_device_name() == "cpu" and not measured_global_config["use_qdq"]:
224+
raise ValueError("For FP8 quantization, only QDQ mode is supported on CPU device.")
225+
223226
# If seperate_measure_files is True (default value), then it is assumed that there are multiple distinct measure and scale files
224227
# and they are stored in / loaded from paths with the correct index as a suffix. Else, only one is searched for.
225228
measured_global_config["local_rank"] = (
@@ -230,6 +233,11 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
230233
logger.debug("setting device for scales config")
231234
Fp8cfg.set_gaudi_device_for_scales(custom_config, measured_global_config, scale_method_config)
232235

236+
if auto_detect_accelerator().current_device_name() == "cpu" and \
237+
check_scale_method_fields(scale_method_config, granularity_weight=ScaleGranularity.PCS, reducer=any):
238+
# for PCQ, there is some issue in dequantize_per_channel op on CPU device
239+
raise ValueError("Don't support FP8 PCQ (Per Channel Quantization) on CPU device now")
240+
233241
if measured_global_config["scale_format"] == ScaleFormat.SCALAR:
234242
if check_scale_method_fields(scale_method_config, granularity_weight=ScaleGranularity.PCS, reducer=any) or \
235243
check_scale_method_fields(scale_method_config, granularity_activation=ScaleGranularity.PCS, reducer=any):
@@ -242,6 +250,8 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
242250
dynamic_quantization = measured_global_config["dynamic_quantization"]
243251
# TODO [SW-217814]: get dynamic methods in a better way, or support file handling in dynamic mode
244252
if dynamic_quantization:
253+
if auto_detect_accelerator().current_device_name() == "cpu":
254+
raise ValueError("Currently CPU device doesn't support dynamic quantization")
245255
logger.info(f"NOTE: Using dynamic scale method, only supported ops will be quantized.")
246256
if measured_global_config["scale_format"] == ScaleFormat.SCALAR:
247257
measured_global_config["scale_format"] = ScaleFormat.CONST
@@ -364,4 +374,4 @@ def _read_config_from_file(config_path: str) -> Mapping[str, str]:
364374
except JSONDecodeError as e:
365375
config_json.close()
366376
raise Exception(f"Got exception: {e}. QUANT PACKAGE: Can't load {config_path}!")
367-
return config
377+
return config

0 commit comments

Comments
 (0)