Skip to content

Commit b3d334c

Browse files
yiliu30Yi4Liu
authored andcommitted
Fix PatchedMoeMatmul and Get num_experts from Module (#202)
Fix `PatchedMoeMatmul` and Get `num_experts` from Module --------- Signed-off-by: Yi Liu <[email protected]> Co-authored-by: Yi Liu <[email protected]>
1 parent 3d9a24c commit b3d334c

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
143143
patched_types.add(type(mod))
144144

145145
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
146+
if mod_type == "dynamic_moe" and hasattr(mod, "num_experts"):
147+
# override default number of outputs for dynamic moe
148+
mod_types[mod_type].num_outputs = mod.num_experts+1
149+
logger.warning(f"Dynamic moe num_outputs set to {mod.num_experts+1}")
146150
mod_extra_config = (
147151
init_measure_object(
148152
mod,

neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,14 @@ def create_mod_info_recursion(parent):
5252

5353
create_mod_info_recursion(model)
5454

55-
56-
INC_DYNAMIC_MOE_EXPERTS = int(os.environ.get("INC_DYNAMIC_MOE_EXPERTS", "8"))
5755

5856
_mod_types = {
5957
"linear": ModuleType(1, ["weight"], 1, False),
6058
"matmul": ModuleType(2, [], 1, False),
6159
"kv_cache": ModuleType(1, [], 1, False),
6260
"softmax": ModuleType(1, [], 1, True),
6361
"fused_sdpa": ModuleType(3, [], 2, True),
64-
"dynamic_moe": ModuleType(1, [], 1 + INC_DYNAMIC_MOE_EXPERTS, True),
62+
"dynamic_moe": ModuleType(1, [], 1 + 8, True),
6563
}
6664

6765

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
630630
if (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
631631
measure_input((torch.tensor(0),), observer=self._mod_extra_config.inputs)
632632
else:
633-
self.weight = self.weight.squeeze()
633+
self.weight = torch.nn.Parameter(self.weight.squeeze(), requires_grad=False)
634634

635635
def forward_qdq(self, input, *args, **kwargs):
636636
qinput = self.quant_input(input)

0 commit comments

Comments
 (0)