@@ -740,13 +740,19 @@ class PatchedGaudiMixtralSparseMoeBlock(PatchedModuleBase):
740740 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
741741 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
742742 self .forward = self .forward_orig
743+ self .ep_size = mod .ep_size
744+ self .experts_min = mod .experts_min
745+ self .experts_max = mod .experts_max
746+ self .experts_range = mod .experts_range
747+ self .num_experts = mod .num_experts
748+
743749 if self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
744750 self .dynamic_moe_op = get_quantized_func_wrapper (OP_TYPE .DYNAMIC_MOE , self .scale_format )
745751 self .quant_input = self ._mod_extra_config .inputs [0 ]
746752 self .register_scale ("scale_input" , mod_extra_config .scale .inputs [0 ], self .scale_format )
747753 self .register_scale (
748754 "scale_intermediate" ,
749- [mod_extra_config .scale .inputs [x ] for x in range (1 , self .num_experts + 1 )],
755+ [mod_extra_config .scale .inputs [x ] for x in range (self . experts_min + 1 , self .experts_max + 2 )],
750756 self .scale_format ,
751757 )
752758 mod .call_dynamic_moe_op = get_call_wrapper (self , "call_dynamic_moe_quant_op" )
@@ -759,12 +765,12 @@ def call_dynamic_moe_quant_op(self,
759765 router_weights ,
760766 permuted_weights = False ,
761767 activation = "silu" ):
762- w1_list = [expert . w1 .weight for expert in self .experts ]
763- w2_list = [expert . w2 .weight for expert in self .experts ]
764- w3_list = [expert . w3 .weight for expert in self .experts ]
765- scale_w1 = [expert . w1 .scale_weight for expert in self .experts ]
766- scale_w2 = [expert . w2 .scale_weight for expert in self .experts ]
767- scale_w3 = [expert . w3 .scale_weight for expert in self .experts ]
768+ w1_list = [self . experts [ i ]. w1 .weight for i in self .experts_range ]
769+ w2_list = [self . experts [ i ]. w2 .weight for i in self .experts_range ]
770+ w3_list = [self . experts [ i ]. w3 .weight for i in self .experts_range ]
771+ scale_w1 = [self . experts [ i ]. w1 .scale_weight for i in self .experts_range ]
772+ scale_w2 = [self . experts [ i ]. w2 .scale_weight for i in self .experts_range ]
773+ scale_w3 = [self . experts [ i ]. w3 .scale_weight for i in self .experts_range ]
768774 qinput = self .quant_input (hidden_states )
769775 output = self .dynamic_moe_op (
770776 hidden_states = qinput ,
@@ -780,8 +786,8 @@ def call_dynamic_moe_quant_op(self,
780786 d_scale_intermediate_hidden_states = self .scale_intermediate ,
781787 permuted_weights = False ,
782788 activation = activation ,
783- experts_min = 0 ,
784- experts_max = 7
789+ experts_min = self . experts_min ,
790+ experts_max = self . experts_max ,
785791 )
786792 return output
787793
@@ -791,10 +797,11 @@ def call_dynamic_moe_measure_op(self,
791797 router_weights ,
792798 permuted_weights = True ,
793799 activation = "silu" ):
794- w1_list = [expert . w1 .weight for expert in self .experts ]
795- w2_list = [expert . w2 .weight for expert in self .experts ]
796- w3_list = [expert . w3 .weight for expert in self .experts ]
800+ w1_list = [self . experts [ i ]. w1 .weight for i in self .experts_range ]
801+ w2_list = [self . experts [ i ]. w2 .weight for i in self .experts_range ]
802+ w3_list = [self . experts [ i ]. w3 .weight for i in self .experts_range ]
797803 measure_input ((hidden_states ,), observer = self ._mod_extra_config .inputs )
804+
798805 output , intermidiate_amax = torch .ops .hpu .mixture_of_experts .fp8_measurement (
799806 hidden_states = hidden_states ,
800807 expert_routing_table = expert_routing_table ,
@@ -804,13 +811,21 @@ def call_dynamic_moe_measure_op(self,
804811 w2 = w3_list ,
805812 permuted_weights = permuted_weights ,
806813 activation = activation ,
807- experts_min = 0 ,
808- experts_max = 7 ,
814+ experts_min = self . experts_min ,
815+ experts_max = self . experts_max ,
809816 measurement_mode = True ,
810817 )
811- output_measure_list = [output ]
818+
819+
820+ amax = []
812821 for i in range (self .num_experts ):
813- output_measure_list .append (intermidiate_amax [i ])
822+ if i in self .experts_range :
823+ amax .append (intermidiate_amax [i - self .experts_min ])
824+ else :
825+ amax .append (torch .tensor (0 , device = "hpu" , dtype = intermidiate_amax [0 ].dtype ))
826+
827+ output_measure_list = [output ] + amax
828+
814829 measure_output (output_measure_list , self ._mod_extra_config .outputs )
815830 return output
816831
0 commit comments