From 7df7530b8fa1b45af84e8151fa0027fbc9250486 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 30 Jun 2024 19:26:13 +0200 Subject: [PATCH 1/9] gemma2: add sliding window mask --- src/llama.cpp | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2a4d73856fcd9..8e4e3137e5e41 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -287,6 +287,7 @@ enum llm_kv { LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, + LLM_KV_CONTEXT_LENGTH_SWA, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -379,6 +380,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_CONTEXT_LENGTH_SWA, "%s.context_length_swa" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -2079,7 +2081,8 @@ struct llama_hparams { bool use_par_res; uint32_t n_vocab; - uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_ctx_train; // context size the model was trained on + int32_t n_ctx_swa = -1; // context size for sliding window attention (SWA) uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -2661,6 +2664,9 @@ struct llama_context { struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + // KQ mask per layer, used by sliding window attention (gemma 2) + std::vector inp_KQ_mask_l; + // control vectors struct llama_control_vector cvec; }; @@ -4709,6 +4715,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_GEMMA2: { + hparams.n_ctx_swa = 4096; // default value + ml.get_key(LLM_KV_CONTEXT_LENGTH_SWA, hparams.n_ctx_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); @@ -11029,9 +11037,16 @@ struct llm_build_context { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // gemma 2 requires different mask for layers using sliding window (SWA) + struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(); + lctx.inp_KQ_mask_l.clear(); for (int il = 0; il < n_layer; ++il) { + // (il % 2) layers use SWA + struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full; + lctx.inp_KQ_mask_l.push_back(KQ_mask); + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -12671,6 +12686,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); float * data = (float *) lctx.inp_KQ_mask->data; + float * data_swa = nullptr; + + if (lctx.model.arch == LLM_ARCH_GEMMA2) { + GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer"); + GGML_ASSERT(hparams.n_ctx_swa > 0); + data_swa = (float *) lctx.inp_KQ_mask_l[0]->data; + data = (float *) lctx.inp_KQ_mask_l[1]->data; + } // For causal attention, use only the previous KV cells // of the correct sequence for each token of the batch. @@ -12692,6 +12715,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + + // may need to cut off old tokens for sliding window + if (data_swa && f != -INFINITY) { + const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens; + if (pos - lctx.kv_self.cells[i].pos > n_keep) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } } } From ab2c3de9b308b80815ba5e5b9f459f56034874e2 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 30 Jun 2024 20:18:53 +0200 Subject: [PATCH 2/9] fix data_swa uninitialized --- src/llama.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 8e4e3137e5e41..6838a6fc7d712 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12687,12 +12687,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; float * data_swa = nullptr; + const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens; if (lctx.model.arch == LLM_ARCH_GEMMA2) { GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer"); GGML_ASSERT(hparams.n_ctx_swa > 0); data_swa = (float *) lctx.inp_KQ_mask_l[0]->data; data = (float *) lctx.inp_KQ_mask_l[1]->data; + // because layer masks are alternate for gemma 2, we only need to take first 2 layers } // For causal attention, use only the previous KV cells @@ -12717,9 +12719,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data[h*(n_kv*n_tokens) + j*n_kv + i] = f; // may need to cut off old tokens for sliding window - if (data_swa && f != -INFINITY) { - const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens; - if (pos - lctx.kv_self.cells[i].pos > n_keep) { + if (data_swa) { + if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) { f = -INFINITY; } data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; From 46b56e67685263f48d2824f2fafc5c8ea136a9e1 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 30 Jun 2024 22:27:47 +0200 Subject: [PATCH 3/9] better naming --- convert-hf-to-gguf.py | 1 + gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ src/llama.cpp | 14 +++++++------- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 3ef2f69e7c0df..27fc9eea6d69b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2369,6 +2369,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_final_logit_softcapping( self.hparams["final_logit_softcapping"] ) + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9bfa891d5dc52..e87c58266158a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -66,6 +66,7 @@ class Attention: Q_LORA_RANK = "{arch}.attention.q_lora_rank" KV_LORA_RANK = "{arch}.attention.kv_lora_rank" REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" + SLIDING_WINDOW = "{arch}.attention.sliding_window" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 1aeb0d9b08685..75a8b2636a6a2 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -552,6 +552,9 @@ def add_kv_lora_rank(self, length: int) -> None: def add_relative_attn_buckets_count(self, value: int) -> None: self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value) + def add_sliding_window(self, value: int) -> None: + self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) diff --git a/src/llama.cpp b/src/llama.cpp index 6838a6fc7d712..d8852cfe494af 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -287,7 +287,6 @@ enum llm_kv { LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, - LLM_KV_CONTEXT_LENGTH_SWA, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -318,6 +317,7 @@ enum llm_kv { LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -380,7 +380,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_CONTEXT_LENGTH_SWA, "%s.context_length_swa" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -411,6 +410,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -2082,7 +2082,6 @@ struct llama_hparams { uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on - int32_t n_ctx_swa = -1; // context size for sliding window attention (SWA) uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -2102,6 +2101,7 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; + uint32_t n_sliding = 0; // sliding window attention (SWA) float f_norm_eps; float f_norm_rms_eps; @@ -4715,8 +4715,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_GEMMA2: { - hparams.n_ctx_swa = 4096; // default value - ml.get_key(LLM_KV_CONTEXT_LENGTH_SWA, hparams.n_ctx_swa, false); + hparams.n_sliding = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); @@ -12687,11 +12687,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; float * data_swa = nullptr; - const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens; + const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; if (lctx.model.arch == LLM_ARCH_GEMMA2) { GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer"); - GGML_ASSERT(hparams.n_ctx_swa > 0); + GGML_ASSERT(hparams.n_sliding > 0); data_swa = (float *) lctx.inp_KQ_mask_l[0]->data; data = (float *) lctx.inp_KQ_mask_l[1]->data; // because layer masks are alternate for gemma 2, we only need to take first 2 layers From 231dae4f68e68d5d77debf9a27af2fc111744815 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 30 Jun 2024 23:11:04 +0200 Subject: [PATCH 4/9] add co-author Co-authored-by: Arlo Phoenix From d09ecb84c8300504bb76794298c0ef47d541a733 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 30 Jun 2024 23:40:25 +0200 Subject: [PATCH 5/9] replace list with single tensor --- src/llama.cpp | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d8852cfe494af..71b7ef622019e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2081,7 +2081,7 @@ struct llama_hparams { bool use_par_res; uint32_t n_vocab; - uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -2665,7 +2665,7 @@ struct llama_context { struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] // KQ mask per layer, used by sliding window attention (gemma 2) - std::vector inp_KQ_mask_l; + struct ggml_tensor * inp_KQ_mask_SWA; // control vectors struct llama_control_vector cvec; @@ -7794,6 +7794,7 @@ struct llm_build_context { lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; lctx.inp_s_seq = nullptr; + lctx.inp_KQ_mask_SWA = nullptr; } void free() { @@ -7946,15 +7947,18 @@ struct llm_build_context { return lctx.inp_out_ids; } - struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { - if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) { + struct ggml_tensor * KQ_mask = causal + ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) + : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + cb(KQ_mask, "KQ_mask", -1); + ggml_set_input(KQ_mask); + if (sliding_window) { + lctx.inp_KQ_mask_SWA = KQ_mask; } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = KQ_mask; } - cb(lctx.inp_KQ_mask, "KQ_mask", -1); - ggml_set_input(lctx.inp_KQ_mask); - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; + return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask; } struct ggml_tensor * build_inp_mean() { @@ -11038,14 +11042,12 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) // gemma 2 requires different mask for layers using sliding window (SWA) - struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(); - struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(); - lctx.inp_KQ_mask_l.clear(); + struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false); + struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true); for (int il = 0; il < n_layer; ++il) { // (il % 2) layers use SWA struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full; - lctx.inp_KQ_mask_l.push_back(KQ_mask); // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -12685,15 +12687,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - float * data = (float *) lctx.inp_KQ_mask->data; + float * data = (float *) lctx.inp_KQ_mask->data; float * data_swa = nullptr; const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; if (lctx.model.arch == LLM_ARCH_GEMMA2) { - GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer"); + GGML_ASSERT(lctx.inp_KQ_mask_SWA); GGML_ASSERT(hparams.n_sliding > 0); - data_swa = (float *) lctx.inp_KQ_mask_l[0]->data; - data = (float *) lctx.inp_KQ_mask_l[1]->data; + data = (float *) lctx.inp_KQ_mask->data; + data_swa = (float *) lctx.inp_KQ_mask_SWA->data; // because layer masks are alternate for gemma 2, we only need to take first 2 layers } From ed5496fb32e9888abeaa7672aaba4d4251671457 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 1 Jul 2024 12:35:47 +0200 Subject: [PATCH 6/9] update --- src/llama.cpp | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 71b7ef622019e..1f676357333ff 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2101,7 +2101,7 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; - uint32_t n_sliding = 0; // sliding window attention (SWA) + uint32_t n_swa = 0; // sliding window attention (SWA) float f_norm_eps; float f_norm_rms_eps; @@ -2665,7 +2665,7 @@ struct llama_context { struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] // KQ mask per layer, used by sliding window attention (gemma 2) - struct ggml_tensor * inp_KQ_mask_SWA; + struct ggml_tensor * inp_KQ_mask_swa; // control vectors struct llama_control_vector cvec; @@ -4715,8 +4715,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_GEMMA2: { - hparams.n_sliding = 4096; // default value of gemma 2 - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding, false); + hparams.n_swa = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); @@ -7794,7 +7794,7 @@ struct llm_build_context { lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; lctx.inp_s_seq = nullptr; - lctx.inp_KQ_mask_SWA = nullptr; + lctx.inp_KQ_mask_swa = nullptr; } void free() { @@ -7954,7 +7954,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); ggml_set_input(KQ_mask); if (sliding_window) { - lctx.inp_KQ_mask_SWA = KQ_mask; + lctx.inp_KQ_mask_swa = KQ_mask; } else { lctx.inp_KQ_mask = KQ_mask; } @@ -12689,14 +12689,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; float * data_swa = nullptr; - const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; - if (lctx.model.arch == LLM_ARCH_GEMMA2) { - GGML_ASSERT(lctx.inp_KQ_mask_SWA); - GGML_ASSERT(hparams.n_sliding > 0); - data = (float *) lctx.inp_KQ_mask->data; - data_swa = (float *) lctx.inp_KQ_mask_SWA->data; - // because layer masks are alternate for gemma 2, we only need to take first 2 layers + if (lctx.inp_KQ_mask_swa) { + data_swa = (float *) lctx.inp_KQ_mask_swa->data; } // For causal attention, use only the previous KV cells @@ -12722,7 +12717,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // may need to cut off old tokens for sliding window if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { f = -INFINITY; } data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; From ce711f6eae1079592e060e47ec23b6a795387ce6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 1 Jul 2024 18:26:24 +0300 Subject: [PATCH 7/9] llama : minor styling --- src/llama.cpp | 89 +++++++++++++++++++++++++++------------------------ 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 1f676357333ff..eea532f6ac2ff 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2087,6 +2087,7 @@ struct llama_hparams { uint32_t n_head_kv; uint32_t n_layer; uint32_t n_rot; + uint32_t n_swa = 0; // sliding window attention (SWA) uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_ff; @@ -2101,7 +2102,6 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; - uint32_t n_swa = 0; // sliding window attention (SWA) float f_norm_eps; float f_norm_rms_eps; @@ -2142,6 +2142,7 @@ struct llama_hparams { if (this->n_head_kv != other.n_head_kv) return true; if (this->n_layer != other.n_layer) return true; if (this->n_rot != other.n_rot) return true; + if (this->n_swa != other.n_swa) return true; if (this->n_embd_head_k != other.n_embd_head_k) return true; if (this->n_embd_head_v != other.n_embd_head_v) return true; if (this->n_ff != other.n_ff) return true; @@ -2652,20 +2653,18 @@ struct llama_context { void * abort_callback_data = nullptr; // input tensors - struct ggml_tensor * inp_tokens; // I32 [n_batch] - struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - struct ggml_tensor * inp_pos; // I32 [n_batch] - struct ggml_tensor * inp_out_ids; // I32 [n_outputs] - struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_K_shift; // I32 [kv_size] - struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] - struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] - - // KQ mask per layer, used by sliding window attention (gemma 2) - struct ggml_tensor * inp_KQ_mask_swa; + struct ggml_tensor * inp_tokens; // I32 [n_batch] + struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_out_ids; // I32 [n_outputs] + struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_K_shift; // I32 [kv_size] + struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] + struct ggml_tensor * inp_cls; // I32 [n_batch] + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] // control vectors struct llama_control_vector cvec; @@ -5427,6 +5426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); @@ -7783,18 +7783,18 @@ struct llm_build_context { ctx0 = ggml_init(params); - lctx.inp_tokens = nullptr; - lctx.inp_embd = nullptr; - lctx.inp_pos = nullptr; - lctx.inp_out_ids = nullptr; - lctx.inp_KQ_mask = nullptr; - lctx.inp_K_shift = nullptr; - lctx.inp_mean = nullptr; - lctx.inp_cls = nullptr; - lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; + lctx.inp_tokens = nullptr; + lctx.inp_embd = nullptr; + lctx.inp_pos = nullptr; + lctx.inp_out_ids = nullptr; + lctx.inp_KQ_mask = nullptr; lctx.inp_KQ_mask_swa = nullptr; + lctx.inp_K_shift = nullptr; + lctx.inp_mean = nullptr; + lctx.inp_cls = nullptr; + lctx.inp_s_copy = nullptr; + lctx.inp_s_mask = nullptr; + lctx.inp_s_seq = nullptr; } void free() { @@ -7813,7 +7813,6 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * rope_factors = build_rope_factors(il); struct ggml_tensor * tmp = @@ -7947,18 +7946,26 @@ struct llm_build_context { return lctx.inp_out_ids; } - struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) { - struct ggml_tensor * KQ_mask = causal + struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { + lctx.inp_KQ_mask = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - cb(KQ_mask, "KQ_mask", -1); - ggml_set_input(KQ_mask); - if (sliding_window) { - lctx.inp_KQ_mask_swa = KQ_mask; - } else { - lctx.inp_KQ_mask = KQ_mask; - } - return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask; + cb(lctx.inp_KQ_mask, "KQ_mask", -1); + ggml_set_input(lctx.inp_KQ_mask); + + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; + } + + struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) { + GGML_ASSERT(hparams.n_swa > 0); + + lctx.inp_KQ_mask_swa = causal + ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) + : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(lctx.inp_KQ_mask_swa); + + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; } struct ggml_tensor * build_inp_mean() { @@ -11042,12 +11049,12 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) // gemma 2 requires different mask for layers using sliding window (SWA) - struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false); - struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true); + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true); for (int il = 0; il < n_layer; ++il) { // (il % 2) layers use SWA - struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full; + struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask; // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -11084,7 +11091,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il); } cur = llm_build_norm(ctx0, cur, hparams, From 7dc9cbf03fd7d3b3a977872f7bae0fe6cbfad5db Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 1 Jul 2024 18:38:24 +0300 Subject: [PATCH 8/9] convert : add sanity check for query_pre_attn_scalar --- convert-hf-to-gguf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 27fc9eea6d69b..4a7f500ff7d5c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2371,6 +2371,11 @@ def set_gguf_parameters(self): ) self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + # sanity check + attn_scalar = self.hparams["query_pre_attn_scalar"] + if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]: + raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head") + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem From e24328ea1ada2d298ea8a64a141d13d240a87319 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 1 Jul 2024 17:54:34 +0200 Subject: [PATCH 9/9] fix small typo in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c136d4a5cb9c9..daba70717312e 100644 --- a/README.md +++ b/README.md @@ -218,7 +218,7 @@ Unless otherwise noted these projects are open-source with permissive licensing: **Tools:** - [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML -[crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption +- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption ---