Skip to content
6 changes: 6 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,6 +2363,12 @@ def set_gguf_parameters(self):
self.gguf_writer.add_key_length(hparams["head_dim"])
self.gguf_writer.add_value_length(hparams["head_dim"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_attn_logit_softcapping(
self.hparams["attn_logit_softcapping"]
)
self.gguf_writer.add_final_logit_softcapping(
self.hparams["final_logit_softcapping"]
)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unusem
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class LLM:
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ def add_clamp_kqv(self, value: float) -> None:
def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)

def add_attn_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_final_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_expert_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)

Expand Down
35 changes: 32 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ enum llm_kv {
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -2099,6 +2103,9 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;

float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;

float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
Expand All @@ -2115,8 +2122,9 @@ struct llama_hparams {
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;

bool causal_attn = true;
bool use_alibi = false;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;

enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
Expand Down Expand Up @@ -4702,6 +4710,9 @@ static void llm_load_hparams(
case LLM_ARCH_GEMMA2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
hparams.attn_soft_cap = true;

switch (hparams.n_layer) {
case 42: model.type = e_model::MODEL_9B; break;
Expand Down Expand Up @@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv(
kq = ggml_scale(ctx, kq, 30);
}

if (hparams.attn_soft_cap) {
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
kq = ggml_tanh(ctx, kq);
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
}

kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

Expand Down Expand Up @@ -11039,7 +11056,7 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);

Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
cb(Qcur, "Qcur_scaled", il);

Kcur = ggml_rope_ext(
Expand Down Expand Up @@ -11106,6 +11123,12 @@ struct llm_build_context {

// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);

// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
Copy link

@eran-medan eran-medan Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total nitpick that probably should be ignored. I came here from curiosity, and I know this is merged by now and I have absolutely no place to comment. But
Isn’t this similar logic to lines 7594 - 7596?

While I’m a proponent of the “rule of 3” l think there’s merit in extracting it to something like a separate apply_softcap method. For educational purposes at least (gives the opportunity to add docs explaining what it does, single responsibility principle and all that, also I know for sure if I had to fix a bug in it, I’d fix it in one place and forget to update the other)

cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);

cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);
Expand Down Expand Up @@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}

if (params.flash_attn && model->hparams.attn_soft_cap) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
params.flash_attn = false;
}


if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;
Expand Down