Skip to content

Commit 42b1485

Browse files
yiliu30Yi4Liu
andauthored
[SW-221594]Re-quantize the Official DeepSeek FP8 Model (#187)
Building on the vllm WoQ path, this PR adds support for re-quantizing FP8 weights w/ per-tensor or per-channel scaling. --------- Co-authored-by: Yi Liu <[email protected]>
1 parent f12e7f0 commit 42b1485

File tree

8 files changed

+127
-23
lines changed

8 files changed

+127
-23
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@
2020
import torch
2121
from enum import Enum, auto
2222
from functools import lru_cache
23-
2423
from .._quant_common.quant_config import get_hqt_config
2524
from ..utils.logger import logger
2625
from neural_compressor.torch.algorithms.fp8_quant.model_configs import ModuleConfig
2726

2827
UNMEASURED_MODELS = "UnmeasuredModels"
2928

29+
def dequant_original_fp8_weight_if_needed(mod: torch.nn.Module, param: torch.Tensor) -> torch.Tensor:
30+
if param.dtype in [torch.float8_e4m3fn]:
31+
if hasattr(mod, "get_dequant_weights_func"):
32+
dequant_weights_func = mod.get_dequant_weights_func()
33+
if dequant_weights_func is not None:
34+
param = dequant_weights_func(mod)
35+
return param
3036

3137
class QuantTensorType(Enum):
3238
MEASUREMENTS = auto()

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OBSERVER_PARAMS,
3232
IMOD_DICT,
3333
)
34+
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
3435
cur_accelerator = auto_detect_accelerator()
3536

3637

@@ -162,6 +163,7 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
162163
if pmod._mod_extra_config:
163164
for param_name in pmod._mod_extra_config.params:
164165
param = getattr(pmod, param_name)
166+
param = dequant_original_fp8_weight_if_needed(pmod.orig_mod, param)
165167
if config["measure_on_hpu"]:
166168
param = param.to(cur_accelerator.name())
167169
pmod._mod_extra_config.params[param_name].measure(param)

neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import importlib.util
16-
16+
import os
1717
from ..model_configs import ModuleInfo, ModuleType
1818
from .._quant_common.helper_modules import *
1919
from ..utils.logger import logger
@@ -52,14 +52,17 @@ def create_mod_info_recursion(parent):
5252

5353
create_mod_info_recursion(model)
5454

55+
56+
INC_DYNAMIC_MOE_EXPERTS = int(os.environ.get("INC_DYNAMIC_MOE_EXPERTS", "8"))
57+
5558
_mod_types = {
5659
"linear": ModuleType(1, ["weight"], 1, False),
5760
"row_parallel_linear": ModuleType(1, ["weight"], 2, True),
5861
"matmul": ModuleType(2, [], 1, False),
5962
"kv_cache": ModuleType(1, [], 1, False),
6063
"softmax": ModuleType(1, [], 1, True),
6164
"fused_sdpa": ModuleType(3, [], 2, True),
62-
"dynamic_moe": ModuleType(1, [], 9, True),
65+
"dynamic_moe": ModuleType(1, [], 1 + INC_DYNAMIC_MOE_EXPERTS, True),
6366
}
6467

6568

@@ -80,10 +83,12 @@ def create_mod_info_recursion(parent):
8083
"Softmax": ModuleInfo("softmax", PatchedSoftmax),
8184
"ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA),
8285
"MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul),
86+
"MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul),
8387
"ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear),
8488
"FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False),
8589
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
8690
"VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp),
91+
"VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8),
8792
}
8893

8994

neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .measure import load_measurements
2929
from .scale import scale_method_mapping, load_layer_scales, prepare_layer_scales
3030
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
31+
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
3132

3233

