From 79a43b2026b09a84f8964ceff0048494fea04646 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 4 Sep 2025 12:35:56 +0200 Subject: [PATCH] llama : add support for EmbeddingGemma 300m This commit add support for the EmbeddingGemma 300m. This model supports sliding window attention (SWA) and a new swq_type is introduced to support symmetric SWA masking. This commit also extracts the code from the function llama_is_masked_swa in llama-impl.h, so that the logic can be shared by both llm_graph_input_attn_no_cache::set_input and llama_kv_cache::set_input_kq_mask. With this commit the EmbeddingGemma 300m model can be converted to to GGUF and used with llama.cpp. Once the model has been uploaded to HuggingFace it can be used like this: ```console ./build/bin/llama-cli -hf ggml-org/embeddinggemma-300m-GGUF:Q8_0 ``` --- convert_hf_to_gguf.py | 9 ++ gguf-py/gguf/constants.py | 20 +++++ gguf-py/gguf/tensor_mapping.py | 14 +++ src/llama-arch.cpp | 22 +++++ src/llama-arch.h | 1 + src/llama-graph.cpp | 59 ++++++++++-- src/llama-graph.h | 8 ++ src/llama-hparams.cpp | 37 ++++++++ src/llama-hparams.h | 9 +- src/llama-kv-cache-iswa.cpp | 4 +- src/llama-kv-cache.cpp | 27 +----- src/llama-kv-cache.h | 3 - src/llama-memory-hybrid.cpp | 2 - src/llama-memory-hybrid.h | 1 - src/llama-model.cpp | 159 ++++++++++++++++++++++++++++++++- 15 files changed, 328 insertions(+), 47 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 908435d7e5f00..717b8a6595010 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5122,6 +5122,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("Gemma3TextModel") +class EmbeddingGemma(Gemma3Model): + model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self._try_set_pooling_type() + + @ModelBase.register("Gemma3ForConditionalGeneration") class Gemma3VisionModel(MmprojModel): def set_gguf_parameters(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6156d35c2ad67..c922bacf6eb97 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -340,6 +340,7 @@ class MODEL_ARCH(IntEnum): GEMMA2 = auto() GEMMA3 = auto() GEMMA3N = auto() + GEMMA_EMBEDDING = auto() STARCODER2 = auto() RWKV6 = auto() RWKV6QWEN2 = auto() @@ -674,6 +675,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.GEMMA3N: "gemma3n", + MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", @@ -1719,6 +1721,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.LAUREL_R, MODEL_TENSOR.LAUREL_POST_NORM, ], + MODEL_ARCH.GEMMA_EMBEDDING: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], MODEL_ARCH.STARCODER2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 497f48809f8a7..b0c3d65e95847 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -14,6 +14,7 @@ class TensorNameMap: "transformer.word_embeddings", # falcon "word_embeddings", # bloom "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid + "embed_tokens", # embeddinggemma "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -141,6 +142,7 @@ class TensorNameMap: "rwkv.blocks.{bid}.ln1", # rwkv6 "model.layers.{bid}.ln1", # rwkv7 "model.layers.{bid}.input_layernorm", # llama4 + "layers.{bid}.input_layernorm", # embeddinggemma "transformer_encoder.{bid}.attention_norm", # neobert "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada @@ -179,6 +181,7 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe + "layers.{bid}.self_attn.q_proj", # embeddinggemma "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert @@ -197,6 +200,7 @@ class TensorNameMap: # Attention key MODEL_TENSOR.ATTN_K: ( "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe + "layers.{bid}.self_attn.k_proj", # embeddinggemma "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert @@ -216,6 +220,7 @@ class TensorNameMap: # Attention value MODEL_TENSOR.ATTN_V: ( "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe + "layers.{bid}.self_attn.v_proj", # embeddinggemma "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.layer.{bid}.attention.v_lin", # distillbert @@ -239,6 +244,7 @@ class TensorNameMap: "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe + "layers.{bid}.self_attn.o_proj", # embeddinggemma "model.layers.{bid}.self_attn.out_proj", # lfm2 "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth @@ -277,6 +283,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_POST_NORM: ( "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "layers.{bid}.post_attention_layernorm", # embeddinggemma "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 "model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2 ), @@ -320,12 +327,14 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_PRE_NORM: ( "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 + "layers.{bid}.pre_feedforward_layernorm", # embeddinggemma "model.layers.{bid}.pre_ff_layernorm.weight", ), # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 + "layers.{bid}.post_feedforward_layernorm", # embeddinggemma "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2 "model.layers.{bid}.feed_forward.up_proj", @@ -362,6 +371,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "h.{bid}.mlp.dense_h_to_4h", # bloom "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 + "layers.{bid}.mlp.up_proj", # embeddinggemma "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert "transformer.layer.{bid}.ffn.lin1", # distillbert @@ -421,6 +431,7 @@ class TensorNameMap: # Feed-forward gate MODEL_TENSOR.FFN_GATE: ( "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2 + "layers.{bid}.mlp.gate_proj", # embeddinggemma "layers.{bid}.feed_forward.w1", # llama-pth "transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.c_fc2", # jais @@ -461,6 +472,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "h.{bid}.mlp.dense_4h_to_h", # bloom "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 + "layers.{bid}.mlp.down_proj", # embeddinggemma "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert "transformer.layer.{bid}.ffn.lin2", # distillbert @@ -513,6 +525,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.query_layernorm", # hunyuan "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 + "layers.{bid}.self_attn.q_norm", # embeddinggemma "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm @@ -525,6 +538,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.key_layernorm", # hunyuan "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 + "layers.{bid}.self_attn.k_norm", # embeddinggemma "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d5c8477f4aa39..d92bf2dd8515f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -45,6 +45,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, @@ -1038,6 +1039,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, }, }, + { + LLM_ARCH_GEMMA_EMBEDDING, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 86c119692d8cc..41f377923f50f 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -49,6 +49,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, + LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 49ea5da7cb422..7ce2960eb311b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } +static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { + LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); + const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : + (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" : + (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : + (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); + LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); + + LLAMA_LOG_DEBUG(" "); + for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { + LLAMA_LOG_DEBUG("%2d", j); + } + LLAMA_LOG_DEBUG("\n"); + + for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { + LLAMA_LOG_DEBUG(" %2d ", i); + for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { + float val = data[i * n_kv + j]; + if (val == -INFINITY) { + LLAMA_LOG_DEBUG(" ∞"); + } else { + LLAMA_LOG_DEBUG(" 0"); + } + } + LLAMA_LOG_DEBUG("\n"); + } +} + void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; @@ -277,21 +307,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { const llama_seq_id s0 = ubatch->seq_id[i0][0]; - // TODO: reimplement this like in llama_kv_cache - if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) { - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; - } - break; + if (s0 != s1) { + continue; // skip different sequences } - } + if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { + continue; // skip future tokens for causal attention + } + + if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { + continue; // skip masked tokens for SWA + } + + // TODO: reimplement this like in llama_kv_cache_unified + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); + } else { + f = 0.0f; + } + } data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; } } } + if (debug) { + print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + } } void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 3c85333fde514..ca90fdf613f6d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -78,6 +78,11 @@ struct llm_graph_params; class llm_graph_input_i { public: + llm_graph_input_i() { + const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG"); + debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0; + } + virtual ~llm_graph_input_i() = default; virtual void set_input(const llama_ubatch * ubatch) = 0; @@ -90,6 +95,9 @@ class llm_graph_input_i { GGML_UNUSED(params); return false; } +protected: + // env: LLAMA_GRAPH_INPUT_DEBUG + int debug = 0; }; using llm_graph_input_ptr = std::unique_ptr; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 91636572da8b2..4b7a65362b9f0 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -1,6 +1,7 @@ #include "llama-hparams.h" #include "ggml.h" +#include void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { @@ -178,3 +179,39 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } + +bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 60415f0c202a4..06d1e51db0552 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -16,9 +16,10 @@ enum llama_expert_gating_func_type { }; enum llama_swa_type { - LLAMA_SWA_TYPE_NONE = 0, - LLAMA_SWA_TYPE_STANDARD = 1, - LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_NONE = 0, + LLAMA_SWA_TYPE_STANDARD = 1, + LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_SYMMETRIC = 3, }; struct llama_hparams_posnet { @@ -227,6 +228,8 @@ struct llama_hparams { // number of layers for which has_kv() returns true uint32_t n_layer_kv() const; + + bool is_masked_swa(llama_pos p0, llama_pos p1) const; }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index d7342914c6b7c..7c51a1981e05e 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( kv_base = std::make_unique( model, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, filter_base, reuse); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, filter_swa, reuse); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index f1c6918738684..1564faedf4eb4 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -27,11 +27,10 @@ llama_kv_cache::llama_kv_cache( uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, - llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) { GGML_ASSERT(kv_size % n_pad == 0); @@ -1393,29 +1392,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co } bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - } - - return false; + return hparams.is_masked_swa(p0, p1); } void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 07d29bb8187e2..55ee355b22928 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -89,7 +89,6 @@ class llama_kv_cache : public llama_memory_i { uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, - llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse); @@ -212,8 +211,6 @@ class llama_kv_cache : public llama_memory_i { // env: LLAMA_KV_CACHE_DEBUG int debug = 0; - const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - std::vector ctxs; std::vector bufs; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index ba61ebaa885fe..38f92da11820d 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -17,7 +17,6 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t kv_size, uint32_t n_pad, uint32_t n_swa, - llama_swa_type swa_type, /* recurrent */ ggml_type type_r, ggml_type type_s, @@ -41,7 +40,6 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_pad, n_swa, - swa_type, filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recurrent(il); } : filter_attn, diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 11a3565178297..0eb63f5ef86d7 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -27,7 +27,6 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t kv_size, uint32_t n_pad, uint32_t n_swa, - llama_swa_type swa_type, /* recurrent */ ggml_type type_r, ggml_type type_s, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5e54ce25f1298..afc4cb48098d2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1142,6 +1142,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA_EMBEDDING: + { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + hparams.set_swa_pattern(6); + + hparams.causal_attn = false; // embeddings do not use causal attention + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3484,6 +3504,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA_EMBEDDING: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -11045,6 +11066,136 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { } }; +struct llm_build_gemma_embedding_iswa : public llm_graph_context { + llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // TODO: move up next to build_starcoder struct llm_build_starcoder2 : public llm_graph_context { llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -18481,6 +18632,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: { @@ -18529,7 +18681,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_kv_size */ cparams.n_ctx, /* attn_n_pad */ padding, /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), @@ -18599,7 +18750,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, padding, hparams.n_swa, - hparams.swa_type, nullptr, nullptr); } @@ -18761,6 +18911,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GEMMA_EMBEDDING: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_STARCODER2: { llm = std::make_unique(*this, params); @@ -19161,6 +19315,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: + case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: