@@ -208,6 +208,7 @@ enum llm_arch {
208208    LLM_ARCH_ORION,
209209    LLM_ARCH_INTERNLM2,
210210    LLM_ARCH_MINICPM,
211+     LLM_ARCH_GEMMA,
211212    LLM_ARCH_UNKNOWN,
212213};
213214
@@ -234,6 +235,7 @@ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
234235    { LLM_ARCH_ORION,           "orion"      },
235236    { LLM_ARCH_INTERNLM2,       "internlm2"  },
236237    { LLM_ARCH_MINICPM,         "minicpm"    },
238+     { LLM_ARCH_GEMMA,           "gemma"      },
237239};
238240
239241enum llm_kv {
@@ -760,6 +762,22 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
760762            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
761763        },
762764    },
765+     {
766+         LLM_ARCH_GEMMA,
767+         {
768+             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
769+             { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
770+             { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
771+             { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
772+             { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
773+             { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
774+             { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
775+             { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
776+             { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
777+             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
778+             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
779+         },
780+     },
763781    {
764782        LLM_ARCH_UNKNOWN,
765783        {
@@ -3243,6 +3261,16 @@ static void llm_load_hparams(
32433261                    default: model.type = e_model::MODEL_UNKNOWN;
32443262                }
32453263            } break;
3264+         case LLM_ARCH_GEMMA:
3265+             {
3266+                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
3267+ 
3268+                 switch (hparams.n_layer) {
3269+                     case 18: model.type = e_model::MODEL_2B; break;
3270+                     case 28: model.type = e_model::MODEL_7B; break;
3271+                     default: model.type = e_model::MODEL_UNKNOWN;
3272+                }
3273+             } break;
32463274        default: (void)0;
32473275    }
32483276
@@ -4360,6 +4388,37 @@ static bool llm_load_tensors(
43604388                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
43614389                    }
43624390                } break;
4391+             case LLM_ARCH_GEMMA:
4392+                 {
4393+                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4394+ 
4395+                     // output
4396+                     model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4397+ 
4398+                     const int64_t n_ff          = hparams.n_ff;
4399+                     const int64_t n_embd_head_k = hparams.n_embd_head_k;
4400+                     const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
4401+                     const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
4402+ 
4403+                     for (uint32_t i = 0; i < n_layer; ++i) {
4404+                         ggml_context * ctx_layer = ctx_for_layer(i);
4405+                         ggml_context * ctx_split = ctx_for_layer_split(i);
4406+ 
4407+                         auto & layer = model.layers[i];
4408+ 
4409+                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4410+ 
4411+                         layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
4412+                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
4413+                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
4414+                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
4415+ 
4416+                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
4417+                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
4418+                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
4419+                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
4420+                     }
4421+                 } break;
43634422            default:
43644423                throw std::runtime_error("unknown architecture");
43654424        }
@@ -7366,6 +7425,113 @@ struct llm_build_context {
73667425
73677426        return gf;
73687427    }
7428+ 
7429+     struct ggml_cgraph * build_gemma() {
7430+         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7431+ 
7432+         const int64_t n_embd_head_k = hparams.n_embd_head_k;
7433+ 
7434+         struct ggml_tensor * cur;
7435+         struct ggml_tensor * inpL;
7436+ 
7437+         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7438+         cb(inpL, "inp_embd", -1);
7439+         inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
7440+         cb(inpL, "inp_scaled", -1);
7441+ 
7442+         // inp_pos - contains the positions
7443+         struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7444+         cb(inp_pos, "inp_pos", -1);
7445+ 
7446+         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7447+         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7448+         cb(KQ_mask, "KQ_mask", -1);
7449+ 
7450+         // shift the entire K-cache if needed
7451+         if (do_rope_shift) {
7452+             llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
7453+         }
7454+ 
7455+         for (int il = 0; il < n_layer; ++il) {
7456+ 
7457+             // norm
7458+             cur = llm_build_norm(ctx0, inpL, hparams,
7459+                     model.layers[il].attn_norm, NULL,
7460+                     LLM_NORM_RMS, cb, il);
7461+             cb(cur, "attn_norm", il);
7462+ 
7463+             // self-attention
7464+             {
7465+                 // compute Q and K and RoPE them
7466+                 struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
7467+                 cb(Qcur, "Qcur", il);
7468+ 
7469+                 struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
7470+                 cb(Kcur, "Kcur", il);
7471+ 
7472+                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
7473+                 cb(Vcur, "Vcur", il);
7474+ 
7475+                 Qcur = ggml_rope_custom(
7476+                         ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos,
7477+                         n_embd_head_k, 2, 0, n_orig_ctx, freq_base, freq_scale,
7478+                         ext_factor, attn_factor, beta_fast, beta_slow);
7479+                 cb(Qcur, "Qcur", il);
7480+                 Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
7481+                 cb(Qcur, "Qcur_scaled", il);
7482+ 
7483+                 Kcur = ggml_rope_custom(
7484+                         ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos,
7485+                         n_embd_head_k, 2, 0, n_orig_ctx, freq_base, freq_scale,
7486+                         ext_factor, attn_factor, beta_fast, beta_slow);
7487+                 cb(Kcur, "Kcur", il);
7488+ 
7489+                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7490+                         model.layers[il].wo, NULL,
7491+                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
7492+                 cb(cur, "kqv_out", il);
7493+             }
7494+             struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
7495+             cb(sa_out, "sa_out", il);
7496+ 
7497+             cur = llm_build_norm(ctx0, sa_out, hparams,
7498+                     model.layers[il].ffn_norm, NULL,
7499+                     LLM_NORM_RMS, cb, il);
7500+             cb(cur, "ffn_norm", il);
7501+ 
7502+             // feed-forward network
7503+             {
7504+                 cur = llm_build_ffn(ctx0, cur,
7505+                         model.layers[il].ffn_up, NULL,
7506+                         model.layers[il].ffn_gate, NULL,
7507+                         model.layers[il].ffn_down, NULL,
7508+                         NULL,
7509+                         LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
7510+                 cb(cur, "ffn_out", il);
7511+             }
7512+ 
7513+             cur = ggml_add(ctx0, cur, sa_out);
7514+             cb(cur, "l_out", il);
7515+ 
7516+             // input for next layer
7517+             inpL = cur;
7518+         }
7519+ 
7520+         cur = inpL;
7521+ 
7522+         cur = llm_build_norm(ctx0, cur, hparams,
7523+                 model.output_norm, NULL,
7524+                 LLM_NORM_RMS, cb, -1);
7525+         cb(cur, "result_norm", -1);
7526+ 
7527+         // lm_head
7528+         cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
7529+         cb(cur, "result_output", -1);
7530+ 
7531+         ggml_build_forward_expand(gf, cur);
7532+ 
7533+         return gf;
7534+     }
73697535};
73707536
73717537static struct ggml_cgraph * llama_build_graph(
@@ -7474,6 +7640,10 @@ static struct ggml_cgraph * llama_build_graph(
74747640            {
74757641                result = llm.build_minicpm();
74767642            } break;
7643+         case LLM_ARCH_GEMMA:
7644+             {
7645+                 result = llm.build_gemma();
7646+             } break;
74777647        default:
74787648            GGML_ASSERT(false);
74797649    }
0 commit comments