@@ -302,6 +302,8 @@ enum llm_kv {
302302 LLM_KV_POOLING_TYPE,
303303 LLM_KV_LOGIT_SCALE,
304304 LLM_KV_DECODER_START_TOKEN_ID,
305+ LLM_KV_ATTN_LOGIT_SOFTCAPPING,
306+ LLM_KV_FINAL_LOGIT_SOFTCAPPING,
305307
306308 LLM_KV_ATTENTION_HEAD_COUNT,
307309 LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
392394 { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
393395 { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
394396 { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
397+ { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
398+ { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
395399
396400 { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
397401 { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2099,6 +2103,9 @@ struct llama_hparams {
20992103 float f_norm_eps;
21002104 float f_norm_rms_eps;
21012105
2106+ float f_attn_logit_softcapping = 50.0f;
2107+ float f_final_logit_softcapping = 30.0f;
2108+
21022109 float rope_attn_factor = 1.0f;
21032110 float rope_freq_base_train;
21042111 float rope_freq_scale_train;
@@ -2115,8 +2122,9 @@ struct llama_hparams {
21152122 float f_max_alibi_bias = 0.0f;
21162123 float f_logit_scale = 0.0f;
21172124
2118- bool causal_attn = true;
2119- bool use_alibi = false;
2125+ bool causal_attn = true;
2126+ bool use_alibi = false;
2127+ bool attn_soft_cap = false;
21202128
21212129 enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
21222130 enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -4766,6 +4774,9 @@ static void llm_load_hparams(
47664774 case LLM_ARCH_GEMMA2:
47674775 {
47684776 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4777+ ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4778+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
4779+ hparams.attn_soft_cap = true;
47694780
47704781 switch (hparams.n_layer) {
47714782 case 42: model.type = e_model::MODEL_9B; break;
@@ -7655,6 +7666,12 @@ static struct ggml_tensor * llm_build_kqv(
76557666 kq = ggml_scale(ctx, kq, 30);
76567667 }
76577668
7669+ if (hparams.attn_soft_cap) {
7670+ kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
7671+ kq = ggml_tanh(ctx, kq);
7672+ kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
7673+ }
7674+
76587675 kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
76597676 cb(kq, "kq_soft_max_ext", il);
76607677
@@ -11115,7 +11132,7 @@ struct llm_build_context {
1111511132 ext_factor, attn_factor, beta_fast, beta_slow);
1111611133 cb(Qcur, "Qcur", il);
1111711134
11118- Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k )));
11135+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head )));
1111911136 cb(Qcur, "Qcur_scaled", il);
1112011137
1112111138 Kcur = ggml_rope_ext(
@@ -11182,6 +11199,12 @@ struct llm_build_context {
1118211199
1118311200 // lm_head
1118411201 cur = ggml_mul_mat(ctx0, model.output, cur);
11202+
11203+ // final logit soft-capping
11204+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
11205+ cur = ggml_tanh(ctx0, cur);
11206+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
11207+
1118511208 cb(cur, "result_output", -1);
1118611209
1118711210 ggml_build_forward_expand(gf, cur);
@@ -19687,6 +19710,12 @@ struct llama_context * llama_new_context_with_model(
1968719710 params.flash_attn = false;
1968819711 }
1968919712
19713+ if (params.flash_attn && model->hparams.attn_soft_cap) {
19714+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
19715+ params.flash_attn = false;
19716+ }
19717+
19718+
1969019719 if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
1969119720 LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
1969219721 params.flash_attn = false;
0 commit comments