From b4eb50dfdfa172f9f9fb5af04cbc751725eccd34 Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sat, 13 Sep 2025 02:14:38 +0800 Subject: [PATCH 01/17] init tokenizer configs --- convert_hf_to_gguf.py | 3 +++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 24 ++++++++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bbc21813f81ca..a08c0e61b7e85 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -861,6 +861,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224": # ref: https://huggingface.co/inclusionAI/Ling-lite res = "bailingmoe" + if chkhsh == "0cb00ada81d42e4d3c5c1a74a67cfb4aa250a31de9c47b951a21f82a54c7e336": + # ref: https://huggingface.co/inclusionAI/Ling-lite + res = "bailing-bt2" if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct res = "llama4" diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 575e05e193c2e..2a9e39698d6c2 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -131,6 +131,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", }, {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, + {"name": "bailingmoe-v2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-2.0", }, {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1e88b6505bae0..3de568b7d48d8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -380,6 +380,7 @@ class MODEL_ARCH(IntEnum): WAVTOKENIZER_DEC = auto() PLM = auto() BAILINGMOE = auto() + BAILINGMOE_V2 = auto() DOTS1 = auto() ARCEE = auto() ERNIE4_5 = auto() @@ -715,6 +716,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", MODEL_ARCH.PLM: "plm", MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.BAILINGMOE_V2: "bailingmoe-v2", MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", MODEL_ARCH.ERNIE4_5: "ernie4_5", @@ -2492,6 +2494,25 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.BAILINGMOE_V2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.DOTS1: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2740,6 +2761,9 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE: [ MODEL_TENSOR.ROPE_FREQS, ], + MODEL_ARCH.BAILINGMOE_V2: [ + MODEL_TENSOR.ROPE_FREQS, + ], } # From 7968c69baf7337c83209a4b7d810841186a98d0e Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sat, 13 Sep 2025 02:25:37 +0800 Subject: [PATCH 02/17] init model --- src/llama-arch.cpp | 24 +++++ src/llama-arch.h | 1 + src/llama-model.cpp | 211 ++++++++++++++++++++++++++++++++++++++++++++ src/llama-vocab.cpp | 4 + 4 files changed, 240 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 81f9746818d4a..ed33f340c3898 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -84,6 +84,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_BAILINGMOE_V2, "bailingmoe-v2" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, @@ -1910,6 +1911,29 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_BAILINGMOE_V2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_DOTS1, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 6ee3707dcfbf6..854878055feb6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -88,6 +88,7 @@ enum llm_arch { LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, + LLM_ARCH_BAILINGMOE_V2, LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 818b209641a5a..accea90483c07 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1800,6 +1800,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_BAILINGMOE_V2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + switch (hparams.n_layer) { + case 20: type = LLM_TYPE_16B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_DOTS1: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5309,6 +5323,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); } } break; + case LLM_ARCH_BAILINGMOE_V2: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } break; case LLM_ARCH_DOTS1: { const int64_t n_ff_exp = hparams.n_ff_exp; @@ -6061,6 +6115,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } + if (arch == LLM_ARCH_BAILINGMOE_V2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } + if (arch == LLM_ARCH_SMALLTHINKER) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); @@ -16566,6 +16628,150 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; +struct llm_build_bailingmoe_v2 : public llm_graph_context { + llm_build_bailingmoe_v2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + false, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_dots1 : public llm_graph_context { llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -19042,6 +19248,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_BAILINGMOE_V2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_SEED_OSS: { llm = std::make_unique(*this, params); @@ -19245,6 +19455,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: + case LLM_ARCH_BAILINGMOE_V2: case LLM_ARCH_NEO_BERT: case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index ca02b63a58407..6cab7397a3874 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1958,6 +1958,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "bailingmoe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; + } else if ( + tokenizer_pre == "bailing-bt2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + clean_spaces = false; } else if ( tokenizer_pre == "seed-coder") { pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; From ba8e746fa98eb65b93a29cf18bc477de677f5e2c Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sat, 13 Sep 2025 03:55:44 +0800 Subject: [PATCH 03/17] =?UTF-8?q?[feat]=20Ling=20mini=202.0=20gguf=20?= =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- convert_hf_to_gguf.py | 107 +++++++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 6 +++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a08c0e61b7e85..f8294b8530ab0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -861,7 +861,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224": # ref: https://huggingface.co/inclusionAI/Ling-lite res = "bailingmoe" - if chkhsh == "0cb00ada81d42e4d3c5c1a74a67cfb4aa250a31de9c47b951a21f82a54c7e336": + if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": # ref: https://huggingface.co/inclusionAI/Ling-lite res = "bailing-bt2" if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": @@ -7824,6 +7824,111 @@ def prepare_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("BailingMoeV2ForCausalLM") +class BailingMoeV2Model(TextModel): + model_arch = gguf.MODEL_ARCH.BAILINGMOE_V2 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if (rope_dim := hparams.get("head_dim")) is None: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + + self.gguf_writer.add_rope_dimension_count(rope_dim) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_weights_scale(1.0) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + _experts: list[dict[str, Tensor]] | None = None + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.endswith("query_key_value.weight"): + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + n_embd = self.hparams["hidden_size"] + if (head_dim := self.hparams.get("head_dim")) is None: + head_dim = n_embd // n_head + q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) + + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), self.permute(q, n_head, n_head)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), self.permute(k, n_head, n_kv_head)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) + ] + elif name.find("mlp.experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + tensors: list[tuple[str, Tensor]] = [] + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + + return tensors + + pre_tensor_name_mapping = { + 'attention.dense': 'self_attn.dense', + 'attention.key_layernorm': 'self_attn.key_layernorm', + 'attention.query_layernorm': 'self_attn.query_layernorm', + 'mlp.gate.expert_bias': 'mlp.gate.e_score_correction', + } + for k, v in pre_tensor_name_mapping.items(): + name = name.replace(k, v) + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + @ModelBase.register("ChameleonForConditionalGeneration") @ModelBase.register("ChameleonForCausalLM") # obsolete diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3de568b7d48d8..699e324c3f8b8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2501,10 +2501,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_EXP_PROBS_B, MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, @@ -2512,6 +2515,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE, ], MODEL_ARCH.DOTS1: [ MODEL_TENSOR.TOKEN_EMBD, From 1dee442c24e30bd15e2d2c144f6274b0b8e05053 Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sat, 13 Sep 2025 04:07:36 +0800 Subject: [PATCH 04/17] update set_gguf_parameters --- convert_hf_to_gguf.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f8294b8530ab0..4774abf667b66 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7834,9 +7834,8 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - if (rope_dim := hparams.get("head_dim")) is None: - rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + rope_dim = int(hparams['partial_rotary_factor'] * hparams['head_dim']) self.gguf_writer.add_rope_dimension_count(rope_dim) rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: @@ -7848,10 +7847,16 @@ def set_gguf_parameters(self): self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) - self.gguf_writer.add_expert_weights_scale(1.0) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + if hparams["score_function"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["score_function"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported score_function value: {hparams['score_function']}") _experts: list[dict[str, Tensor]] | None = None From 63a2f54e98278cf5a69d3c9d0f885e2f697e5ab5 Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sat, 13 Sep 2025 04:27:24 +0800 Subject: [PATCH 05/17] update llm tensor names --- src/llama-arch.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ed33f340c3898..313fba2b94d8a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1920,10 +1920,13 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, @@ -1932,6 +1935,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, { From 69177c7bf520cd79912c620e692c85b5162eb17e Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sat, 13 Sep 2025 05:05:40 +0800 Subject: [PATCH 06/17] update load tensors --- src/llama-model.cpp | 45 ++++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index accea90483c07..57d35c51efffa 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5339,28 +5339,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } } } break; case LLM_ARCH_DOTS1: From 3fe767697aeea2775115ed8a5bad13407394b2fc Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sat, 13 Sep 2025 05:34:39 +0800 Subject: [PATCH 07/17] update llm graph --- src/llama-model.cpp | 86 +++++++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 57d35c51efffa..4e25a7547f7d5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6118,15 +6118,8 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } - - if (arch == LLM_ARCH_BAILINGMOE_V2) { + if (arch == LLM_ARCH_BAILINGMOE || + arch == LLM_ARCH_BAILINGMOE_V2) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); @@ -16693,12 +16686,18 @@ struct llm_build_bailingmoe_v2 : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -16727,41 +16726,52 @@ struct llm_build_bailingmoe_v2 : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - ggml_tensor * moe_out = - build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - false, hparams.expert_weights_scale, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(moe_out, "ffn_moe_out", il); - - // FFN shared expert - { - ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "ffn_shexp", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); cb(cur, "ffn_out", il); - } + } else { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); - cur = ggml_add(ctx0, cur, ffn_inp); + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } - // input for next layer - inpL = cur; + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } } cur = inpL; From 94ec7dc841f21ba81e6a91185803ac61490b6845 Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sun, 14 Sep 2025 01:42:55 +0800 Subject: [PATCH 08/17] [fix] fix expert_bias convert hf to gguf for ling mini 2.0 --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4774abf667b66..df20b633d9201 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7917,7 +7917,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter 'attention.dense': 'self_attn.dense', 'attention.key_layernorm': 'self_attn.key_layernorm', 'attention.query_layernorm': 'self_attn.query_layernorm', - 'mlp.gate.expert_bias': 'mlp.gate.e_score_correction', + 'mlp.gate.expert_bias': 'mlp.gate.e_score_correction.bias', } for k, v in pre_tensor_name_mapping.items(): name = name.replace(k, v) From a2a2299ce7ef917e767ed4acab91d5f196a1d2aa Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sun, 14 Sep 2025 03:25:20 +0800 Subject: [PATCH 09/17] [fix] fix llm graph for ling mini 2.0 --- src/llama-model.cpp | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4e25a7547f7d5..d3e4c29dd2ff3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1808,6 +1808,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // Ling 2.0 use sigmoid gating func + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } switch (hparams.n_layer) { case 20: type = LLM_TYPE_16B; break; @@ -16634,6 +16639,9 @@ struct llm_build_bailingmoe : public llm_graph_context { struct llm_build_bailingmoe_v2 : public llm_graph_context { llm_build_bailingmoe_v2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + ggml_tensor * cur; ggml_tensor * inpL; @@ -16682,9 +16690,9 @@ struct llm_build_bailingmoe_v2 : public llm_graph_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -16710,7 +16718,7 @@ struct llm_build_bailingmoe_v2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16763,15 +16771,15 @@ struct llm_build_bailingmoe_v2 : public llm_graph_context { cur = ggml_add(ctx0, moe_out, ffn_shexp); cb(cur, "ffn_out", il); } + } - cur = ggml_add(ctx0, cur, ffn_inp); + cur = ggml_add(ctx0, cur, ffn_inp); - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); - // input for next layer - inpL = cur; - } + // input for next layer + inpL = cur; } cur = inpL; From a6b3ca8272d2c6860dad28443f8744629183bad3 Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sun, 14 Sep 2025 04:10:23 +0800 Subject: [PATCH 10/17] [feat] add chat template for ling 2.0 --- src/llama-chat.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 9d8e57eac1f69..985cb6de1df0c 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -63,6 +63,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, { "bailing", LLM_CHAT_TEMPLATE_BAILING }, + { "bailing-v2", LLM_CHAT_TEMPLATE_BAILING }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, From c72e3994c783fc935494ffcf6aeb287bac7a03bd Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sun, 14 Sep 2025 06:53:23 +0800 Subject: [PATCH 11/17] [feat] add chat template --- src/llama-chat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 985cb6de1df0c..e9bac0f74136e 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -189,7 +189,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MEGREZ; } else if (tmpl_contains(" Ассистент:")) { return LLM_CHAT_TEMPLATE_YANDEX; - } else if (tmpl_contains("ASSISTANT") && tmpl_contains("'HUMAN'")) { + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("HUMAN")) { return LLM_CHAT_TEMPLATE_BAILING; } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { return LLM_CHAT_TEMPLATE_LLAMA4; From b359533f0dad227a4c3ebde3d003c0dde3421e9a Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sun, 14 Sep 2025 19:06:36 +0800 Subject: [PATCH 12/17] [fix] fix for eog token --- convert_hf_to_gguf.py | 12 ++---------- src/llama-vocab.cpp | 1 + 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index df20b633d9201..07796023776aa 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7860,14 +7860,6 @@ def set_gguf_parameters(self): _experts: list[dict[str, Tensor]] | None = None - @staticmethod - def permute(weights: Tensor, n_head: int, n_head_kv: int | None): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name.endswith("query_key_value.weight"): n_head = self.hparams["num_attention_heads"] @@ -7878,8 +7870,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), self.permute(q, n_head, n_head)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), self.permute(k, n_head, n_kv_head)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) ] elif name.find("mlp.experts") != -1: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 6cab7397a3874..e103c5b8e39e8 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2327,6 +2327,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "_" || t.first == "<|end_of_text|>" || t.first == "" // smoldocling + || t.first == "<|role_end|>" // Ling v2 ) { special_eog_ids.insert(t.second); if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { From e078a635c76b548f6e661f60ea1317286113902d Mon Sep 17 00:00:00 2001 From: im0qianqian Date: Sun, 14 Sep 2025 20:00:10 +0800 Subject: [PATCH 13/17] [fix] skip mtp layer --- convert_hf_to_gguf.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 07796023776aa..92ac4674ec01b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7861,6 +7861,12 @@ def set_gguf_parameters(self): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + if name.endswith("query_key_value.weight"): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") From 09e3df4bb5035fb3b214adfa94d539ff515f0e54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=81=E6=84=9A?= Date: Tue, 16 Sep 2025 17:29:18 +0800 Subject: [PATCH 14/17] =?UTF-8?q?[fix]=20fix=20weight=20convert=20for=20Li?= =?UTF-8?q?ng=20mini=202.0=20w/=20half=20rotary=20=F0=9F=90=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- convert_hf_to_gguf.py | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 92ac4674ec01b..52ca3ecbda911 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7828,6 +7828,28 @@ def prepare_tensors(self): class BailingMoeV2Model(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE_V2 + @staticmethod + def permute( + weights: Tensor, n_head: int, n_head_kv: int | None, rope_dim: int | None + ): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + if rope_dim is None: + rope_dim = weights.shape[0] // n_head + weights_rope, weights_nope = weights.reshape( + n_head, weights.shape[0] // n_head, *weights.shape[1:] + ).split([rope_dim, weights.shape[0] // n_head - rope_dim], dim=1) + return torch.cat( + [ + weights_rope.reshape( + n_head, 2, rope_dim // 2, *weights_rope.shape[2:] + ) + .swapaxes(1, 2) + .reshape(weights_rope.shape), + weights_nope, + ], dim=1 + ).reshape(weights.shape) + def set_vocab(self): self._set_vocab_gpt2() @@ -7867,6 +7889,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if match and int(match.group(1)) >= block_count: return [] + rope_dim = int(self.hparams['partial_rotary_factor'] * self.hparams['head_dim']) if name.endswith("query_key_value.weight"): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") @@ -7876,10 +7899,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeV2Model.permute(q, n_head, n_head, rope_dim)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeV2Model.permute(k, n_head, n_kv_head, rope_dim)), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) ] + elif "attention.key_layernorm" in name or "attention.query_layernorm" in name: + mapping = { + "attention.key_layernorm": "self_attn.key_layernorm", + "attention.query_layernorm": "self_attn.query_layernorm", + } + for k, v in mapping.items(): + name = name.replace(k, v) + return [(self.map_tensor_name(name), BailingMoeV2Model.permute(data_torch, 1, 1, rope_dim))] elif name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None @@ -7912,10 +7943,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return tensors pre_tensor_name_mapping = { - 'attention.dense': 'self_attn.dense', - 'attention.key_layernorm': 'self_attn.key_layernorm', - 'attention.query_layernorm': 'self_attn.query_layernorm', - 'mlp.gate.expert_bias': 'mlp.gate.e_score_correction.bias', + "attention.dense": "self_attn.dense", + "mlp.gate.expert_bias": "mlp.gate.e_score_correction.bias", } for k, v in pre_tensor_name_mapping.items(): name = name.replace(k, v) From 1c6ec2a897dda6bf965c0aa09a972a129d8a0c4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=81=E6=84=9A?= Date: Wed, 17 Sep 2025 08:19:00 +0800 Subject: [PATCH 15/17] [fix] update vocab ref to Ling mini 2.0 --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 52ca3ecbda911..89f48dbfa8900 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -862,7 +862,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: # ref: https://huggingface.co/inclusionAI/Ling-lite res = "bailingmoe" if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": - # ref: https://huggingface.co/inclusionAI/Ling-lite + # ref: https://huggingface.co/inclusionAI/Ling-mini-2.0 res = "bailing-bt2" if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct From 3709961f4b805d488efc386b19483a22c0c51067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=81=E6=84=9A?= Date: Wed, 17 Sep 2025 14:42:06 +0800 Subject: [PATCH 16/17] [feat] add support for Ling Lite 2.0 --- src/llama-model.cpp | 1 + src/llama-model.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d3e4c29dd2ff3..11e42032397b2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1816,6 +1816,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 20: type = LLM_TYPE_16B; break; + case 32: type = LLM_TYPE_100B; break; default: type = LLM_TYPE_UNKNOWN; } } break; diff --git a/src/llama-model.h b/src/llama-model.h index 10b1767f27228..7ac6e01d6b821 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -80,6 +80,7 @@ enum llm_type { LLM_TYPE_40B, LLM_TYPE_65B, LLM_TYPE_70B, + LLM_TYPE_100B, LLM_TYPE_120B, LLM_TYPE_142B, LLM_TYPE_236B, From 48ddb75588a0090f3e42de7c575d3f19bf8a275e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=81=E6=84=9A?= Date: Thu, 18 Sep 2025 21:17:41 +0800 Subject: [PATCH 17/17] [fix] use NEOX, and remove permute & split in convert process --- convert_hf_to_gguf.py | 37 ++++--------------------------------- src/llama-model.cpp | 2 +- 2 files changed, 5 insertions(+), 34 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 89f48dbfa8900..0db59b1cd782c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7828,28 +7828,6 @@ def prepare_tensors(self): class BailingMoeV2Model(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE_V2 - @staticmethod - def permute( - weights: Tensor, n_head: int, n_head_kv: int | None, rope_dim: int | None - ): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - if rope_dim is None: - rope_dim = weights.shape[0] // n_head - weights_rope, weights_nope = weights.reshape( - n_head, weights.shape[0] // n_head, *weights.shape[1:] - ).split([rope_dim, weights.shape[0] // n_head - rope_dim], dim=1) - return torch.cat( - [ - weights_rope.reshape( - n_head, 2, rope_dim // 2, *weights_rope.shape[2:] - ) - .swapaxes(1, 2) - .reshape(weights_rope.shape), - weights_nope, - ], dim=1 - ).reshape(weights.shape) - def set_vocab(self): self._set_vocab_gpt2() @@ -7889,7 +7867,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if match and int(match.group(1)) >= block_count: return [] - rope_dim = int(self.hparams['partial_rotary_factor'] * self.hparams['head_dim']) if name.endswith("query_key_value.weight"): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") @@ -7899,18 +7876,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeV2Model.permute(q, n_head, n_head, rope_dim)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeV2Model.permute(k, n_head, n_kv_head, rope_dim)), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) ] - elif "attention.key_layernorm" in name or "attention.query_layernorm" in name: - mapping = { - "attention.key_layernorm": "self_attn.key_layernorm", - "attention.query_layernorm": "self_attn.query_layernorm", - } - for k, v in mapping.items(): - name = name.replace(k, v) - return [(self.map_tensor_name(name), BailingMoeV2Model.permute(data_torch, 1, 1, rope_dim))] elif name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None @@ -7945,6 +7914,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter pre_tensor_name_mapping = { "attention.dense": "self_attn.dense", "mlp.gate.expert_bias": "mlp.gate.e_score_correction.bias", + "attention.key_layernorm": "self_attn.key_layernorm", + "attention.query_layernorm": "self_attn.query_layernorm", } for k, v in pre_tensor_name_mapping.items(): name = name.replace(k, v) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 11e42032397b2..dcb970dcb25ef 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -19485,7 +19485,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: - case LLM_ARCH_BAILINGMOE_V2: case LLM_ARCH_NEO_BERT: case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: @@ -19539,6 +19538,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: + case LLM_ARCH_BAILINGMOE_V2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: