2222from .quant_config import QuantMode , get_hqt_config
2323from ..patched_module_base import PatchedModuleBase , get_call_wrapper
2424from .._core .scale_handler import get_scale_dtype , ScaleFormat
25-
25+ from neural_compressor .torch .utils .auto_accelerator import auto_detect_accelerator
26+ cur_accelerator = auto_detect_accelerator ()
2627
2728class BMM (nn .Module ):
2829 def __init__ (self ):
@@ -75,9 +76,13 @@ def get_current_repr(cls_instance, *member_names):
7576 if not first_name :
7677 curr_repr += ", "
7778 cur_attr = getattr (cls_instance , name )
78- # currently, only scale is called here.
79- dtype = get_scale_dtype (cur_attr )
80- curr_repr += f"{ name } dtype={ dtype } "
79+ if isinstance (cur_attr , list ) and len (cur_attr ) > 1 :
80+ dtype = get_scale_dtype (cur_attr [0 ])
81+ curr_repr += f"{ name } type=list of { dtype } , length={ len (cur_attr )} "
82+ else :
83+ # currently, only scale is called here.
84+ dtype = get_scale_dtype (cur_attr )
85+ curr_repr += f"{ name } dtype={ dtype } "
8186 first_name = False
8287 return curr_repr
8388
@@ -401,6 +406,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
401406 allreduce_quantization_enable = get_hqt_config (mod ).cfg ["row_parallel_linear_allreduce_quantization" ]
402407 if self .quantization_mode in (QuantMode .MEASURE , QuantMode .SHAPE ):
403408 self .forward = self .forward_measure_reduce if self .reduce_results and self .tp_size > 1 else self .forward_measure_no_reduce
409+
404410 elif self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
405411 if self .fake_quant or self .use_qdq :
406412 self .forward = self .forward_qdq
@@ -470,7 +476,7 @@ def forward_quant_reduce_in_hp(self, input):
470476 def measure_input_and_matmul (self , input ):
471477 resolved_input = self .resolve_input (input )
472478 measure_input ((resolved_input ,), observer = self ._mod_extra_config .inputs )
473- return torch . matmul ( resolved_input , self .weight . transpose ( - 1 , - 2 ) )
479+ return self .orig_mod . quant_method . apply ( self . orig_mod , resolved_input )
474480
475481 def forward_measure_no_reduce (self , input ):
476482 output = self .measure_input_and_matmul (input )
@@ -570,7 +576,7 @@ def forward_quant(self, input):
570576
571577 def forward_measure (self , input ):
572578 measure_input ((input ,), observer = self ._mod_extra_config .inputs )
573- output = torch . matmul ( input , self .weight . transpose ( - 1 , - 2 ) )
579+ output = self .orig_mod . quant_method . apply ( self . orig_mod , input )
574580 measure_output ((output ,), self ._mod_extra_config .outputs )
575581 output , output_bias = self .add_bias (output )
576582 if self .gather_output :
@@ -698,6 +704,8 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
698704 init_linear (self , mod_extra_config )
699705 if (self .quantization_mode == QuantMode .MEASURE ) or (self .quantization_mode == QuantMode .SHAPE ):
700706 measure_input ((torch .tensor (0 ),), observer = self ._mod_extra_config .inputs )
707+ else :
708+ self .weight = self .weight .squeeze ()
701709
702710 def forward_qdq (self , input , * args , ** kwargs ):
703711 qinput = self .quant_input (input )
@@ -823,6 +831,9 @@ def extra_repr(self) -> str:
823831class PatchedVllmMixtureOfExpertsOp (PatchedModuleBase ):
824832 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
825833 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
834+ # Get the `experts_min` and `experts_max` from the original module if they exist
835+ self .experts_min = self .orig_mod .experts_min if hasattr (self .orig_mod , "experts_min" ) else 0
836+ self .experts_max = self .orig_mod .experts_max if hasattr (self .orig_mod , "experts_max" ) else 7
826837 if self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
827838 self .forward = self .forward_quant
828839 self .dynamic_moe_op = get_quantized_func_wrapper (OP_TYPE .DYNAMIC_MOE_FUSED_WEIGHTS , self .scale_format )
@@ -841,8 +852,8 @@ def forward_quant(self,
841852 permuted_weights = True ,
842853 activation = "silu" ):
843854 experts_range = range (self .num_experts )
844- w1_list = [self .w13_list [i ].weight . squeeze () for i in experts_range ]
845- w2_list = [self .w2_list [i ].weight . squeeze () for i in experts_range ]
855+ w1_list = [self .w13_list [i ].weight for i in experts_range ]
856+ w2_list = [self .w2_list [i ].weight for i in experts_range ]
846857 scale_w1 = [self .w13_list [i ].scale_weight for i in experts_range ]
847858 scale_w2 = [self .w2_list [i ].scale_weight for i in experts_range ]
848859 qinput = self .quant_input (hidden_states )
@@ -858,8 +869,8 @@ def forward_quant(self,
858869 d_scale_intermediate_hidden_states = self .scale_intermediate ,
859870 permuted_weights = False ,
860871 activation = activation ,
861- experts_min = 0 ,
862- experts_max = 7
872+ experts_min = self . experts_min ,
873+ experts_max = self . experts_max ,
863874 )
864875 return output
865876
@@ -881,8 +892,8 @@ def forward_measure(self,
881892 w3 = w2_list ,
882893 permuted_weights = permuted_weights ,
883894 activation = activation ,
884- experts_min = 0 ,
885- experts_max = 7 ,
895+ experts_min = self . experts_min ,
896+ experts_max = self . experts_max ,
886897 measurement_mode = True ,
887898 )
888899 output_measure_list = [output ]
@@ -892,15 +903,65 @@ def forward_measure(self,
892903 return output
893904
894905 def extra_repr (self ) -> str :
895- member_names = ["scale_input" ]
896- for x in range (1 , self .num_experts + 1 ):
897- member_names .append ("scale_intermediate[" + str (x )+ "]" )
906+ member_names = ["scale_input" , "scale_intermediate" ]
907+ # for x in range(1, self.num_experts+1):
908+ # member_names.append("scale_intermediate["+str(x)+"]")
898909 return extra_representation (
899910 self .extra_repr_org (),
900911 self .class_name_org ,
901912 get_current_repr (self , * member_names ),
902913 )
903914
915+ class PatchedVllmMixtureOfExpertsOpFP8 (PatchedVllmMixtureOfExpertsOp ):
916+ """The patched module for the VLLMMixtureOfExpertsOp module with FP8 weights.
917+
918+ There are some models like Deepseek R1/V3 with FP8 weights, we need to requantize it.
919+
920+ The main difference between this module and the PatchedVllmMixtureOfExpertsOp is that the weights are FP8.
921+ - At measurement stage, we dequantize the weights to BF16.
922+ - At quantization stage, we use the same `forward_quant` method as the PatchedVllmMixtureOfExpertsOp.
923+ """
924+
925+ def forward_measure (
926+ self ,
927+ x ,
928+ topk_ids ,
929+ topk_weights ,
930+ ):
931+ hidden_states = x
932+ measure_input ((hidden_states ,), observer = self ._mod_extra_config .inputs )
933+ min_expert = self .experts_min
934+ max_expert = self .experts_max
935+ w13_list_slice = []
936+ w2_list_slice = []
937+ for j in range (self .num_experts ):
938+ w13_list_slice .append (self .w13_list [j ].get_dequant_weight ())
939+ w2_list_slice .append (self .w2_list [j ].get_dequant_weight ())
940+
941+ output , intermidiate_amax = torch .ops .hpu .mixture_of_experts .fp8_measurement_fused_weights (
942+ hidden_states = x ,
943+ expert_routing_table = topk_ids .to (torch .int64 ),
944+ router_weights = topk_weights .to (x .dtype ),
945+ w12 = w13_list_slice ,
946+ w3 = w2_list_slice ,
947+ permuted_weights = True ,
948+ activation = "silu" ,
949+ experts_min = min_expert ,
950+ experts_max = max_expert ,
951+ measurement_mode = True ,
952+ )
953+ output_measure_list = [output ]
954+ for i in range (self .num_experts ):
955+ output_measure_list .append (intermidiate_amax [i ])
956+ measure_output (output_measure_list , self ._mod_extra_config .outputs )
957+ return output
958+
959+ class PatchedMoeFP8Matmul (PatchedMoeMatmul ):
960+ """The patched module for the MoeMatmul module with FP8 weights."""
961+
962+ def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
963+ super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
964+ self .get_dequant_weight = self .orig_mod .get_dequant_weight
904965
905966class PatchedKVCache (PatchedModuleBase ):
906967 # Module to patch KVCache module from llama model
0 commit comments