diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 222f6ed6dc40f..a4971c3825cd9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9807,6 +9807,86 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Glm4vForConditionalGeneration") +class GLM4VModel(Glm4Model): + """Text model from [zai-org/GLM-4.1V-9B-Thinking](https://huggingface.co/zai-org/GLM-4.1V-9B-Thinking) + + ref: [#16600](https://github.com/ggml-org/llama.cpp/pull/16600)""" + model_arch = gguf.MODEL_ARCH.GLM4V + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + # skip vision tensors for the text model + if name.startswith("model.visual."): + return [] + + # the Glm4Model class expects tensor names to start with 'model.', + # so we strip the we strip the 'language_model.' part + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "model.", 1) + + # let the Glm4Model class handle the tensor mapping + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Glm4vMoeForConditionalGeneration") +class GLM4VMoEModel(Glm4MoeModel): + """Text model from [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V) + + ref: [#16600](https://github.com/ggml-org/llama.cpp/pull/16600)""" + model_arch = gguf.MODEL_ARCH.GLM4V_MOE + + def set_gguf_parameters(self): + # parameters specific to GLM-4.5V like rope_theta=10000 and context_length=65536 + # should be correctly picked up from the text_config by the base classes + super().set_gguf_parameters() + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + # skip vision tensors for the text model + if name.startswith("model.visual."): + return [] + + # the Glm4MoeModel class expects tensor names to start with 'model.', + # so we strip the we strip the 'language_model.' part + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "model.", 1) + + # let the Glm4MoeModel class handle the MoE logic and tensor mapping + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Glm4vMoeForConditionalGeneration", "Glm4vForConditionalGeneration") +class GLM4VisionModel(MmprojModel): + """Multimodal projector from: + - [zai-org/GLM-4.1V-9B-Thinking](https://huggingface.co/zai-org/GLM-4.1V-9B-Thinking) + - [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V) + + ref: [#16600](https://github.com/ggml-org/llama.cpp/pull/16600)""" + # + # TODO: conversion logic is still WIP! + # + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + vparams = self.hparams_vision + ln_eps = vparams.get("layer_norm_eps", 1e-5) + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V) + self.gguf_writer.add_vision_attention_layernorm_eps(ln_eps) + self.gguf_writer.add_vision_use_silu(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("model.visual."): + yield self.map_tensor_name(name), data_torch + else: + return @ModelBase.register("CogVLMForCausalLM") class CogVLMVisionModel(MmprojModel): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6b4b6c5ab075d..d70a74db39f5f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -391,6 +391,8 @@ class MODEL_ARCH(IntEnum): CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() + GLM4V = auto() + GLM4V_MOE = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -437,6 +439,7 @@ class VISION_PROJECTOR_TYPE(IntEnum): GLM_EDGE = auto() MERGER = auto() GEMMA3 = auto() + GLM4V = auto() QWEN3VL = auto() COGVLM = auto() @@ -683,10 +686,10 @@ class MODEL_TENSOR(IntEnum): A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() # nextn/mtp - NEXTN_EH_PROJ = auto() - NEXTN_EMBED_TOKENS = auto() - NEXTN_ENORM = auto() - NEXTN_HNORM = auto() + NEXTN_EH_PROJ = auto() + NEXTN_EMBED_TOKENS = auto() + NEXTN_ENORM = auto() + NEXTN_HNORM = auto() NEXTN_SHARED_HEAD_HEAD = auto() NEXTN_SHARED_HEAD_NORM = auto() @@ -757,7 +760,9 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", - MODEL_ARCH.GLM4_MOE: "glm4moe", + MODEL_ARCH.GLM4_MOE: "glm4_moe", + MODEL_ARCH.GLM4V: "glm4v", + MODEL_ARCH.GLM4V_MOE: "glm4v_moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -805,6 +810,7 @@ class MODEL_TENSOR(IntEnum): VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter", VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger", VISION_PROJECTOR_TYPE.GEMMA3: "gemma3", + VISION_PROJECTOR_TYPE.GLM4V: "glm4v", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -2365,6 +2371,46 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.GLM4V : [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], + MODEL_ARCH.GLM4V_MOE: [ # same as GLM4_MOE without MTP tensors + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, @@ -3204,6 +3250,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + GLM4V = "glm4v" LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 630b2cddf67e8..0988f3114819d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -70,6 +70,8 @@ add_library(llama models/gemma3n-iswa.cpp models/glm4-moe.cpp models/glm4.cpp + models/glm4v-moe.cpp + models/glm4v.cpp models/gpt2.cpp models/gptneox.cpp models/granite-hybrid.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568dffb..8346af56eb569 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -67,7 +67,9 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, - { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM4_MOE, "glm4_moe" }, + { LLM_ARCH_GLM4V, "glm4v" }, + { LLM_ARCH_GLM4V_MOE, "glm4v_moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1506,7 +1508,7 @@ static const std::map> LLM_TENSOR_N LLM_ARCH_GLM4, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, // does this really exist? { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, @@ -1555,6 +1557,51 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, }, }, + { + LLM_ARCH_GLM4V, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, + { + LLM_ARCH_GLM4V_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_BITNET, { diff --git a/src/llama-arch.h b/src/llama-arch.h index a769dd1e85741..df221c9bd2480 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -72,6 +72,8 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM4V, + LLM_ARCH_GLM4V_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -129,7 +131,6 @@ enum llm_kv { LLM_KV_GENERAL_LICENSE, LLM_KV_GENERAL_SOURCE_URL, LLM_KV_GENERAL_SOURCE_HF_REPO, - LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b199e94628fff..475e4f3e0ecdc 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -820,8 +820,13 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if ( + arch == LLM_ARCH_GLM4 || + arch == LLM_ARCH_GLM4_MOE || + arch == LLM_ARCH_GLM4V || + arch == LLM_ARCH_GLM4V_MOE + ) { + // GLM4 models seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1618,8 +1623,13 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if ( + arch == LLM_ARCH_GLM4 || + arch == LLM_ARCH_GLM4_MOE || + arch == LLM_ARCH_GLM4V || + arch == LLM_ARCH_GLM4V_MOE + ) { + // GLM4 models seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14f82..c23f0d9fb565a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1659,6 +1659,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4V: + { + // TODO + } break; + case LLM_ARCH_GLM4V_MOE: + { + // TODO + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5002,6 +5010,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM4V: + { + // TODO + } + break; + case LLM_ARCH_GLM4V_MOE: + { + // TODO + } + break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7138,6 +7156,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4V: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_GLM4V_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); @@ -7531,6 +7557,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_GLM4V: + case LLM_ARCH_GLM4V_MOE: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: diff --git a/src/models/glm4v-moe.cpp b/src/models/glm4v-moe.cpp new file mode 100644 index 0000000000000..09bf8abbb2db9 --- /dev/null +++ b/src/models/glm4v-moe.cpp @@ -0,0 +1,156 @@ +#include "models.h" + +llm_build_glm4v_moe::llm_build_glm4v_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + // + // TODO -- currently this is just copied from `llm_build_glm4_moe` -- still WIP + // + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/glm4v.cpp b/src/models/glm4v.cpp new file mode 100644 index 0000000000000..0ec311f3f0739 --- /dev/null +++ b/src/models/glm4v.cpp @@ -0,0 +1,128 @@ +#include "models.h" + +llm_build_glm4v::llm_build_glm4v(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + // + // TODO -- currently this is just copied from `llm_build_glm4` -- still WIP + // + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], + 0 * sizeof(float) * (n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd)); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + } + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + // Post-attention norm (new!) + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_mlp_norm", il); + } + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + // Final norm + cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 2fffb382df2e5..8a4bfa4e730d7 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -216,6 +216,14 @@ struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_glm4v : public llm_graph_context { + llm_build_glm4v(const llama_model & model, const llm_graph_params & params); +}; + +struct llm_build_glm4v_moe : public llm_graph_context { + llm_build_glm4v_moe(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_gpt2 : public llm_graph_context { llm_build_gpt2(const llama_model & model, const llm_graph_params & params); }; diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 722b1a4948d6f..f61e1e0ee1f78 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -156,32 +156,34 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_UNKNOWN, }; static std::map PROJECTOR_TYPE_NAMES = { - { PROJECTOR_TYPE_MLP, "mlp" }, - { PROJECTOR_TYPE_LDP, "ldp" }, - { PROJECTOR_TYPE_LDPV2, "ldpv2"}, - { PROJECTOR_TYPE_MINICPMV, "resampler"}, - { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, - { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"}, - { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, - { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"}, - { PROJECTOR_TYPE_GEMMA3, "gemma3"}, - { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, - { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, - { PROJECTOR_TYPE_ULTRAVOX, "ultravox"}, - { PROJECTOR_TYPE_INTERNVL, "internvl"}, - { PROJECTOR_TYPE_LLAMA4, "llama4"}, - { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, - { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, - { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, - { PROJECTOR_TYPE_LFM2, "lfm2"}, - { PROJECTOR_TYPE_KIMIVL, "kimivl"}, - { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, - { PROJECTOR_TYPE_COGVLM, "cogvlm"}, - { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_MLP, "mlp" }, + { PROJECTOR_TYPE_LDP, "ldp" }, + { PROJECTOR_TYPE_LDPV2, "ldpv2" }, + { PROJECTOR_TYPE_MINICPMV, "resampler" }, + { PROJECTOR_TYPE_GLM_EDGE, "adapter" }, + { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger" }, + { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger" }, + { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger" }, + { PROJECTOR_TYPE_GEMMA3, "gemma3" }, + { PROJECTOR_TYPE_IDEFICS3, "idefics3" }, + { PROJECTOR_TYPE_PIXTRAL, "pixtral" }, + { PROJECTOR_TYPE_ULTRAVOX, "ultravox" }, + { PROJECTOR_TYPE_INTERNVL, "internvl" }, + { PROJECTOR_TYPE_LLAMA4, "llama4" }, + { PROJECTOR_TYPE_QWEN2A, "qwen2a" }, + { PROJECTOR_TYPE_QWEN25O, "qwen2.5o" }, + { PROJECTOR_TYPE_VOXTRAL, "voxtral" }, + { PROJECTOR_TYPE_LFM2, "lfm2" }, + { PROJECTOR_TYPE_KIMIVL, "kimivl" }, + { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr" }, + { PROJECTOR_TYPE_COGVLM, "cogvlm" }, + { PROJECTOR_TYPE_JANUS_PRO, "janus_pro" }, + { PROJECTOR_TYPE_GLM4V, "glm4v" }, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d1423b67f9865..c0b7bc5f0c36c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -281,7 +281,7 @@ struct clip_model { // embeddings ggml_tensor * class_embedding = nullptr; ggml_tensor * patch_embeddings_0 = nullptr; - ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) + ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temporal dimension (Qwen2VL, GLM4V) ggml_tensor * patch_bias = nullptr; ggml_tensor * position_embeddings = nullptr; @@ -400,6 +400,22 @@ struct clip_model { ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + // GLM4V projection + ggml_tensor * mm_post_conv_ln_w = nullptr; + ggml_tensor * mm_post_conv_ln_b = nullptr; + ggml_tensor * mm_downsample_w = nullptr; + ggml_tensor * mm_downsample_b = nullptr; + ggml_tensor * mm_merger_proj_w = nullptr; + ggml_tensor * mm_merger_proj_b = nullptr; + ggml_tensor * mm_merger_norm_w = nullptr; + ggml_tensor * mm_merger_norm_b = nullptr; + ggml_tensor * mm_merger_gate_w = nullptr; + ggml_tensor * mm_merger_gate_b = nullptr; + ggml_tensor * mm_merger_up_w = nullptr; + ggml_tensor * mm_merger_up_b = nullptr; + ggml_tensor * mm_merger_down_w = nullptr; + ggml_tensor * mm_merger_down_b = nullptr; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; @@ -1082,6 +1098,125 @@ struct clip_graph { return gf; } + ggml_cgraph * build_glm4v() { + GGML_ASSERT(model.patch_embeddings_0 != nullptr); + GGML_ASSERT(model.patch_embeddings_1 != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + // 2D RoPE input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp; + + // patch embedding + // - this is similar to Qwen2VL's handling of Conv3d for video/image inputs + // - for single images, the input is duplicated along the temporal axis + // + // ref: `class Glm4vVisionPatchEmbed(Qwen2_5_VisionPatchEmbed):` + + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + if (model.patch_embeddings_1) { + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + } + + const int batch_size = 1; + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d(ctx0, inp, n_embd, n_patches_x * n_patches_y, batch_size); + cb(inp, "patch_embed", -1); + + // post-convolution layernorm + // + // ref: `self.post_conv_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)` + inp = build_norm(inp, model.mm_post_conv_ln_w, model.mm_post_conv_ln_b, NORM_TYPE_RMS, eps, -1); + cb(inp, "post_conv_ln", -1); + + // absolute position embeddings (interpolated) + // + // ref: self.embeddings + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + inp = ggml_add(ctx0, inp, learned_pos_embd); + cb(inp, "abs_pos_embed", -1); + + // RoPE to be applied inside ViT blocks + // + // ref: self.rotary_pos_emb + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + + // ViT blocks + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_RMS, + FFN_SILU, // hidden_act is "silu" + nullptr, // absolute embeddings already added + add_pos); + + // post-ViT layernorm + cur = build_norm(cur, model.post_ln_w, model.post_ln_b, NORM_TYPE_RMS, eps, -1); + cb(cur, "post_vit_ln", -1); + + // reshape and permute to prepare for conv2d + const int merge_size = model.hparams.n_merge; // WIP: is this the correct value to use? + cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y); + cur = ggml_permute(ctx0, cur, 1, 2, 0, 3); // -> [C, W, H, B] -> [W, H, C, B] for ggml + cb(cur, "pre_downsample_permute", -1); + + // downsampling conv2d + cur = ggml_conv_2d(ctx0, model.mm_downsample_w, cur, merge_size, merge_size, 0, 0, 1, 1); + cb(cur, "downsample_conv", -1); + + // reshape to [tokens, features] + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]); + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + cb(cur, "post_downsample_reshape", -1); + + // patch merger FFN + // + // ref: `class Glm4vVisionPatchMerger(nn.Module):` + { + // input projection + cur = ggml_mul_mat(ctx0, model.mm_merger_proj_w, cur); + + // apply norm + GELU + cur = build_norm(cur, model.mm_merger_norm_w, model.mm_merger_norm_b, NORM_TYPE_NORMAL, 1e-5f, -1); + cur = ggml_gelu(ctx0, cur); + ggml_tensor * ffn_input = cur; + cb(cur, "merger_ffn_inp", -1); + + // gate projection + ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_merger_gate_w, ffn_input); + cb(cur, "merger_gate", -1); + + // up projection + ggml_tensor * up = ggml_mul_mat(ctx0, model.mm_merger_up_w, ffn_input); + cb(cur, "merger_up", -1); + + // activation + down projection + cur = ggml_silu(ctx0, gate); + cur = ggml_mul(ctx0, cur, up); + cur = ggml_mul_mat(ctx0, model.mm_merger_down_w, cur); + cb(cur, "merger_ffn_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + ggml_cgraph * build_minicpmv() { GGML_ASSERT(model.class_embedding == nullptr); const int n_pos = n_patches; @@ -2516,13 +2651,17 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_kimivl(); } break; + case PROJECTOR_TYPE_COGVLM: + { + res = graph.build_cogvlm(); + } break; case PROJECTOR_TYPE_JANUS_PRO: { res = graph.build_siglip(); } break; - case PROJECTOR_TYPE_COGVLM: + case PROJECTOR_TYPE_GLM4V: { - res = graph.build_cogvlm(); + res = graph.build_glm4v(); } break; default: {