Skip to content

Commit 35d6ad0

Browse files
[SW-218277]Add support for mixtral with expert parallelism (#177)
* Add support for mixtral with expert parallelism * Remove allreduce from measurement
1 parent 9d9188b commit 35d6ad0

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -740,13 +740,19 @@ class PatchedGaudiMixtralSparseMoeBlock(PatchedModuleBase):
740740
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
741741
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
742742
self.forward = self.forward_orig
743+
self.ep_size = mod.ep_size
744+
self.experts_min = mod.experts_min
745+
self.experts_max = mod.experts_max
746+
self.experts_range = mod.experts_range
747+
self.num_experts = mod.num_experts
748+
743749
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
744750
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE, self.scale_format)
745751
self.quant_input = self._mod_extra_config.inputs[0]
746752
self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format)
747753
self.register_scale(
748754
"scale_intermediate",
749-
[mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)],
755+
[mod_extra_config.scale.inputs[x] for x in range(self.experts_min+1, self.experts_max +2)],
750756
self.scale_format,
751757
)
752758
mod.call_dynamic_moe_op = get_call_wrapper(self, "call_dynamic_moe_quant_op")
@@ -759,12 +765,12 @@ def call_dynamic_moe_quant_op(self,
759765
router_weights,
760766
permuted_weights=False,
761767
activation="silu"):
762-
w1_list = [expert.w1.weight for expert in self.experts]
763-
w2_list = [expert.w2.weight for expert in self.experts]
764-
w3_list = [expert.w3.weight for expert in self.experts]
765-
scale_w1 = [expert.w1.scale_weight for expert in self.experts]
766-
scale_w2 = [expert.w2.scale_weight for expert in self.experts]
767-
scale_w3 = [expert.w3.scale_weight for expert in self.experts]
768+
w1_list = [self.experts[i].w1.weight for i in self.experts_range]
769+
w2_list = [self.experts[i].w2.weight for i in self.experts_range]
770+
w3_list = [self.experts[i].w3.weight for i in self.experts_range]
771+
scale_w1 = [self.experts[i].w1.scale_weight for i in self.experts_range]
772+
scale_w2 = [self.experts[i].w2.scale_weight for i in self.experts_range]
773+
scale_w3 = [self.experts[i].w3.scale_weight for i in self.experts_range]
768774
qinput = self.quant_input(hidden_states)
769775
output = self.dynamic_moe_op(
770776
hidden_states=qinput,
@@ -780,8 +786,8 @@ def call_dynamic_moe_quant_op(self,
780786
d_scale_intermediate_hidden_states=self.scale_intermediate,
781787
permuted_weights=False,
782788
activation=activation,
783-
experts_min=0,
784-
experts_max=7
789+
experts_min=self.experts_min,
790+
experts_max=self.experts_max,
785791
)
786792
return output
787793

@@ -791,10 +797,11 @@ def call_dynamic_moe_measure_op(self,
791797
router_weights,
792798
permuted_weights=True,
793799
activation="silu"):
794-
w1_list = [expert.w1.weight for expert in self.experts]
795-
w2_list = [expert.w2.weight for expert in self.experts]
796-
w3_list = [expert.w3.weight for expert in self.experts]
800+
w1_list = [self.experts[i].w1.weight for i in self.experts_range]
801+
w2_list = [self.experts[i].w2.weight for i in self.experts_range]
802+
w3_list = [self.experts[i].w3.weight for i in self.experts_range]
797803
measure_input((hidden_states,), observer=self._mod_extra_config.inputs)
804+
798805
output, intermidiate_amax = torch.ops.hpu.mixture_of_experts.fp8_measurement(
799806
hidden_states=hidden_states,
800807
expert_routing_table=expert_routing_table,
@@ -804,13 +811,21 @@ def call_dynamic_moe_measure_op(self,
804811
w2=w3_list,
805812
permuted_weights=permuted_weights,
806813
activation=activation,
807-
experts_min=0,
808-
experts_max=7,
814+
experts_min=self.experts_min,
815+
experts_max=self.experts_max,
809816
measurement_mode=True,
810817
)
811-
output_measure_list = [output]
818+
819+
820+
amax = []
812821
for i in range(self.num_experts):
813-
output_measure_list.append(intermidiate_amax[i])
822+
if i in self.experts_range:
823+
amax.append(intermidiate_amax[i-self.experts_min])
824+
else:
825+
amax.append(torch.tensor(0, device="hpu", dtype=intermidiate_amax[0].dtype))
826+
827+
output_measure_list = [output] + amax
828+
814829
measure_output(output_measure_list, self._mod_extra_config.outputs)
815830
return output
816831

0 commit comments

Comments
 (0)