3334
cur_accelerator = auto_detect_accelerator()
@@ -74,9 +75,12 @@ def quantize_params(mod, mod_extra_config):
7475
param = getattr(mod, param_name)
7576
if param.dtype == torch.float16:
7677
param = param.to(torch.bfloat16)
78+
param = dequant_original_fp8_weight_if_needed(mod, param)
7779
quantized_param = quantizer(param.to(cur_accelerator.name()))
7880
delattr(mod, param_name)
7981
setattr(mod, param_name, nn.Parameter(quantized_param))
82+
# Note: in case of re-quantize the fp8 weights, we need to set `updated_fp8_weight` to True
83+
mod.updated_fp8_weight = True
8084
quantized_param = getattr(mod, param_name)
8185
quantized_param.requires_grad_(False)
8286
cur_accelerator.synchronize()

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .scale_method_factory import ScaleMethodFactory, QuantTensorName
1919
from ..common import ModuleConfig, QuantTensorType
2020
from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput, QuantDynamicInput
21+
from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed
2122

2223

2324
class BaseOpQuantizer:
@@ -100,9 +101,11 @@ def get_scales_module_config(self):
100101
input_scales = self.calc_input_scales(num_of_inputs=1)
101102
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
102103
rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None
104+
if rescaled_weight is not None:
105+
rescaled_weight = dequant_original_fp8_weight_if_needed(self.mod, rescaled_weight)
103106
if self.weight_ich_scale_calc is not None:
104107
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
105-
rescaled_weight = torch.div(self.mod.weight, weight_scales_in_ch.reshape([1, -1]))
108+
rescaled_weight = torch.div(rescaled_weight, weight_scales_in_ch.reshape([1, -1]))
106109
weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST)
107110
params_config = (
108111
{"weight": weights_scales_out_ch}

neural_compressor/torch/algorithms/fp8_quant/_core/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .common import is_runtime_scale_patching
2424

2525
import os
26+
import re
2627
import habana_frameworks.torch.utils.experimental as htexp
2728

2829

@@ -42,8 +43,11 @@ def print_init_info(config):
4243
logger.info("neural_compressor_pt Configuration = %s", config)
4344

4445

45-
def is_substr(substr_list, target):
46-
return any([x in target for x in substr_list])
46+
def is_re_match(substr_list, target):
47+
for substr in substr_list:
48+
if re.search(substr, target):
49+
return True
50+
return False
4751

4852

4953
def should_quantize(config, mod_type, name):
@@ -57,12 +61,12 @@ def mod_is_not_blocked(mod_type, config):
5761
return (mod_type in allowlist_tuple)
5862
def allowlist_is_empty_or_allows_mod(mod_type, name, config):
5963
def mod_is_in_allowlist_config(mod_type, name, config):
60-
return ((mod_type in config.cfg["allowlist"]["types"]) or (is_substr(config.cfg["allowlist"]["names"], name)))
64+
return ((mod_type in config.cfg["allowlist"]["types"]) or (is_re_match(config.cfg["allowlist"]["names"], name)))
6165
def is_allowlist_completely_empty(config):
6266
return ((len(config.cfg["allowlist"]["names"]) == 0) and len(config.cfg["allowlist"]["types"]) == 0)
6367
return (mod_is_in_allowlist_config(mod_type, name, config) or is_allowlist_completely_empty(config))
6468
def name_is_not_blocked(name, config):
65-
return (not is_substr(config.cfg["blocklist"]["names"], name))
69+
return (not is_re_match(config.cfg["blocklist"]["names"], name))
6670
def is_static_scale_method(config):
6771
return config.cfg["scale_method"] not in _dynamic_scale_methods
6872
def quantize_dynamic_op(config, mod_type):

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

100644100755
Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from .quant_config import QuantMode, get_hqt_config
2323
from ..patched_module_base import PatchedModuleBase, get_call_wrapper
2424
from .._core.scale_handler import get_scale_dtype, ScaleFormat
25-
25+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
26+
cur_accelerator = auto_detect_accelerator()
2627

2728
class BMM(nn.Module):
2829
def __init__(self):
@@ -75,9 +76,13 @@ def get_current_repr(cls_instance, *member_names):
7576
if not first_name:
7677
curr_repr += ", "
7778
cur_attr = getattr(cls_instance, name)
78-
# currently, only scale is called here.
79-
dtype = get_scale_dtype(cur_attr)
80-
curr_repr += f"{name} dtype={dtype}"
79+
if isinstance(cur_attr, list) and len(cur_attr) > 1:
80+
dtype = get_scale_dtype(cur_attr[0])
81+
curr_repr += f"{name} type=list of {dtype}, length={len(cur_attr)}"
82+
else:
83+
# currently, only scale is called here.
84+
dtype = get_scale_dtype(cur_attr)
85+
curr_repr += f"{name} dtype={dtype}"
8186
first_name = False
8287
return curr_repr
8388

@@ -398,6 +403,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
398403
allreduce_quantization_enable = get_hqt_config(mod).cfg["row_parallel_linear_allreduce_quantization"]
399404
if self.quantization_mode in (QuantMode.MEASURE, QuantMode.SHAPE):
400405
self.forward = self.forward_measure_reduce if self.reduce_results and self.tp_size > 1 else self.forward_measure_no_reduce
406+
401407
elif self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
402408
if self.fake_quant or self.use_qdq:
403409
self.forward = self.forward_qdq
@@ -467,7 +473,7 @@ def forward_quant_reduce_in_hp(self, input):
467473
def measure_input_and_matmul(self, input):
468474
resolved_input = self.resolve_input(input)
469475
measure_input((resolved_input,), observer=self._mod_extra_config.inputs)
470-
return torch.matmul(resolved_input, self.weight.transpose(-1, -2))
476+
return self.orig_mod.quant_method.apply(self.orig_mod, resolved_input)
471477

472478
def forward_measure_no_reduce(self, input):
473479
output = self.measure_input_and_matmul(input)
@@ -567,7 +573,7 @@ def forward_quant(self, input):
567573

568574
def forward_measure(self, input):
569575
measure_input((input,), observer=self._mod_extra_config.inputs)
570-
output = torch.matmul(input, self.weight.transpose(-1, -2))
576+
output = self.orig_mod.quant_method.apply(self.orig_mod, input)
571577
measure_output((output,), self._mod_extra_config.outputs)
572578
output, output_bias = self.add_bias(output)
573579
if self.gather_output:
@@ -695,6 +701,8 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
695701
init_linear(self, mod_extra_config)
696702
if (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
697703
measure_input((torch.tensor(0),), observer=self._mod_extra_config.inputs)
704+
else:
705+
self.weight = self.weight.squeeze()
698706

699707
def forward_qdq(self, input, *args, **kwargs):
700708
qinput = self.quant_input(input)
@@ -820,6 +828,9 @@ def extra_repr(self) -> str:
820828
class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase):
821829
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
822830
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
831+
# Get the `experts_min` and `experts_max` from the original module if they exist
832+
self.experts_min = self.orig_mod.experts_min if hasattr(self.orig_mod, "experts_min") else 0
833+
self.experts_max = self.orig_mod.experts_max if hasattr(self.orig_mod, "experts_max") else 7
823834
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
824835
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format)
825836
self.quant_input = self._mod_extra_config.inputs[0]
@@ -837,8 +848,8 @@ def forward_quant(self,
837848
permuted_weights=True,
838849
activation="silu"):
839850
experts_range = range(self.num_experts)
840-
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
841-
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
851+
w1_list = [self.w13_list[i].weight for i in experts_range]
852+
w2_list = [self.w2_list[i].weight for i in experts_range]
842853
scale_w1 = [self.w13_list[i].scale_weight for i in experts_range]
843854
scale_w2 = [self.w2_list[i].scale_weight for i in experts_range]
844855
qinput = self.quant_input(hidden_states)
@@ -854,8 +865,8 @@ def forward_quant(self,
854865
d_scale_intermediate_hidden_states=self.scale_intermediate,
855866
permuted_weights=False,
856867
activation=activation,
857-
experts_min=0,
858-
experts_max=7
868+
experts_min=self.experts_min,
869+
experts_max=self.experts_max,
859870
)
860871
return output
861872

