Skip to content

Commit 01aa482

Browse files
yiliu30Yi4Liu
authored andcommitted
[SW-221594]Re-quantize the Official DeepSeek FP8 Model (#187)
Building on the vllm WoQ path, this PR adds support for re-quantizing FP8 weights w/ per-tensor or per-channel scaling. --------- Co-authored-by: Yi Liu <[email protected]>
1 parent 6bf45af commit 01aa482

File tree

8 files changed

+127
-23
lines changed

8 files changed

+127
-23
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import torch
2121
from enum import Enum, auto
2222
from functools import lru_cache
23-
2423
from .._quant_common.quant_config import get_hqt_config
2524
from ..utils.logger import logger
2625
from neural_compressor.torch.algorithms.fp8_quant.model_configs import (
@@ -38,6 +37,13 @@
3837

3938
UNMEASURED_MODELS = "UnmeasuredModels"
4039

40+
def dequant_original_fp8_weight_if_needed(mod: torch.nn.Module, param: torch.Tensor) -> torch.Tensor:
41+
if param.dtype in [torch.float8_e4m3fn]:
42+
if hasattr(mod, "get_dequant_weights_func"):
43+
dequant_weights_func = mod.get_dequant_weights_func()
44+
if dequant_weights_func is not None:
45+
param = dequant_weights_func(mod)
46+
return param
4147

4248
class QuantTensorType(Enum):
4349
MEASUREMENTS = auto()

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OBSERVER_PARAMS,
3232
IMOD_DICT,
3333
)
34+
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
3435
cur_accelerator = auto_detect_accelerator()
3536

3637

@@ -162,6 +163,7 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
162163
if pmod._mod_extra_config:
163164
for param_name in pmod._mod_extra_config.params:
164165
param = getattr(pmod, param_name)
166+
param = dequant_original_fp8_weight_if_needed(pmod.orig_mod, param)
165167
if config["measure_on_hpu"]:
166168
param = param.to(cur_accelerator.name())
167169
pmod._mod_extra_config.params[param_name].measure(param)

neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import importlib.util
16-
16+
import os
1717
from ..model_configs import ModuleInfo, ModuleType
1818
from .._quant_common.helper_modules import *
1919

@@ -51,14 +51,17 @@ def create_mod_info_recursion(parent):
5151

5252
create_mod_info_recursion(model)
5353

54+
55+
INC_DYNAMIC_MOE_EXPERTS = int(os.environ.get("INC_DYNAMIC_MOE_EXPERTS", "8"))
56+
5457
_mod_types = {
5558
"linear": ModuleType(1, ["weight"], 1, False),
5659
"row_parallel_linear": ModuleType(1, ["weight"], 2, True),
5760
"matmul": ModuleType(2, [], 1, False),
5861
"kv_cache": ModuleType(1, [], 1, False),
5962
"softmax": ModuleType(1, [], 1, True),
6063
"fused_sdpa": ModuleType(3, [], 2, True),
61-
"dynamic_moe": ModuleType(1, [], 9, True),
64+
"dynamic_moe": ModuleType(1, [], 1 + INC_DYNAMIC_MOE_EXPERTS, True),
6265
}
6366

6467

@@ -79,10 +82,12 @@ def create_mod_info_recursion(parent):
7982
"Softmax": ModuleInfo("softmax", PatchedSoftmax),
8083
"ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA),
8184
"MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul),
85+
"MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul),
8286
"ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear),
8387
"FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False),
8488
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
8589
"VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp),
90+
"VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8),
8691
}
8792

8893

neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .measure import load_measurements
2929
from .scale import scale_method_mapping, load_layer_scales, prepare_layer_scales
3030
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
31+
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
3132

3233

