Skip to content

Commit 648a23d

Browse files
kwisniewski98XuehaoSun
authored andcommitted
[SW-218277]Add support for mixtral with expert parallelism (#177)
* Add support for mixtral with expert parallelism * Remove allreduce from measurement
1 parent 01aa482 commit 648a23d

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -743,13 +743,19 @@ class PatchedGaudiMixtralSparseMoeBlock(PatchedModuleBase):
743743
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
744744
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
745745
self.forward = self.forward_orig
746+
self.ep_size = mod.ep_size
747+
self.experts_min = mod.experts_min
748+
self.experts_max = mod.experts_max
749+
self.experts_range = mod.experts_range
750+
self.num_experts = mod.num_experts
751+
746752
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
747753
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE, self.scale_format)
748754
self.quant_input = self._mod_extra_config.inputs[0]
749755
self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format)
750756
self.register_scale(
751757
"scale_intermediate",
752-
[mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)],
758+
[mod_extra_config.scale.inputs[x] for x in range(self.experts_min+1, self.experts_max +2)],
753759
self.scale_format,
754760
)
755761
mod.call_dynamic_moe_op = get_call_wrapper(self, "call_dynamic_moe_quant_op")
@@ -762,12 +768,12 @@ def call_dynamic_moe_quant_op(self,
762768
router_weights,
763769
permuted_weights=False,
764770
activation="silu"):
765-
w1_list = [expert.w1.weight for expert in self.experts]
766-
w2_list = [expert.w2.weight for expert in self.experts]
767-
w3_list = [expert.w3.weight for expert in self.experts]
768-
scale_w1 = [expert.w1.scale_weight for expert in self.experts]
769-
scale_w2 = [expert.w2.scale_weight for expert in self.experts]
770-
scale_w3 = [expert.w3.scale_weight for expert in self.experts]
771+
w1_list = [self.experts[i].w1.weight for i in self.experts_range]
772+
w2_list = [self.experts[i].w2.weight for i in self.experts_range]
773+
w3_list = [self.experts[i].w3.weight for i in self.experts_range]
774+
scale_w1 = [self.experts[i].w1.scale_weight for i in self.experts_range]
775+
scale_w2 = [self.experts[i].w2.scale_weight for i in self.experts_range]
776+
scale_w3 = [self.experts[i].w3.scale_weight for i in self.experts_range]
771777
qinput = self.quant_input(hidden_states)
772778
output = self.dynamic_moe_op(
773779
hidden_states=qinput,
@@ -783,8 +789,8 @@ def call_dynamic_moe_quant_op(self,
783789
d_scale_intermediate_hidden_states=self.scale_intermediate,
784790
permuted_weights=False,
785791
activation=activation,
786-
experts_min=0,
787-
experts_max=7
792+
experts_min=self.experts_min,
793+
experts_max=self.experts_max,
788794
)
789795
return output
790796

@@ -794,10 +800,11 @@ def call_dynamic_moe_measure_op(self,
794800
router_weights,
795801
permuted_weights=True,
796802
activation="silu"):
797-
w1_list = [expert.w1.weight for expert in self.experts]
798-
w2_list = [expert.w2.weight for expert in self.experts]
799-
w3_list = [expert.w3.weight for expert in self.experts]
803+
w1_list = [self.experts[i].w1.weight for i in self.experts_range]
804+
w2_list = [self.experts[i].w2.weight for i in self.experts_range]
805+
w3_list = [self.experts[i].w3.weight for i in self.experts_range]
800806
measure_input((hidden_states,), observer=self._mod_extra_config.inputs)
807+
801808
output, intermidiate_amax = torch.ops.hpu.mixture_of_experts.fp8_measurement(
802809
hidden_states=hidden_states,
803810
expert_routing_table=expert_routing_table,
@@ -807,13 +814,21 @@ def call_dynamic_moe_measure_op(self,
807814
w2=w3_list,
808815
permuted_weights=permuted_weights,
809816
activation=activation,
810-
experts_min=0,
811-
experts_max=7,
817+
experts_min=self.experts_min,
818+
experts_max=self.experts_max,
812819
measurement_mode=True,
813820
)
814-
output_measure_list = [output]
821+
822+
823+
amax = []
815824
for i in range(self.num_experts):
816-
output_measure_list.append(intermidiate_amax[i])
825+
if i in self.experts_range:
826+
amax.append(intermidiate_amax[i-self.experts_min])
827+
else:
828+
amax.append(torch.tensor(0, device="hpu", dtype=intermidiate_amax[0].dtype))
829+
830+
output_measure_list = [output] + amax
831+
817832
measure_output(output_measure_list, self._mod_extra_config.outputs)
818833
return output
819834

0 commit comments

Comments
 (0)