@@ -877,8 +888,8 @@ def forward_measure(self,
877888
w3=w2_list,
878889
permuted_weights=permuted_weights,
879890
activation=activation,
880-
experts_min=0,
881-
experts_max=7,
891+
experts_min=self.experts_min,
892+
experts_max=self.experts_max,
882893
measurement_mode=True,
883894
)
884895
output_measure_list = [output]
@@ -888,15 +899,65 @@ def forward_measure(self,
888899
return output
889900

890901
def extra_repr(self) -> str:
891-
member_names = ["scale_input"]
892-
for x in range(1, self.num_experts+1):
893-
member_names.append("scale_intermediate["+str(x)+"]")
902+
member_names = ["scale_input", "scale_intermediate"]
903+
# for x in range(1, self.num_experts+1):
904+
# member_names.append("scale_intermediate["+str(x)+"]")
894905
return extra_representation(
895906
self.extra_repr_org(),
896907
self.class_name_org,
897908
get_current_repr(self, *member_names),
898909
)
899910

911+
class PatchedVllmMixtureOfExpertsOpFP8(PatchedVllmMixtureOfExpertsOp):
912+
"""The patched module for the VLLMMixtureOfExpertsOp module with FP8 weights.
913+
914+
There are some models like Deepseek R1/V3 with FP8 weights, we need to requantize it.
915+
916+
The main difference between this module and the PatchedVllmMixtureOfExpertsOp is that the weights are FP8.
917+
- At measurement stage, we dequantize the weights to BF16.
918+
- At quantization stage, we use the same `forward_quant` method as the PatchedVllmMixtureOfExpertsOp.
919+
"""
920+
921+
def forward_measure(
922+
self,
923+
x,
924+
topk_ids,
925+
topk_weights,
926+
):
927+
hidden_states = x
928+
measure_input((hidden_states,), observer=self._mod_extra_config.inputs)
929+
min_expert = self.experts_min
930+
max_expert = self.experts_max
931+
w13_list_slice = []
932+
w2_list_slice = []
933+
for j in range(self.num_experts):
934+
w13_list_slice.append(self.w13_list[j].get_dequant_weight())
935+
w2_list_slice.append(self.w2_list[j].get_dequant_weight())
936+
937+
output, intermidiate_amax = torch.ops.hpu.mixture_of_experts.fp8_measurement_fused_weights(
938+
hidden_states=x,
939+
expert_routing_table=topk_ids.to(torch.int64),
940+
router_weights=topk_weights.to(x.dtype),
941+
w12=w13_list_slice,
942+
w3=w2_list_slice,
943+
permuted_weights=True,
944+
activation="silu",
945+
experts_min=min_expert,
946+
experts_max=max_expert,
947+
measurement_mode=True,
948+
)
949+
output_measure_list = [output]
950+
for i in range(self.num_experts):
951+
output_measure_list.append(intermidiate_amax[i])
952+
measure_output(output_measure_list, self._mod_extra_config.outputs)
953+
return output
954+
955+
class PatchedMoeFP8Matmul(PatchedMoeMatmul):
956+
"""The patched module for the MoeMatmul module with FP8 weights."""
957+
958+
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
959+
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
960+
self.get_dequant_weight = self.orig_mod.get_dequant_weight
900961

901962
class PatchedKVCache(PatchedModuleBase):
902963
# Module to patch KVCache module from llama model
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import unittest
2+
import re
3+
from typing import List
4+
from neural_compressor.torch.algorithms.fp8_quant._core.utils import is_re_match
5+
6+
7+
class TestUtils(unittest.TestCase):
8+
def test_is_re_match_found(self):
9+
substr_list = ["lm_head", "mlp\\.gate\\b"]
10+
target = "layer.1.mlp.gate"
11+
self.assertTrue(is_re_match(substr_list, target))
12+
target2 = "model.lm_head"
13+
self.assertTrue(is_re_match(substr_list, target2))
14+
15+
def test_is_re_match_not_found(self):
16+
substr_list = ["lm_head", "mlp\\.gate\\b"]
17+
target = "layer.1.mlp.gate_up_proj"
18+
self.assertFalse(is_re_match(substr_list, target))
19+

0 commit comments

Comments
 (0)