3334
cur_accelerator = auto_detect_accelerator()
@@ -74,9 +75,12 @@ def quantize_params(mod, mod_extra_config):
7475
param = getattr(mod, param_name)
7576
if param.dtype == torch.float16:
7677
param = param.to(torch.bfloat16)
78+
param = dequant_original_fp8_weight_if_needed(mod, param)
7779
quantized_param = quantizer(param.to(cur_accelerator.name()))
7880
delattr(mod, param_name)
7981
setattr(mod, param_name, nn.Parameter(quantized_param))
82+
# Note: in case of re-quantize the fp8 weights, we need to set `updated_fp8_weight` to True
83+
mod.updated_fp8_weight = True
8084
quantized_param = getattr(mod, param_name)
8185
quantized_param.requires_grad_(False)
8286
cur_accelerator.synchronize()

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .scale_method_factory import QuantTensorName, ScaleMethodFactory
2020
from .scales_method import QuantTensorType
2121
from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput, QuantDynamicInput
22+
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
2223

2324

2425
class BaseOpQuantizer:
@@ -101,9 +102,11 @@ def get_scales_module_config(self):
101102
input_scales = self.calc_input_scales(num_of_inputs=1)
102103
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
103104
rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None
105+
if rescaled_weight is not None:
106+
rescaled_weight = dequant_original_fp8_weight_if_needed(self.mod, rescaled_weight)
104107
if self.weight_ich_scale_calc is not None:
105108
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
106-
rescaled_weight = torch.div(self.mod.weight, weight_scales_in_ch.reshape([1, -1]))
109+
rescaled_weight = torch.div(rescaled_weight, weight_scales_in_ch.reshape([1, -1]))
107110
weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST)
108111
params_config = (
109112
{"weight": weights_scales_out_ch}

neural_compressor/torch/algorithms/fp8_quant/_core/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .common import is_runtime_scale_patching
2424

2525
import os
26+
import re
2627
import habana_frameworks.torch.utils.experimental as htexp
2728

2829

@@ -42,8 +43,11 @@ def print_init_info(config):
4243
logger.info("neural_compressor_pt Configuration = %s", config)
4344

4445

45-
def is_substr(substr_list, target):
46-
return any([x in target for x in substr_list])
46+
def is_re_match(substr_list, target):
47+
for substr in substr_list:
48+
if re.search(substr, target):
49+
return True
50+
return False
4751

4852

4953
def should_quantize(config, mod_type, name):
@@ -57,12 +61,12 @@ def mod_is_not_blocked(mod_type, config):
5761
return (mod_type in allowlist_tuple)
5862
def allowlist_is_empty_or_allows_mod(mod_type, name, config):
5963
def mod_is_in_allowlist_config(mod_type, name, config):
60-
return ((mod_type in config.cfg["allowlist"]["types"]) or (is_substr(config.cfg["allowlist"]["names"], name)))
64+
return ((mod_type in config.cfg["allowlist"]["types"]) or (is_re_match(config.cfg["allowlist"]["names"], name)))
6165
def is_allowlist_completely_empty(config):
6266
return ((len(config.cfg["allowlist"]["names"]) == 0) and len(config.cfg["allowlist"]["types"]) == 0)
6367
return (mod_is_in_allowlist_config(mod_type, name, config) or is_allowlist_completely_empty(config))
6468
def name_is_not_blocked(name, config):
65-
return (not is_substr(config.cfg["blocklist"]["names"], name))
69+
return (not is_re_match(config.cfg["blocklist"]["names"], name))
6670
def is_static_scale_method(config):
6771
return config.cfg["scale_method"] not in _dynamic_scale_methods
6872
def quantize_dynamic_op(config, mod_type):

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

100644100755
Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from .quant_config import QuantMode, get_hqt_config
2323
from ..patched_module_base import PatchedModuleBase, get_call_wrapper
2424
from .._core.scale_handler import get_scale_dtype, ScaleFormat
25-
25+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
26+
cur_accelerator = auto_detect_accelerator()
2627

2728
class BMM(nn.Module):
2829
def __init__(self):
@@ -75,9 +76,13 @@ def get_current_repr(cls_instance, *member_names):
7576
if not first_name:
7677
curr_repr += ", "
7778
cur_attr = getattr(cls_instance, name)
78-
# currently, only scale is called here.
79-
dtype = get_scale_dtype(cur_attr)
80-
curr_repr += f"{name} dtype={dtype}"
79+
if isinstance(cur_attr, list) and len(cur_attr) > 1:
80+
dtype = get_scale_dtype(cur_attr[0])
81+
curr_repr += f"{name} type=list of {dtype}, length={len(cur_attr)}"
82+
else:
83+
# currently, only scale is called here.
84+
dtype = get_scale_dtype(cur_attr)
85+
curr_repr += f"{name} dtype={dtype}"
8186
first_name = False
8287
return curr_repr
8388

@@ -401,6 +406,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
401406
allreduce_quantization_enable = get_hqt_config(mod).cfg["row_parallel_linear_allreduce_quantization"]
402407
if self.quantization_mode in (QuantMode.MEASURE, QuantMode.SHAPE):
403408
self.forward = self.forward_measure_reduce if self.reduce_results and self.tp_size > 1 else self.forward_measure_no_reduce
409+
404410
elif self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
405411
if self.fake_quant or self.use_qdq:
406412
self.forward = self.forward_qdq
@@ -470,7 +476,7 @@ def forward_quant_reduce_in_hp(self, input):
470476
def measure_input_and_matmul(self, input):
471477
resolved_input = self.resolve_input(input)
472478
measure_input((resolved_input,), observer=self._mod_extra_config.inputs)
473-
return torch.matmul(resolved_input, self.weight.transpose(-1, -2))
479+
return self.orig_mod.quant_method.apply(self.orig_mod, resolved_input)
474480

475481
def forward_measure_no_reduce(self, input):
476482
output = self.measure_input_and_matmul(input)
@@ -570,7 +576,7 @@ def forward_quant(self, input):
570576

571577
def forward_measure(self, input):
572578
measure_input((input,), observer=self._mod_extra_config.inputs)
573-
output = torch.matmul(input, self.weight.transpose(-1, -2))
579+
output = self.orig_mod.quant_method.apply(self.orig_mod, input)
574580
measure_output((output,), self._mod_extra_config.outputs)
575581
output, output_bias = self.add_bias(output)
576582
if self.gather_output:
@@ -698,6 +704,8 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
698704
init_linear(self, mod_extra_config)
699705
if (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
700706
measure_input((torch.tensor(0),), observer=self._mod_extra_config.inputs)
707+
else:
708+
self.weight = self.weight.squeeze()
701709

702710
def forward_qdq(self, input, *args, **kwargs):
703711
qinput = self.quant_input(input)
@@ -823,6 +831,9 @@ def extra_repr(self) -> str:
823831
class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase):
824832
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
825833
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
834+
# Get the `experts_min` and `experts_max` from the original module if they exist
835+
self.experts_min = self.orig_mod.experts_min if hasattr(self.orig_mod, "experts_min") else 0
836+
self.experts_max = self.orig_mod.experts_max if hasattr(self.orig_mod, "experts_max") else 7
826837
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
827838
self.forward = self.forward_quant
828839
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format)
@@ -841,8 +852,8 @@ def forward_quant(self,
841852
permuted_weights=True,
842853
activation="silu"):
843854
experts_range = range(self.num_experts)
844-
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
845-
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
855+
w1_list = [self.w13_list[i].weight for i in experts_range]
856+
w2_list = [self.w2_list[i].weight for i in experts_range]
846857
scale_w1 = [self.w13_list[i].scale_weight for i in experts_range]
847858
scale_w2 = [self.w2_list[i].scale_weight for i in experts_range]
848859
qinput = self.quant_input(hidden_states)
@@ -858,8 +869,8 @@ def forward_quant(self,
858869
d_scale_intermediate_hidden_states=self.scale_intermediate,
859870
permuted_weights=False,
860871
activation=activation,
861-
experts_min=0,
862-
experts_max=7
872+
experts_min=self.experts_min,
873+
experts_max=self.experts_max,
863874
)
864875
return output
865876

