Skip to content

Commit d726f92

Browse files
yiliu30XuehaoSun
authored andcommitted
[SW-228576] Add Dynamic Quant Support For FusedMoE (#243)
Signed-off-by: Yi Liu <[email protected]>
1 parent 833c107 commit d726f92

File tree

8 files changed

+75
-22
lines changed

8 files changed

+75
-22
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def create_mod_info_recursion(parent):
8383
"MoeMatmul": ModuleInfo("linear", PatchedMoeMatmul),
8484
"MoeFP8Matmul": ModuleInfo("linear", PatchedMoeFP8Matmul),
8585
"ReplicatedLinear": ModuleInfo("linear", PatchedReplicatedLinear),
86-
"FusedMoE": ModuleInfo("matmul", PatchedMixtralMoE, False),
86+
# Note: `no_quantize_op` indicates that this module is patched but does not require measurement or quantization.
87+
"FusedMoE": ModuleInfo("no_quantize_op", PatchedMixtralMoE, False),
8788
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
8889
"VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp),
8990
"VllmMixtureOfExpertsOpFP8": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpFP8),

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/hpu/hpu_quantized_func_wrapper.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class QuantizedHpuFuncWrapperBase(QuantizedFuncWrapperBase, metaclass=ABCMeta):
3333
Concrete class may override base class methods in case custom op logic is unique, see examples in concrete
3434
classes below.
3535
"""
36-
def __init__(self, scale_format):
37-
self._quantized_func_ = self.get_quantized_func(scale_format)
36+
def __init__(self, scale_format, is_dynamic=False):
37+
self._quantized_func_ = self.get_quantized_func(scale_format, is_dynamic)
3838

3939
@abstractmethod
4040
def get_default_quantized_func(self):
@@ -45,19 +45,32 @@ def get_scalar_quantized_func(self):
4545
return self.get_default_quantized_func()
4646
return self.get_default_quantized_func().scalar
4747

48-
def get_quantized_func(self, scale_format):
49-
if scale_format == ScaleFormat.SCALAR:
50-
return self.get_scalar_quantized_func()
51-
elif scale_format == ScaleFormat.CONST:
52-
return self.get_default_quantized_func()
48+
def get_dynamic_scalar_quantized_func(self):
49+
# By default, dynamic scalar quantized function is the same as scalar quantized function.
50+
return self.get_scalar_quantized_func()
51+
52+
def get_dynamic_quantized_func(self):
53+
# By default, dynamic quantized function is the same as default quantized function.
54+
return self.get_default_quantized_func()
55+
56+
def get_quantized_func(self, scale_format, is_dynamic=False):
57+
if scale_format not in [ScaleFormat.SCALAR, ScaleFormat.CONST]:
58+
raise ValueError("Unsupported scale format - {}".format(scale_format))
59+
if is_dynamic:
60+
if scale_format == ScaleFormat.SCALAR:
61+
return self.get_dynamic_scalar_quantized_func()
62+
else:
63+
return self.get_dynamic_quantized_func()
5364
else:
54-
raise ValueError("Unexpected scale format - {}".format(scale_format))
65+
if scale_format == ScaleFormat.SCALAR:
66+
return self.get_scalar_quantized_func()
67+
else:
68+
return self.get_default_quantized_func()
5569

5670
def __call__(self, *args, **kwargs):
5771
return self._quantized_func_(*args, **kwargs)
5872

5973

60-
6174
class QuantizedHpuMatmul(QuantizedHpuFuncWrapperBase):
6275
def get_default_quantized_func(self):
6376
return torch.ops.hpu.fp8_gemm_v2
@@ -134,7 +147,6 @@ def get_default_quantized_func(self):
134147
def get_scalar_quantized_func(self):
135148
return torch.ops.hpu.mixture_of_experts.fp8_scalars
136149

137-
138150
class QuantizedHPUCastToFP8(QuantizedHpuFuncWrapperBase):
139151
def get_default_quantized_func(self):
140152
return torch.ops.hpu.cast_to_fp8_v2
@@ -144,18 +156,23 @@ def __call__(self, *args, **kwargs):
144156

145157
class QuantizedHPUCastFromFP8(QuantizedHpuFuncWrapperBase):
146158

147-
def __init__(self, scale_format):
148-
super().__init__(scale_format)
149-
150159
def get_default_quantized_func(self):
151160
return torch.ops.hpu.cast_from_fp8
152161

162+
153163
class QuantizedHpuDynamicMoeFusedWeights(QuantizedHpuFuncWrapperBase):
154164
def get_default_quantized_func(self):
155165
return torch.ops.hpu.mixture_of_experts.fp8_fused_weights
166+
156167
def get_scalar_quantized_func(self):
157168
return torch.ops.hpu.mixture_of_experts.fp8_fused_weights_scalars
158169

170+
def get_dynamic_scalar_quantized_func(self):
171+
return torch.ops.hpu.mixture_of_experts.fp8_fused_weights_scalars_dynamic
172+
173+
def get_dynamic_quantized_func(self):
174+
return torch.ops.hpu.mixture_of_experts.fp8_fused_weights_dynamic
175+
159176

160177
_OP_TYPE_HPU_QUANTIZED_WRAPPER_CLASSES = {OP_TYPE.LINEAR_GEMM : QuantizedHpuMatmul,
161178
OP_TYPE.MATMUL_GEMM: QuantizedHpuMatmul,

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/quantized_func_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def initialize(cls, device_quantized_func_wrapper_dict):
7373

7474

7575
@classmethod
76-
def get_quantized_func_wrapper_object(cls, op_type, scale_format):
76+
def get_quantized_func_wrapper_object(cls, op_type, scale_format, is_dynamic=False):
7777
if op_type not in cls.__quantized_func_wrapper_instances:
7878
quantized_wrapper_class = cls.__device_func_wrappers_mapping[op_type]
79-
cls.__quantized_func_wrapper_instances[op_type] = quantized_wrapper_class(scale_format)
79+
cls.__quantized_func_wrapper_instances[op_type] = quantized_wrapper_class(scale_format, is_dynamic)
8080

8181
return cls.__quantized_func_wrapper_instances[op_type]
8282

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/quantized_func_wrapper_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
Functions to interact with QuantizedFuncWrapperFactory singleton object
2525
"""
2626

