@@ -326,6 +326,8 @@ enum llm_kv {
326326 LLM_KV_POOLING_TYPE,
327327 LLM_KV_LOGIT_SCALE,
328328 LLM_KV_DECODER_START_TOKEN_ID,
329+ LLM_KV_ATTN_LOGIT_SOFTCAPPING,
330+ LLM_KV_FINAL_LOGIT_SOFTCAPPING,
329331
330332 LLM_KV_ATTENTION_HEAD_COUNT,
331333 LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -416,6 +418,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
416418 { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
417419 { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
418420 { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
421+ { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
422+ { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
419423
420424 { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
421425 { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2127,6 +2131,9 @@ struct llama_hparams {
21272131 float f_norm_eps;
21282132 float f_norm_rms_eps;
21292133
2134+ float f_attn_logit_softcapping = 50.0f;
2135+ float f_final_logit_softcapping = 30.0f;
2136+
21302137 float rope_attn_factor = 1.0f;
21312138 float rope_freq_base_train;
21322139 float rope_freq_scale_train;
@@ -2143,8 +2150,9 @@ struct llama_hparams {
21432150 float f_max_alibi_bias = 0.0f;
21442151 float f_logit_scale = 0.0f;
21452152
2146- bool causal_attn = true;
2147- bool use_alibi = false;
2153+ bool causal_attn = true;
2154+ bool use_alibi = false;
2155+ bool attn_soft_cap = false;
21482156
21492157 enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
21502158 enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -4822,6 +4830,9 @@ static void llm_load_hparams(
48224830 case LLM_ARCH_GEMMA2:
48234831 {
48244832 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4833+ ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4834+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
4835+ hparams.attn_soft_cap = true;
48254836
48264837 switch (hparams.n_layer) {
48274838 case 42: model.type = e_model::MODEL_9B; break;
@@ -7737,6 +7748,12 @@ static struct ggml_tensor * llm_build_kqv(
77377748 kq = ggml_scale(ctx, kq, 30);
77387749 }
77397750
7751+ if (hparams.attn_soft_cap) {
7752+ kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
7753+ kq = ggml_tanh(ctx, kq);
7754+ kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
7755+ }
7756+
77407757 kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
77417758 cb(kq, "kq_soft_max_ext", il);
77427759
@@ -11197,7 +11214,7 @@ struct llm_build_context {
1119711214 ext_factor, attn_factor, beta_fast, beta_slow);
1119811215 cb(Qcur, "Qcur", il);
1119911216
11200- Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k )));
11217+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head )));
1120111218 cb(Qcur, "Qcur_scaled", il);
1120211219
1120311220 Kcur = ggml_rope_ext(
@@ -11264,6 +11281,12 @@ struct llm_build_context {
1126411281
1126511282 // lm_head
1126611283 cur = ggml_mul_mat(ctx0, model.output, cur);
11284+
11285+ // final logit soft-capping
11286+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
11287+ cur = ggml_tanh(ctx0, cur);
11288+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
11289+
1126711290 cb(cur, "result_output", -1);
1126811291
1126911292 ggml_build_forward_expand(gf, cur);
@@ -20022,6 +20045,12 @@ struct llama_context * llama_new_context_with_model(
2002220045 params.flash_attn = false;
2002320046 }
2002420047
20048+ if (params.flash_attn && model->hparams.attn_soft_cap) {
20049+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
20050+ params.flash_attn = false;
20051+ }
20052+
20053+
2002520054 if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
2002620055 LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
2002720056 params.flash_attn = false;
0 commit comments