Skip to content

Commit a1f476c

Browse files
yiliu30Yi4Liu
andauthored
Fix PatchedMoeMatmul and Get num_experts from Module (#202)
Fix `PatchedMoeMatmul` and Get `num_experts` from Module --------- Signed-off-by: Yi Liu <[email protected]> Co-authored-by: Yi Liu <[email protected]>
1 parent c4ae066 commit a1f476c

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
146146
patched_types.add(type(mod))
147147

148148
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
149+
if mod_type == "dynamic_moe" and hasattr(mod, "num_experts"):
150+
# override default number of outputs for dynamic moe
151+
mod_types[mod_type].num_outputs = mod.num_experts+1
152+
logger.warning(f"Dynamic moe num_outputs set to {mod.num_experts+1}")
149153
mod_extra_config = (
150154
init_measure_object(
151155
mod,

neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def create_mod_info_recursion(parent):
5252

5353
create_mod_info_recursion(model)
5454

55-
56-
INC_DYNAMIC_MOE_EXPERTS = int(os.environ.get("INC_DYNAMIC_MOE_EXPERTS", "8"))
5755

5856
_mod_types = {
5957
"linear": ModuleType(1, ["weight"], 1, False),
@@ -62,7 +60,7 @@ def create_mod_info_recursion(parent):
6260
"kv_cache": ModuleType(1, [], 1, False),
6361
"softmax": ModuleType(1, [], 1, True),
6462
"fused_sdpa": ModuleType(3, [], 2, True),
65-
"dynamic_moe": ModuleType(1, [], 1 + INC_DYNAMIC_MOE_EXPERTS, True),
63+
"dynamic_moe": ModuleType(1, [], 1 + 8, True),
6664
}
6765

6866

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
702702
if (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
703703
measure_input((torch.tensor(0),), observer=self._mod_extra_config.inputs)
704704
else:
705-
self.weight = self.weight.squeeze()
705+
self.weight = torch.nn.Parameter(self.weight.squeeze(), requires_grad=False)
706706

707707
def forward_qdq(self, input, *args, **kwargs):
708708
qinput = self.quant_input(input)

0 commit comments

Comments
 (0)