27-
def get_quantized_func_wrapper(op_type, scale_format):
28-
return QuantizedFuncWrapperFactory.get_quantized_func_wrapper_object(op_type, scale_format)
27+
def get_quantized_func_wrapper(op_type, scale_format, is_dynamic=False):
28+
return QuantizedFuncWrapperFactory.get_quantized_func_wrapper_object(op_type, scale_format, is_dynamic)
2929

3030

3131
def init_quantized_func_wrapper_factory():

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/xpu/xpu_quantized_func_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class QuantizedXPUFuncWrapperBase(QuantizedFuncWrapperBase, metaclass=ABCMeta):
2525
"""
2626
Placeholder for base class for XPU (Falcon/Jaguar Shores) quantized func wrapper.
2727
"""
28-
def __init__(self, scale_format):
28+
def __init__(self, scale_format, is_dynamic=False):
2929
self._quantized_func_ = self.get_default_quantized_func()
3030

3131
class QuantizedXPUMatmul(QuantizedXPUFuncWrapperBase):

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def __init__(self, config, mod, measurement, params, module_type):
382382
num_of_experts = 8
383383

384384
self.inputs_scales_creators = [
385-
self.scales_method_factory.get_scale_method(QuantTensorName.INPUT)
385+
self.scales_method_factory.get_scale_method(QuantTensorName.INPUT, is_dynamic=self.is_dynamic)
386386
for i in range(num_of_inputs + num_of_experts)
387387
]
388388
self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT))

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,20 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
734734
self.experts_max = self.orig_mod.experts_max if hasattr(self.orig_mod, "experts_max") else 7
735735
self.experts_used = self.local_num_experts if hasattr(self.orig_mod, "local_num_experts") else self.num_experts
736736
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
737-
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format)
737+
738738
self.quant_input = self._mod_extra_config.inputs[0]
739739
self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format)
740740
self.register_scale(
741741
"scale_intermediate",
742742
[mod_extra_config.scale.inputs[x] for x in range(1, self.experts_used+1)],
743743
self.scale_format,
744744
)
745+
self.is_dynamic_quantization = isinstance(self.quant_input, QuantDynamicInput)
746+
self.dynamic_moe_op = get_quantized_func_wrapper(
747+
OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, scale_format=self.scale_format, is_dynamic=self.is_dynamic_quantization
748+
)
749+
if self.is_dynamic_quantization:
750+
self.forward = self.forward_dynamic_quant
745751

746752
def forward_quant(self,
747753
hidden_states,
@@ -772,6 +778,35 @@ def forward_quant(self,
772778
)
773779
return output
774780

781+
def forward_dynamic_quant(
782+
self, hidden_states, expert_routing_table, router_weights, permuted_weights=True, layer=None, activation="silu"
783+
):
784+
# This is the dynamic version of the forward_quant method.
785+
# Compared to the `forward_quant` method, the main differences are:
786+
# 1) The `quant_input` is of type `QuantDynamicInput`.
787+
# 2) There is no need to pass the `d_scale_intermediate_hidden_states` to the dynamic moe op.
788+
experts_range = range(self.num_experts)
789+
w1_list = [self.w13_list[i].weight for i in experts_range]
790+
w2_list = [self.w2_list[i].weight for i in experts_range]
791+
scale_w1 = [self.w13_list[i].scale_weight for i in experts_range]
792+
scale_w2 = [self.w2_list[i].scale_weight for i in experts_range]
793+
qinput_fp8, input_scale = self.quant_input(hidden_states)
794+
output = self.dynamic_moe_op(
795+
hidden_states=qinput_fp8,
796+
expert_routing_table=expert_routing_table,
797+
router_weights=router_weights,
798+
w12=w1_list,
799+
w3=w2_list,
800+
d_scale_w12=scale_w1,
801+
d_scale_w3=scale_w2,
802+
d_scale_hidden_states=input_scale,
803+
permuted_weights=False,
804+
activation=activation,
805+
experts_min=self.experts_min,
806+
experts_max=self.experts_max
807+
)
808+
return output
809+
775810
def forward_measure(self,
776811
hidden_states,
777812
expert_routing_table,

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class DeviceForScalesType(Enum):
104104

105105
# TODO [SW-217813]: support dynamic quantization in all ops and remove
106106
# TODO [SW-228723]: get a better way to list all linear ops, like set in ModuleInfo if supports dynamic
107-
supported_dynamic_ops = ["linear", "row_parallel_linear"]
107+
supported_dynamic_ops = ["linear", "row_parallel_linear", "no_quantize_op", "dynamic_moe"]
108108
def is_supported_dynamic_op(op_type):
109109
ret = op_type.lower() in [op.lower() for op in supported_dynamic_ops]
110110
logger.trace("Checking if %s is supported for dynamic quantization: %s", op_type, ret)

0 commit comments

Comments
 (0)