From ee7b595de3425b856be4c96df9f388225cbaf0b7 Mon Sep 17 00:00:00 2001 From: Lpzhan931 <3209311628@qq.com> Date: Sun, 2 Nov 2025 23:35:29 +0800 Subject: [PATCH 1/5] Model: add openPangu-Embedded --- convert_hf_to_gguf.py | 64 ++++++++++++++++++ gguf-py/gguf/constants.py | 20 ++++++ src/CMakeLists.txt | 1 + src/llama-arch.cpp | 18 +++++ src/llama-arch.h | 1 + src/llama-chat.cpp | 35 ++++++++++ src/llama-chat.h | 1 + src/llama-model.cpp | 58 ++++++++++++++++ src/llama-vocab.cpp | 14 ++++ src/models/models.h | 4 ++ src/models/pangu_embedded.cpp | 122 ++++++++++++++++++++++++++++++++++ 11 files changed, 338 insertions(+) create mode 100644 src/models/pangu_embedded.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c6f5ba6a04c54..d4eb6d2768bd4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1270,6 +1270,28 @@ def _set_vocab_llama_hf(self): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_pangu_embedded(self): + tokens, scores, toktypes = self._create_vocab_sentencepiece() + + self.gguf_writer.add_tokenizer_model("pangu_embedded") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + tokenizer_config_file = self.dir_model / "tokenizer_config.json" + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "chat_template" in tokenizer_config_json: + self.gguf_writer.add_chat_template(tokenizer_config_json["chat_template"]) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_rwkv_world(self): assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file() vocab_size = self.hparams.get("vocab_size", 65536) @@ -7186,6 +7208,48 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("PanguEmbeddedForCausalLM") +class PanguEmbeddedModel(TextModel): + model_arch = gguf.MODEL_ARCH.PANGU_EMBED + + def set_vocab(self): + try: + self._set_vocab_pangu_embedded() + except FileNotFoundError: + print("pangu vocab set fail, fallback to sentencepiece!") + self._set_vocab_sentencepiece() + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + # PanguEmbedded's hparam loaded from config.json without head_dim + if (rope_dim := hparams.get("head_dim")) is None: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if (head_dim := hparams.get("head_dim")) is None: + if "hidden_size" in hparams and "num_attention_heads" in hparams: + head_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + + if head_dim is not None: + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid + n_head = self.find_hparam(["n_heads", "num_attention_heads"]) + n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"]) + return [(self.map_tensor_name(name), data_torch)] + @ModelBase.register("Dots1ForCausalLM") class Dots1Model(Qwen2MoeModel): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 77e3b0650ff0b..5c9d49cf14c2f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -426,6 +426,7 @@ class MODEL_ARCH(IntEnum): APERTUS = auto() COGVLM = auto() MINIMAXM2 = auto() + PANGU_EMBED = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -793,6 +794,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.APERTUS: "apertus", MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.COGVLM: "cogvlm", + MODEL_ARCH.PANGU_EMBED: "pangu_embedded", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2958,6 +2960,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.VISEXP_UP, MODEL_TENSOR.VISEXP_DOWN, ], + MODEL_ARCH.PANGU_EMBED: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -3013,6 +3029,10 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE: [ MODEL_TENSOR.ROPE_FREQS, ], + MODEL_ARCH.PANGU_EMBED: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } # diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 832b58e315d09..9074eb3ac84af 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -89,6 +89,7 @@ add_library(llama models/mamba.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/pangu_embedded.cpp models/mpt.cpp models/nemotron-h.cpp models/nemotron.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7c7953b83dda8..03d87466cfb48 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -107,6 +107,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, + {LLM_ARCH_PANGU_EMBED, "pangu_embedded" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2377,6 +2378,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_PANGU_EMBED, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_COGVLM, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 3f893a2dc6916..a769dd1e85741 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -111,6 +111,7 @@ enum llm_arch { LLM_ARCH_APERTUS, LLM_ARCH_MINIMAX_M2, LLM_ARCH_COGVLM, + LLM_ARCH_PANGU_EMBED, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 0285006d73caa..9f92ff039c09b 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -73,6 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, + { "pangu_embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -213,6 +214,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_SEED_OSS; } else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) { return LLM_CHAT_TEMPLATE_GROK_2; + } else if (tmpl_contains("[unused9]") && tmpl_contains("[unused10]")) { + return LLM_CHAT_TEMPLATE_PANGU_EMBED; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -813,6 +816,38 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "Assistant:"; } + }else if (tmpl == LLM_CHAT_TEMPLATE_PANGU_EMBED) { + // [unused9]系统:xxx[unused10] + // [unused9]用户:xxx[unused10] + // [unused9]助手:xxx[unused10] + // ... + for (size_t i = 0; i < chat.size(); ++i) { + const auto & msg = chat[i]; + const std::string & role = msg->role; + const std::string & content = msg->content; + + if (i == 0 && role != "system") { + ss << "[unused9]系统:[unused10]"; + } + + if (role == "system") { + ss << "[unused9]系统:" << content << "[unused10]"; + } else if (role == "user") { + ss << "[unused9]用户:" << content << "[unused10]"; + } else if (role == "assistant") { + ss << "[unused9]助手:" << content << "[unused10]"; + } else if (role == "tool") { + ss << "[unused9]工具:" << content << "[unused10]"; + } else if (role == "function") { + ss << "[unused9]方法:" << content << "[unused10]"; + } else { + // unknown role + ss << "[unused9]" << role << ":" << content << "[unused10]"; + } + } + if (add_ass) { + ss << "[unused9]助手:"; + } } else { // template not supported return -1; diff --git a/src/llama-chat.h b/src/llama-chat.h index da1b7c47997ca..684efb4d67f45 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -53,6 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, + LLM_CHAT_TEMPLATE_PANGU_EMBED, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 896725466ce24..6e6733ee4405c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2177,6 +2177,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PANGU_EMBED: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 + case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -6263,6 +6272,50 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_PANGU_EMBED: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // openPanguEmbedded-1B model's lm_head/output is 'tie_word_embeddings', the 7B model is not + if(type == LLM_TYPE_1B){ + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // weight tensors + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7260,6 +7313,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PANGU_EMBED: + { + llm = std::make_unique(*this, params); + }break; default: GGML_ABORT("fatal error"); } @@ -7479,6 +7536,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_APERTUS: case LLM_ARCH_MINIMAX_M2: case LLM_ARCH_COGVLM: + case LLM_ARCH_PANGU_EMBED: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 735c5d547f9e4..62feb5472a6cb 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1805,6 +1805,20 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = 3; // <|plamo:pad|> special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "pangu_embedded") { + type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + special_bos_id = 1; + special_eos_id = 45892; + special_unk_id = 0; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 0; + special_mask_id = LLAMA_TOKEN_NULL; + + add_space_prefix = true; + add_bos = true; + add_eos = false; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } diff --git a/src/models/models.h b/src/models/models.h index af203343a4d71..b41e2a4e7db6f 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -317,6 +317,10 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_pangu_embedded : public llm_graph_context { + llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/pangu_embedded.cpp b/src/models/pangu_embedded.cpp new file mode 100644 index 0000000000000..fdd15dfe12cc9 --- /dev/null +++ b/src/models/pangu_embedded.cpp @@ -0,0 +1,122 @@ +#include "models.h" + + +llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (model.output_b != nullptr) { + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } \ No newline at end of file From 1c754baae2dcc5c456cc619c0c6aff967e1ce476 Mon Sep 17 00:00:00 2001 From: Lpzhan931 <3209311628@qq.com> Date: Mon, 3 Nov 2025 23:53:21 +0800 Subject: [PATCH 2/5] fixed according to reviewer's comments --- convert_hf_to_gguf.py | 47 +++---------- src/CMakeLists.txt | 2 +- src/llama-arch.cpp | 2 +- src/llama-chat.cpp | 5 +- src/llama-model.cpp | 14 ++-- src/llama-vocab.cpp | 14 ---- src/models/models.h | 8 +-- src/models/pangu-embedded.cpp | 121 +++++++++++++++++++++++++++++++++ src/models/pangu_embedded.cpp | 122 ---------------------------------- 9 files changed, 144 insertions(+), 191 deletions(-) create mode 100644 src/models/pangu-embedded.cpp delete mode 100644 src/models/pangu_embedded.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d4eb6d2768bd4..06403497e6ba5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1270,28 +1270,6 @@ def _set_vocab_llama_hf(self): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab.add_to_gguf(self.gguf_writer) - def _set_vocab_pangu_embedded(self): - tokens, scores, toktypes = self._create_vocab_sentencepiece() - - self.gguf_writer.add_tokenizer_model("pangu_embedded") - self.gguf_writer.add_tokenizer_pre("default") - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - - tokenizer_config_file = self.dir_model / "tokenizer_config.json" - if tokenizer_config_file.is_file(): - with open(tokenizer_config_file, "r", encoding="utf-8") as f: - tokenizer_config_json = json.load(f) - if "chat_template" in tokenizer_config_json: - self.gguf_writer.add_chat_template(tokenizer_config_json["chat_template"]) - if "add_prefix_space" in tokenizer_config_json: - self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) - - special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) - special_vocab.add_to_gguf(self.gguf_writer) - - def _set_vocab_rwkv_world(self): assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file() vocab_size = self.hparams.get("vocab_size", 65536) @@ -7212,12 +7190,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): class PanguEmbeddedModel(TextModel): model_arch = gguf.MODEL_ARCH.PANGU_EMBED - def set_vocab(self): - try: - self._set_vocab_pangu_embedded() - except FileNotFoundError: - print("pangu vocab set fail, fallback to sentencepiece!") - self._set_vocab_sentencepiece() + def set_vocab(self): + self._set_vocab_sentencepiece() tokenizer_config_file = self.dir_model / 'tokenizer_config.json' if tokenizer_config_file.is_file(): @@ -7236,18 +7210,15 @@ def set_gguf_parameters(self): rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) - if (head_dim := hparams.get("head_dim")) is None: - if "hidden_size" in hparams and "num_attention_heads" in hparams: - head_dim = hparams["hidden_size"] // hparams["num_attention_heads"] - - if head_dim is not None: - self.gguf_writer.add_key_length(head_dim) - self.gguf_writer.add_value_length(head_dim) + if hparams.get("head_dim") is None: + self.gguf_writer.add_key_length(rope_dim) + self.gguf_writer.add_value_length(rope_dim) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid - n_head = self.find_hparam(["n_heads", "num_attention_heads"]) - n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"]) + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] return [(self.map_tensor_name(name), data_torch)] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9074eb3ac84af..630b2cddf67e8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -89,7 +89,6 @@ add_library(llama models/mamba.cpp models/minicpm3.cpp models/minimax-m2.cpp - models/pangu_embedded.cpp models/mpt.cpp models/nemotron-h.cpp models/nemotron.cpp @@ -100,6 +99,7 @@ add_library(llama models/openai-moe-iswa.cpp models/openelm.cpp models/orion.cpp + models/pangu-embedded.cpp models/phi2.cpp models/phi3.cpp models/plamo.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 03d87466cfb48..d5654767359ac 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -107,7 +107,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, - {LLM_ARCH_PANGU_EMBED, "pangu_embedded" }, + { LLM_ARCH_PANGU_EMBED, "pangu_embedded" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 9f92ff039c09b..41181b577ad95 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -214,7 +214,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_SEED_OSS; } else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) { return LLM_CHAT_TEMPLATE_GROK_2; - } else if (tmpl_contains("[unused9]") && tmpl_contains("[unused10]")) { + } else if (tmpl_contains("[unused9]") && tmpl_contains("message['content'] + '[unused10]'")) { return LLM_CHAT_TEMPLATE_PANGU_EMBED; } return LLM_CHAT_TEMPLATE_UNKNOWN; @@ -840,9 +840,6 @@ int32_t llm_chat_apply_template( ss << "[unused9]工具:" << content << "[unused10]"; } else if (role == "function") { ss << "[unused9]方法:" << content << "[unused10]"; - } else { - // unknown role - ss << "[unused9]" << role << ":" << content << "[unused10]"; } } if (add_ass) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6e6733ee4405c..c66fadd7ff187 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6275,11 +6275,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_PANGU_EMBED: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // openPanguEmbedded-1B model's lm_head/output is 'tie_word_embeddings', the 7B model is not - if(type == LLM_TYPE_1B){ + // if output is NULL, init from the input tok embed + if(output == NULL){ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } @@ -6295,18 +6297,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0); layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { + } else { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } @@ -6314,7 +6315,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } - } break; default: throw std::runtime_error("unknown architecture"); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 62feb5472a6cb..735c5d547f9e4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1805,20 +1805,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = 3; // <|plamo:pad|> special_mask_id = LLAMA_TOKEN_NULL; - } else if (tokenizer_model == "pangu_embedded") { - type = LLAMA_VOCAB_TYPE_SPM; - - // default special tokens - special_bos_id = 1; - special_eos_id = 45892; - special_unk_id = 0; - special_sep_id = LLAMA_TOKEN_NULL; - special_pad_id = 0; - special_mask_id = LLAMA_TOKEN_NULL; - - add_space_prefix = true; - add_bos = true; - add_eos = false; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } diff --git a/src/models/models.h b/src/models/models.h index b41e2a4e7db6f..2fffb382df2e5 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -317,10 +317,6 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_pangu_embedded : public llm_graph_context { - llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params); -}; - struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; @@ -365,6 +361,10 @@ struct llm_build_orion : public llm_graph_context { llm_build_orion(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_pangu_embedded : public llm_graph_context { + llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_phi2 : public llm_graph_context { llm_build_phi2(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/pangu-embedded.cpp b/src/models/pangu-embedded.cpp new file mode 100644 index 0000000000000..c4f8efd8749ab --- /dev/null +++ b/src/models/pangu-embedded.cpp @@ -0,0 +1,121 @@ +#include "models.h" + + +llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (model.output_b != nullptr) { + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/pangu_embedded.cpp b/src/models/pangu_embedded.cpp deleted file mode 100644 index fdd15dfe12cc9..0000000000000 --- a/src/models/pangu_embedded.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include "models.h" - - -llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self attention - { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - if (model.output_b != nullptr) { - cur = ggml_add(ctx0, cur, model.output_b); - } - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } \ No newline at end of file From 5becaad3cf4e6dfe45133834400f7251274cb794 Mon Sep 17 00:00:00 2001 From: Lpzhan931 <3209311628@qq.com> Date: Tue, 4 Nov 2025 11:19:37 +0800 Subject: [PATCH 3/5] fixed the chat template check condition --- gguf-py/gguf/constants.py | 2 +- src/llama-arch.cpp | 2 +- src/llama-chat.cpp | 4 ++-- src/llama-model.cpp | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5c9d49cf14c2f..6b4b6c5ab075d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -794,7 +794,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.APERTUS: "apertus", MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.COGVLM: "cogvlm", - MODEL_ARCH.PANGU_EMBED: "pangu_embedded", + MODEL_ARCH.PANGU_EMBED: "pangu-embedded", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d5654767359ac..b7642b568dffb 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -107,7 +107,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, - { LLM_ARCH_PANGU_EMBED, "pangu_embedded" }, + { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 41181b577ad95..bfa878e6a274b 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -73,7 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, - { "pangu_embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, + { "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -214,7 +214,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_SEED_OSS; } else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) { return LLM_CHAT_TEMPLATE_GROK_2; - } else if (tmpl_contains("[unused9]") && tmpl_contains("message['content'] + '[unused10]'")) { + } else if (tmpl_contains(LU8("[unused9]系统:[unused10]")) && tmpl_contains("message['content'] + '[unused10]'")) { return LLM_CHAT_TEMPLATE_PANGU_EMBED; } return LLM_CHAT_TEMPLATE_UNKNOWN; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c66fadd7ff187..18159ac5cfd91 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6281,7 +6281,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed - if(output == NULL){ + if (output == NULL){ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } From 09739dfd40a9d2e5df57a8e774e6129a3fe2ffaf Mon Sep 17 00:00:00 2001 From: Li Pengzhan <151381994+Lpzhan931@users.noreply.github.com> Date: Tue, 4 Nov 2025 20:51:53 +0800 Subject: [PATCH 4/5] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit change the chat-template check condition and some formatting issue Co-authored-by: Sigbjørn Skjæret --- src/llama-chat.cpp | 2 +- src/llama-model.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index bfa878e6a274b..743cf2aa32d6f 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -214,7 +214,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_SEED_OSS; } else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) { return LLM_CHAT_TEMPLATE_GROK_2; - } else if (tmpl_contains(LU8("[unused9]系统:[unused10]")) && tmpl_contains("message['content'] + '[unused10]'")) { + } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) { return LLM_CHAT_TEMPLATE_PANGU_EMBED; } return LLM_CHAT_TEMPLATE_UNKNOWN; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 18159ac5cfd91..c94af286f2fea 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6281,7 +6281,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed - if (output == NULL){ + if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } From 95280c06477189081964d83c2920ba31cde3383a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 4 Nov 2025 22:16:02 +0100 Subject: [PATCH 5/5] whitespace cleanup --- convert_hf_to_gguf.py | 5 +++-- src/llama-chat.cpp | 2 +- src/llama-model.cpp | 2 +- src/models/pangu-embedded.cpp | 8 ++++---- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 06403497e6ba5..222f6ed6dc40f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7186,11 +7186,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): return super().modify_tensors(data_torch, name, bid) + @ModelBase.register("PanguEmbeddedForCausalLM") class PanguEmbeddedModel(TextModel): model_arch = gguf.MODEL_ARCH.PANGU_EMBED - - def set_vocab(self): + + def set_vocab(self): self._set_vocab_sentencepiece() tokenizer_config_file = self.dir_model / 'tokenizer_config.json' diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 743cf2aa32d6f..fc6a6223cfe2f 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -820,7 +820,7 @@ int32_t llm_chat_apply_template( // [unused9]系统:xxx[unused10] // [unused9]用户:xxx[unused10] // [unused9]助手:xxx[unused10] - // ... + // ... for (size_t i = 0; i < chat.size(); ++i) { const auto & msg = chat[i]; const std::string & role = msg->role; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c94af286f2fea..1987135ca6a2e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6279,7 +6279,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - + // if output is NULL, init from the input tok embed if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); diff --git a/src/models/pangu-embedded.cpp b/src/models/pangu-embedded.cpp index c4f8efd8749ab..664572a500146 100644 --- a/src/models/pangu-embedded.cpp +++ b/src/models/pangu-embedded.cpp @@ -53,8 +53,8 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -62,7 +62,7 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -75,7 +75,7 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // feed-forward network + // feed-forward network cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);