@@ -881,8 +892,8 @@ def forward_measure(self,
881892
w3=w2_list,
882893
permuted_weights=permuted_weights,
883894
activation=activation,
884-
experts_min=0,
885-
experts_max=7,
895+
experts_min=self.experts_min,
896+
experts_max=self.experts_max,
886897
measurement_mode=True,
887898
)
888899
output_measure_list = [output]
@@ -892,15 +903,65 @@ def forward_measure(self,
892903
return output
893904

894905
def extra_repr(self) -> str:
895-
member_names = ["scale_input"]
896-
for x in range(1, self.num_experts+1):
897-
member_names.append("scale_intermediate["+str(x)+"]")
906+
member_names = ["scale_input", "scale_intermediate"]
907+
# for x in range(1, self.num_experts+1):
908+
# member_names.append("scale_intermediate["+str(x)+"]")
898909
return extra_representation(
899910
self.extra_repr_org(),
900911
self.class_name_org,
901912
get_current_repr(self, *member_names),
902913
)
903914

915+
class PatchedVllmMixtureOfExpertsOpFP8(PatchedVllmMixtureOfExpertsOp):
916+
"""The patched module for the VLLMMixtureOfExpertsOp module with FP8 weights.
917+
918+
There are some models like Deepseek R1/V3 with FP8 weights, we need to requantize it.
919+
920+
The main difference between this module and the PatchedVllmMixtureOfExpertsOp is that the weights are FP8.
921+
- At measurement stage, we dequantize the weights to BF16.
922+
- At quantization stage, we use the same `forward_quant` method as the PatchedVllmMixtureOfExpertsOp.
923+
"""
924+
925+
def forward_measure(
926+
self,
927+
x,
928+
topk_ids,
929+
topk_weights,
930+
):
931+
hidden_states = x
932+
measure_input((hidden_states,), observer=self._mod_extra_config.inputs)
933+
min_expert = self.experts_min
934+
max_expert = self.experts_max
935+
w13_list_slice = []
936+
w2_list_slice = []
937+
for j in range(self.num_experts):
938+
w13_list_slice.append(self.w13_list[j].get_dequant_weight())
939+
w2_list_slice.append(self.w2_list[j].get_dequant_weight())
940+
941+
output, intermidiate_amax = torch.ops.hpu.mixture_of_experts.fp8_measurement_fused_weights(
942+
hidden_states=x,
943+
expert_routing_table=topk_ids.to(torch.int64),
944+
router_weights=topk_weights.to(x.dtype),
945+
w12=w13_list_slice,
946+
w3=w2_list_slice,
947+
permuted_weights=True,
948+
activation="silu",
949+
experts_min=min_expert,
950+
experts_max=max_expert,
951+
measurement_mode=True,
952+
)
953+
output_measure_list = [output]
954+
for i in range(self.num_experts):
955+
output_measure_list.append(intermidiate_amax[i])
956+
measure_output(output_measure_list, self._mod_extra_config.outputs)
957+
return output
958+
959+
class PatchedMoeFP8Matmul(PatchedMoeMatmul):
960+
"""The patched module for the MoeMatmul module with FP8 weights."""
961+
962+
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
963+
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
964+
self.get_dequant_weight = self.orig_mod.get_dequant_weight
904965

905966
class PatchedKVCache(PatchedModuleBase):
906967
# Module to patch KVCache module from llama model
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import unittest
2+
import re
3+
from typing import List
4+
from neural_compressor.torch.algorithms.fp8_quant._core.utils import is_re_match
5+
6+
7+
class TestUtils(unittest.TestCase):
8+
def test_is_re_match_found(self):
9+
substr_list = ["lm_head", "mlp\\.gate\\b"]
10+
target = "layer.1.mlp.gate"
11+
self.assertTrue(is_re_match(substr_list, target))
12+
target2 = "model.lm_head"
13+
self.assertTrue(is_re_match(substr_list, target2))
14+
15+
def test_is_re_match_not_found(self):
16+
substr_list = ["lm_head", "mlp\\.gate\\b"]
17+
target = "layer.1.mlp.gate_up_proj"
18+
self.assertFalse(is_re_match(substr_list, target))
19+

0 commit comments

Comments
 (0)