2222from .quant_config import QuantMode , get_hqt_config
2323from ..patched_module_base import PatchedModuleBase , get_call_wrapper
2424from .._core .scale_handler import get_scale_dtype , ScaleFormat
25-
25+ from neural_compressor .torch .utils .auto_accelerator import auto_detect_accelerator
26+ cur_accelerator = auto_detect_accelerator ()
2627
2728class BMM (nn .Module ):
2829 def __init__ (self ):
@@ -75,9 +76,13 @@ def get_current_repr(cls_instance, *member_names):
7576 if not first_name :
7677 curr_repr += ", "
7778 cur_attr = getattr (cls_instance , name )
78- # currently, only scale is called here.
79- dtype = get_scale_dtype (cur_attr )
80- curr_repr += f"{ name } dtype={ dtype } "
79+ if isinstance (cur_attr , list ) and len (cur_attr ) > 1 :
80+ dtype = get_scale_dtype (cur_attr [0 ])
81+ curr_repr += f"{ name } type=list of { dtype } , length={ len (cur_attr )} "
82+ else :
83+ # currently, only scale is called here.
84+ dtype = get_scale_dtype (cur_attr )
85+ curr_repr += f"{ name } dtype={ dtype } "
8186 first_name = False
8287 return curr_repr
8388
@@ -398,6 +403,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
398403 allreduce_quantization_enable = get_hqt_config (mod ).cfg ["row_parallel_linear_allreduce_quantization" ]
399404 if self .quantization_mode in (QuantMode .MEASURE , QuantMode .SHAPE ):
400405 self .forward = self .forward_measure_reduce if self .reduce_results and self .tp_size > 1 else self .forward_measure_no_reduce
406+
401407 elif self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
402408 if self .fake_quant or self .use_qdq :
403409 self .forward = self .forward_qdq
@@ -467,7 +473,7 @@ def forward_quant_reduce_in_hp(self, input):
467473 def measure_input_and_matmul (self , input ):
468474 resolved_input = self .resolve_input (input )
469475 measure_input ((resolved_input ,), observer = self ._mod_extra_config .inputs )
470- return torch . matmul ( resolved_input , self .weight . transpose ( - 1 , - 2 ) )
476+ return self .orig_mod . quant_method . apply ( self . orig_mod , resolved_input )
471477
472478 def forward_measure_no_reduce (self , input ):
473479 output = self .measure_input_and_matmul (input )
@@ -567,7 +573,7 @@ def forward_quant(self, input):
567573
568574 def forward_measure (self , input ):
569575 measure_input ((input ,), observer = self ._mod_extra_config .inputs )
570- output = torch . matmul ( input , self .weight . transpose ( - 1 , - 2 ) )
576+ output = self .orig_mod . quant_method . apply ( self . orig_mod , input )
571577 measure_output ((output ,), self ._mod_extra_config .outputs )
572578 output , output_bias = self .add_bias (output )
573579 if self .gather_output :
@@ -695,6 +701,8 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
695701 init_linear (self , mod_extra_config )
696702 if (self .quantization_mode == QuantMode .MEASURE ) or (self .quantization_mode == QuantMode .SHAPE ):
697703 measure_input ((torch .tensor (0 ),), observer = self ._mod_extra_config .inputs )
704+ else :
705+ self .weight = self .weight .squeeze ()
698706
699707 def forward_qdq (self , input , * args , ** kwargs ):
700708 qinput = self .quant_input (input )
@@ -820,6 +828,9 @@ def extra_repr(self) -> str:
820828class PatchedVllmMixtureOfExpertsOp (PatchedModuleBase ):
821829 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
822830 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
831+ # Get the `experts_min` and `experts_max` from the original module if they exist
832+ self .experts_min = self .orig_mod .experts_min if hasattr (self .orig_mod , "experts_min" ) else 0
833+ self .experts_max = self .orig_mod .experts_max if hasattr (self .orig_mod , "experts_max" ) else 7
823834 if self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
824835 self .dynamic_moe_op = get_quantized_func_wrapper (OP_TYPE .DYNAMIC_MOE_FUSED_WEIGHTS , self .scale_format )
825836 self .quant_input = self ._mod_extra_config .inputs [0 ]
@@ -837,8 +848,8 @@ def forward_quant(self,
837848 permuted_weights = True ,
838849 activation = "silu" ):
839850 experts_range = range (self .num_experts )
840- w1_list = [self .w13_list [i ].weight . squeeze () for i in experts_range ]
841- w2_list = [self .w2_list [i ].weight . squeeze () for i in experts_range ]
851+ w1_list = [self .w13_list [i ].weight for i in experts_range ]
852+ w2_list = [self .w2_list [i ].weight for i in experts_range ]
842853 scale_w1 = [self .w13_list [i ].scale_weight for i in experts_range ]
843854 scale_w2 = [self .w2_list [i ].scale_weight for i in experts_range ]
844855 qinput = self .quant_input (hidden_states )
@@ -854,8 +865,8 @@ def forward_quant(self,
854865 d_scale_intermediate_hidden_states = self .scale_intermediate ,
855866 permuted_weights = False ,
856867 activation = activation ,
857- experts_min = 0 ,
858- experts_max = 7
868+ experts_min = self . experts_min ,
869+ experts_max = self . experts_max ,
859870 )
860871 return output
861872
@@ -877,8 +888,8 @@ def forward_measure(self,
877888 w3 = w2_list ,
878889 permuted_weights = permuted_weights ,
879890 activation = activation ,
880- experts_min = 0 ,
881- experts_max = 7 ,
891+ experts_min = self . experts_min ,
892+ experts_max = self . experts_max ,
882893 measurement_mode = True ,
883894 )
884895 output_measure_list = [output ]
@@ -888,15 +899,65 @@ def forward_measure(self,
888899 return output
889900
890901 def extra_repr (self ) -> str :
891- member_names = ["scale_input" ]
892- for x in range (1 , self .num_experts + 1 ):
893- member_names .append ("scale_intermediate[" + str (x )+ "]" )
902+ member_names = ["scale_input" , "scale_intermediate" ]
903+ # for x in range(1, self.num_experts+1):
904+ # member_names.append("scale_intermediate["+str(x)+"]")
894905 return extra_representation (
895906 self .extra_repr_org (),
896907 self .class_name_org ,
897908 get_current_repr (self , * member_names ),
898909 )
899910
911+ class PatchedVllmMixtureOfExpertsOpFP8 (PatchedVllmMixtureOfExpertsOp ):
912+ """The patched module for the VLLMMixtureOfExpertsOp module with FP8 weights.
913+
914+ There are some models like Deepseek R1/V3 with FP8 weights, we need to requantize it.
915+
916+ The main difference between this module and the PatchedVllmMixtureOfExpertsOp is that the weights are FP8.
917+ - At measurement stage, we dequantize the weights to BF16.
918+ - At quantization stage, we use the same `forward_quant` method as the PatchedVllmMixtureOfExpertsOp.
919+ """
920+
921+ def forward_measure (
922+ self ,
923+ x ,
924+ topk_ids ,
925+ topk_weights ,
926+ ):
927+ hidden_states = x
928+ measure_input ((hidden_states ,), observer = self ._mod_extra_config .inputs )
929+ min_expert = self .experts_min
930+ max_expert = self .experts_max
931+ w13_list_slice = []
932+ w2_list_slice = []
933+ for j in range (self .num_experts ):
934+ w13_list_slice .append (self .w13_list [j ].get_dequant_weight ())
935+ w2_list_slice .append (self .w2_list [j ].get_dequant_weight ())
936+
937+ output , intermidiate_amax = torch .ops .hpu .mixture_of_experts .fp8_measurement_fused_weights (
938+ hidden_states = x ,
939+ expert_routing_table = topk_ids .to (torch .int64 ),
940+ router_weights = topk_weights .to (x .dtype ),
941+ w12 = w13_list_slice ,
942+ w3 = w2_list_slice ,
943+ permuted_weights = True ,
944+ activation = "silu" ,
945+ experts_min = min_expert ,
946+ experts_max = max_expert ,
947+ measurement_mode = True ,
948+ )
949+ output_measure_list = [output ]
950+ for i in range (self .num_experts ):
951+ output_measure_list .append (intermidiate_amax [i ])
952+ measure_output (output_measure_list , self ._mod_extra_config .outputs )
953+ return output
954+
955+ class PatchedMoeFP8Matmul (PatchedMoeMatmul ):
956+ """The patched module for the MoeMatmul module with FP8 weights."""
957+
958+ def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
959+ super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
960+ self .get_dequant_weight = self .orig_mod .get_dequant_weight
900961
901962class PatchedKVCache (PatchedModuleBase ):
902963 # Module to patch KVCache module from llama model
0 commit comments