@@ -743,13 +743,19 @@ class PatchedGaudiMixtralSparseMoeBlock(PatchedModuleBase):
743743 def __init__ (self , mod , parent , mod_extra_config , * args , ** kwargs ):
744744 super ().__init__ (mod , parent , mod_extra_config , * args , ** kwargs )
745745 self .forward = self .forward_orig
746+ self .ep_size = mod .ep_size
747+ self .experts_min = mod .experts_min
748+ self .experts_max = mod .experts_max
749+ self .experts_range = mod .experts_range
750+ self .num_experts = mod .num_experts
751+
746752 if self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
747753 self .dynamic_moe_op = get_quantized_func_wrapper (OP_TYPE .DYNAMIC_MOE , self .scale_format )
748754 self .quant_input = self ._mod_extra_config .inputs [0 ]
749755 self .register_scale ("scale_input" , mod_extra_config .scale .inputs [0 ], self .scale_format )
750756 self .register_scale (
751757 "scale_intermediate" ,
752- [mod_extra_config .scale .inputs [x ] for x in range (1 , self .num_experts + 1 )],
758+ [mod_extra_config .scale .inputs [x ] for x in range (self . experts_min + 1 , self .experts_max + 2 )],
753759 self .scale_format ,
754760 )
755761 mod .call_dynamic_moe_op = get_call_wrapper (self , "call_dynamic_moe_quant_op" )
@@ -762,12 +768,12 @@ def call_dynamic_moe_quant_op(self,
762768 router_weights ,
763769 permuted_weights = False ,
764770 activation = "silu" ):
765- w1_list = [expert . w1 .weight for expert in self .experts ]
766- w2_list = [expert . w2 .weight for expert in self .experts ]
767- w3_list = [expert . w3 .weight for expert in self .experts ]
768- scale_w1 = [expert . w1 .scale_weight for expert in self .experts ]
769- scale_w2 = [expert . w2 .scale_weight for expert in self .experts ]
770- scale_w3 = [expert . w3 .scale_weight for expert in self .experts ]
771+ w1_list = [self . experts [ i ]. w1 .weight for i in self .experts_range ]
772+ w2_list = [self . experts [ i ]. w2 .weight for i in self .experts_range ]
773+ w3_list = [self . experts [ i ]. w3 .weight for i in self .experts_range ]
774+ scale_w1 = [self . experts [ i ]. w1 .scale_weight for i in self .experts_range ]
775+ scale_w2 = [self . experts [ i ]. w2 .scale_weight for i in self .experts_range ]
776+ scale_w3 = [self . experts [ i ]. w3 .scale_weight for i in self .experts_range ]
771777 qinput = self .quant_input (hidden_states )
772778 output = self .dynamic_moe_op (
773779 hidden_states = qinput ,
@@ -783,8 +789,8 @@ def call_dynamic_moe_quant_op(self,
783789 d_scale_intermediate_hidden_states = self .scale_intermediate ,
784790 permuted_weights = False ,
785791 activation = activation ,
786- experts_min = 0 ,
787- experts_max = 7
792+ experts_min = self . experts_min ,
793+ experts_max = self . experts_max ,
788794 )
789795 return output
790796
@@ -794,10 +800,11 @@ def call_dynamic_moe_measure_op(self,
794800 router_weights ,
795801 permuted_weights = True ,
796802 activation = "silu" ):
797- w1_list = [expert . w1 .weight for expert in self .experts ]
798- w2_list = [expert . w2 .weight for expert in self .experts ]
799- w3_list = [expert . w3 .weight for expert in self .experts ]
803+ w1_list = [self . experts [ i ]. w1 .weight for i in self .experts_range ]
804+ w2_list = [self . experts [ i ]. w2 .weight for i in self .experts_range ]
805+ w3_list = [self . experts [ i ]. w3 .weight for i in self .experts_range ]
800806 measure_input ((hidden_states ,), observer = self ._mod_extra_config .inputs )
807+
801808 output , intermidiate_amax = torch .ops .hpu .mixture_of_experts .fp8_measurement (
802809 hidden_states = hidden_states ,
803810 expert_routing_table = expert_routing_table ,
@@ -807,13 +814,21 @@ def call_dynamic_moe_measure_op(self,
807814 w2 = w3_list ,
808815 permuted_weights = permuted_weights ,
809816 activation = activation ,
810- experts_min = 0 ,
811- experts_max = 7 ,
817+ experts_min = self . experts_min ,
818+ experts_max = self . experts_max ,
812819 measurement_mode = True ,
813820 )
814- output_measure_list = [output ]
821+
822+
823+ amax = []
815824 for i in range (self .num_experts ):
816- output_measure_list .append (intermidiate_amax [i ])
825+ if i in self .experts_range :
826+ amax .append (intermidiate_amax [i - self .experts_min ])
827+ else :
828+ amax .append (torch .tensor (0 , device = "hpu" , dtype = intermidiate_amax [0 ].dtype ))
829+
830+ output_measure_list = [output ] + amax
831+
817832 measure_output (output_measure_list , self ._mod_extra_config .outputs )
818833 return output
819834
0 commit comments