From 21196da114d88f40bda32227a8f24f5eb6ac33d4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 30 Oct 2023 10:44:07 +0200 Subject: [PATCH 1/8] examples : add passkey test --- examples/CMakeLists.txt | 1 + examples/batched/batched.cpp | 1 + examples/passkey/CMakeLists.txt | 5 + examples/passkey/passkey.cpp | 263 ++++++++++++++++++++++++++++++++ 4 files changed, 270 insertions(+) create mode 100644 examples/passkey/CMakeLists.txt create mode 100644 examples/passkey/passkey.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4cc13d6e99ce1..0c71cbdf72a65 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -31,6 +31,7 @@ else() add_subdirectory(quantize-stats) add_subdirectory(save-load-state) add_subdirectory(simple) + add_subdirectory(passkey) add_subdirectory(speculative) add_subdirectory(lookahead) add_subdirectory(lookup) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 22a4265df77c0..b1775e0b0e8d6 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -69,6 +69,7 @@ int main(int argc, char ** argv) { std::vector tokens_list; tokens_list = ::llama_tokenize(model, params.prompt, true); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; // initialize the context diff --git a/examples/passkey/CMakeLists.txt b/examples/passkey/CMakeLists.txt new file mode 100644 index 0000000000000..3161bf3ef9a45 --- /dev/null +++ b/examples/passkey/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET passkey) +add_executable(${TARGET} passkey.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp new file mode 100644 index 0000000000000..e6e5fc4b687e0 --- /dev/null +++ b/examples/passkey/passkey.cpp @@ -0,0 +1,263 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH N_JUNK SEED\n" , argv[0]); + return 1 ; + } + + int seed = -1; + + int n_junk = 250; // number of times to repeat the junk text + int n_keep = 32; // number of tokens in the prompt prefix + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + n_junk = std::stoi(argv[2]); + } + + if (argc >= 4) { + seed = std::stoi(argv[3]); + } + + const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; + + if (seed == -1) { + seed = time(NULL); + } + + srand(seed); + + // generate junk text + params.prompt = prompt_prefix; + + const int n_insert = rand() % n_junk; + const int passkey = rand() % 50000 + 1; + + for (int i = 0; i < n_junk; i++) { + if (i % n_junk == n_insert) { + params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key."; + } + + params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."; + } + + params.prompt += " What is the pass key? The pass key is"; + + // init LLM + + llama_backend_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + + model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = seed; + ctx_params.n_ctx = llama_n_ctx_train(model) + n_keep; + ctx_params.n_batch = 512; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // tokenize the prefix and use it as a sink + const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); + + // tokenize the prompt + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + // we leave a margin of 16 tokens for the generated text - it should contain just the passkey + const int n_predict = 16; + + // total length of the sequences including the prompt + const int n_len = tokens_list.size() + n_predict; + + const int n_ctx = llama_n_ctx(ctx) - n_keep; + const int n_kv_req = llama_n_ctx(ctx); + const int n_batch = ctx_params.n_batch; + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + + // print the prompt token-by-token + + LOG_TEE("\n"); + LOG_TEE("prefix tokens: %d\n", n_tokens_prefix); + LOG_TEE("prompt tokens: %d\n", (int) tokens_list.size()); + //LOG_TEE("prompt: %s\n", params.prompt.c_str()); + + llama_batch batch = llama_batch_init(512, 0, 1); + + // fill the KV cache + for (int i = 0; i < n_ctx; i += n_batch) { + llama_batch_clear(batch); + + for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { + llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false); + } + + if (i + n_batch >= (int) tokens_list.size()) { + batch.logits[batch.n_tokens - 1] = true; + } + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); + + if (i + n_batch >= (int) tokens_list.size()) { + break; + } + } + + for (int i = n_ctx; i < (int) tokens_list.size(); i += n_batch) { + const int n_discard = n_batch; + + LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + + llama_batch_clear(batch); + + for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { + llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false); + } + + if (i + n_batch >= (int) tokens_list.size()) { + batch.logits[batch.n_tokens - 1] = true; + } + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); + } + + int n_past = batch.pos[batch.n_tokens - 1]; + + { + const int n_discard = n_past - n_ctx + n_predict; + + if (n_discard > 0) { + LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + + n_past -= n_discard; + } + } + + LOG_TEE("\n"); + + // main loop + + int n_cur = tokens_list.size(); + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // sample the next token + { + auto n_vocab = llama_n_vocab(model); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of stream? + if (new_token_id == llama_token_eos(model) || n_cur == n_len) { + LOG_TEE("\n"); + + break; + } + + LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); + fflush(stdout); + + n_decode += 1; + n_past += 1; + + // prepare the next batch + llama_batch_clear(batch); + + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_past, { 0 }, true); + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG_TEE("\n"); + LOG_TEE("%s: passkey = %d, inserted at position %d / %d\n", __func__, passkey, n_insert, n_junk); + + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} From fbb999f592289a50e67f4b614a0123800ea72d2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 30 Oct 2023 11:13:44 +0200 Subject: [PATCH 2/8] passkey : better prints --- examples/passkey/passkey.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index e6e5fc4b687e0..f3b83d5097d9e 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -32,6 +32,7 @@ int main(int argc, char ** argv) { } const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; + const std::string prompt_suffix = " What is the pass key? The pass key is"; if (seed == -1) { seed = time(NULL); @@ -53,7 +54,7 @@ int main(int argc, char ** argv) { params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."; } - params.prompt += " What is the pass key? The pass key is"; + params.prompt += prompt_suffix; // init LLM @@ -182,6 +183,8 @@ int main(int argc, char ** argv) { } } + LOG_TEE("\n"); + LOG_TEE("%s: passkey = %d, inserted at position %d / %d\n", __func__, passkey, n_insert, n_junk); LOG_TEE("\n"); // main loop @@ -189,6 +192,9 @@ int main(int argc, char ** argv) { int n_cur = tokens_list.size(); int n_decode = 0; + LOG_TEE("%s", prompt_suffix.c_str()); + fflush(stdout); + const auto t_main_start = ggml_time_us(); while (n_cur <= n_len) { @@ -238,9 +244,6 @@ int main(int argc, char ** argv) { } } - LOG_TEE("\n"); - LOG_TEE("%s: passkey = %d, inserted at position %d / %d\n", __func__, passkey, n_insert, n_junk); - LOG_TEE("\n"); const auto t_main_end = ggml_time_us(); From bda3f2c89260b69d738dc772a0e11ee70d88e9b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 7 Jan 2024 14:48:09 +0200 Subject: [PATCH 3/8] passkey : select pass key pos from CLI --- examples/passkey/passkey.cpp | 54 +++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index f3b83d5097d9e..682f90e83202a 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -10,7 +10,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH N_JUNK SEED\n" , argv[0]); + printf("usage: %s MODEL_PATH N_JUNK I_POS SEED\n" , argv[0]); return 1 ; } @@ -18,6 +18,7 @@ int main(int argc, char ** argv) { int n_junk = 250; // number of times to repeat the junk text int n_keep = 32; // number of tokens in the prompt prefix + int i_pos = -1; // position of the passkey in the junk text if (argc >= 2) { params.model = argv[1]; @@ -28,11 +29,12 @@ int main(int argc, char ** argv) { } if (argc >= 4) { - seed = std::stoi(argv[3]); + i_pos = std::stoi(argv[3]); } - const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; - const std::string prompt_suffix = " What is the pass key? The pass key is"; + if (argc >= 5) { + seed = std::stoi(argv[4]); + } if (seed == -1) { seed = time(NULL); @@ -40,14 +42,20 @@ int main(int argc, char ** argv) { srand(seed); + if (i_pos == -1) { + i_pos = rand() % n_junk; + } + + const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; + const std::string prompt_suffix = " What is the pass key? The pass key is"; + // generate junk text params.prompt = prompt_prefix; - const int n_insert = rand() % n_junk; - const int passkey = rand() % 50000 + 1; + const int passkey = rand() % 50000 + 1; for (int i = 0; i < n_junk; i++) { - if (i % n_junk == n_insert) { + if (i % n_junk == i_pos) { params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key."; } @@ -90,18 +98,20 @@ int main(int argc, char ** argv) { return 1; } - // tokenize the prefix and use it as a sink - const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); - // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); + // tokenize the prefix and use it as a sink + const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); + + const int n_tokens_all = tokens_list.size(); + // we leave a margin of 16 tokens for the generated text - it should contain just the passkey const int n_predict = 16; // total length of the sequences including the prompt - const int n_len = tokens_list.size() + n_predict; + const int n_len = n_tokens_all + n_predict; const int n_ctx = llama_n_ctx(ctx) - n_keep; const int n_kv_req = llama_n_ctx(ctx); @@ -113,7 +123,7 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); LOG_TEE("prefix tokens: %d\n", n_tokens_prefix); - LOG_TEE("prompt tokens: %d\n", (int) tokens_list.size()); + LOG_TEE("prompt tokens: %d\n", n_tokens_all); //LOG_TEE("prompt: %s\n", params.prompt.c_str()); llama_batch batch = llama_batch_init(512, 0, 1); @@ -122,11 +132,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < n_ctx; i += n_batch) { llama_batch_clear(batch); - for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false); } - if (i + n_batch >= (int) tokens_list.size()) { + if (i + n_batch >= n_tokens_all) { batch.logits[batch.n_tokens - 1] = true; } @@ -135,14 +145,14 @@ int main(int argc, char ** argv) { return 1; } - LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); - if (i + n_batch >= (int) tokens_list.size()) { + if (i + n_batch >= n_tokens_all) { break; } } - for (int i = n_ctx; i < (int) tokens_list.size(); i += n_batch) { + for (int i = n_ctx; i < n_tokens_all; i += n_batch) { const int n_discard = n_batch; LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); @@ -152,11 +162,11 @@ int main(int argc, char ** argv) { llama_batch_clear(batch); - for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) { + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false); } - if (i + n_batch >= (int) tokens_list.size()) { + if (i + n_batch >= n_tokens_all) { batch.logits[batch.n_tokens - 1] = true; } @@ -165,7 +175,7 @@ int main(int argc, char ** argv) { return 1; } - LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size())); + LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); } int n_past = batch.pos[batch.n_tokens - 1]; @@ -184,12 +194,12 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: passkey = %d, inserted at position %d / %d\n", __func__, passkey, n_insert, n_junk); + LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk); LOG_TEE("\n"); // main loop - int n_cur = tokens_list.size(); + int n_cur = n_tokens_all; int n_decode = 0; LOG_TEE("%s", prompt_suffix.c_str()); From f2c9800dfb70369f8436c583d2c361b811690abb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 7 Jan 2024 17:52:12 +0200 Subject: [PATCH 4/8] passkey : simplify n_past logic --- examples/passkey/passkey.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 682f90e83202a..862dde996e7e2 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -128,12 +128,14 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(512, 0, 1); + int n_past = 0; + // fill the KV cache for (int i = 0; i < n_ctx; i += n_batch) { llama_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false); + llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); } if (i + n_batch >= n_tokens_all) { @@ -160,10 +162,12 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + n_past -= n_discard; + llama_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false); + llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); } if (i + n_batch >= n_tokens_all) { @@ -178,8 +182,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); } - int n_past = batch.pos[batch.n_tokens - 1]; - { const int n_discard = n_past - n_ctx + n_predict; @@ -236,13 +238,12 @@ int main(int argc, char ** argv) { fflush(stdout); n_decode += 1; - n_past += 1; // prepare the next batch llama_batch_clear(batch); // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_past, { 0 }, true); + llama_batch_add(batch, new_token_id, n_past++, { 0 }, true); } n_cur += 1; From 2f40c9f6c58d06682e235b510dfae35b29a9b42a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 7 Jan 2024 16:16:19 +0200 Subject: [PATCH 5/8] llama : "self-extend"-like context extension --- examples/passkey/passkey.cpp | 32 +++++++++++++++++++++++++------- llama.cpp | 34 ++++++++++++++++++++++++++++++++++ llama.h | 7 +++++++ 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 862dde996e7e2..815a92f93775b 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -10,7 +10,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH N_JUNK I_POS SEED\n" , argv[0]); + printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]); return 1 ; } @@ -18,6 +18,7 @@ int main(int argc, char ** argv) { int n_junk = 250; // number of times to repeat the junk text int n_keep = 32; // number of tokens in the prompt prefix + int n_grp = 1; // if more than 1 - perform LongLM SelfExtend int i_pos = -1; // position of the passkey in the junk text if (argc >= 2) { @@ -29,11 +30,15 @@ int main(int argc, char ** argv) { } if (argc >= 4) { - i_pos = std::stoi(argv[3]); + n_grp = std::stoi(argv[3]); } if (argc >= 5) { - seed = std::stoi(argv[4]); + i_pos = std::stoi(argv[4]); + } + + if (argc >= 6) { + seed = std::stoi(argv[5]); } if (seed == -1) { @@ -86,11 +91,13 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.seed = seed; - ctx_params.n_ctx = llama_n_ctx_train(model) + n_keep; + ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; ctx_params.n_batch = 512; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); + llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { @@ -113,9 +120,10 @@ int main(int argc, char ** argv) { // total length of the sequences including the prompt const int n_len = n_tokens_all + n_predict; - const int n_ctx = llama_n_ctx(ctx) - n_keep; - const int n_kv_req = llama_n_ctx(ctx); - const int n_batch = ctx_params.n_batch; + const int n_ctx = llama_n_ctx(ctx) - n_keep; + const int n_kv_req = llama_n_ctx(ctx); + const int n_batch = ctx_params.n_batch; + const int n_batch_grp = ctx_params.n_batch/n_grp; LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); @@ -132,6 +140,16 @@ int main(int argc, char ** argv) { // fill the KV cache for (int i = 0; i < n_ctx; i += n_batch) { + if (i > 0 && n_grp > 1) { + const int ib = i/n_batch - 1; + const int bd = n_batch_grp*(n_grp - 1); + + llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + + n_past -= bd; + } + llama_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { diff --git a/llama.cpp b/llama.cpp index 91aa3f8e79191..63853d1c3cdae 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1903,6 +1903,28 @@ static void llama_kv_cache_seq_shift( cache.head = new_head != cache.size ? new_head : 0; } +static void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + + { + llama_pos p_old = cache.cells[i].pos; + cache.cells[i].pos /= d; + cache.cells[i].delta += cache.cells[i].pos - p_old; + } + } + } +} + // // model loading and saving // @@ -10140,9 +10162,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { } void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); +} + // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. diff --git a/llama.h b/llama.h index 461d4604a1b54..5305de90be5c1 100644 --- a/llama.h +++ b/llama.h @@ -484,6 +484,13 @@ extern "C" { llama_pos p1, llama_pos delta); + LLAMA_API void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + // // State / sessions // From f64cddc76d9a445140b949392c8e3e69b8f0d392 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 7 Jan 2024 16:37:02 +0200 Subject: [PATCH 6/8] passkey : add comment --- examples/passkey/passkey.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 815a92f93775b..5c0022832146b 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { const int n_batch = ctx_params.n_batch; const int n_batch_grp = ctx_params.n_batch/n_grp; - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch); // print the prompt token-by-token @@ -141,6 +141,7 @@ int main(int argc, char ** argv) { // fill the KV cache for (int i = 0; i < n_ctx; i += n_batch) { if (i > 0 && n_grp > 1) { + // if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); From ea12921826ee6951318e7812a2058fecbb57712b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 7 Jan 2024 22:22:44 +0200 Subject: [PATCH 7/8] main : add Self-Extend support --- common/common.cpp | 18 +++++++++ common/common.h | 2 + examples/main/main.cpp | 83 +++++++++++++++++++++++++++++++----------- 3 files changed, 81 insertions(+), 22 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index eacaee18e0907..6b4913a656573 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_ctx = std::stoi(argv[i]); + } else if (arg == "--grp-attn-n" || arg == "-gan") { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_n = std::stoi(argv[i]); + } else if (arg == "--grp-attn-w" || arg == "-gaw") { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_w = std::stoi(argv[i]); } else if (arg == "--rope-freq-base") { if (++i >= argc) { invalid_param = true; @@ -904,6 +918,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" Not recommended since this is both slower and uses more VRAM.\n"); #endif // GGML_USE_CUBLAS #endif + printf(" -gan N, --grp-attn-n N\n"); + printf(" group-attention factor (default: %d)\n", params.grp_attn_n); + printf(" -gat N, --grp-attn-w N\n"); + printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); printf(" --verbose-prompt print prompt before generation\n"); printf(" -dkvc, --dump-kv-cache\n"); printf(" verbose print of the KV cache\n"); diff --git a/common/common.h b/common/common.h index 9659aa0453ff8..e2bbfc258b646 100644 --- a/common/common.h +++ b/common/common.h @@ -62,6 +62,8 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_beams = 0; // if non-zero then use beam search of given width. + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c096f110b32c5..5ea67051f3654 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -439,6 +439,21 @@ int main(int argc, char ** argv) { LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + + // group-attention state + // number of grouped KV tokens so far (used only if params.grp_attn_n > 1) + int ga_i = 0; + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT + LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w); + } LOG_TEE("\n\n"); if (params.interactive) { @@ -500,37 +515,61 @@ int main(int argc, char ** argv) { fflush(stdout); } - // infinite text generation via context swapping - // if we run out of context: - // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { - if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); - break; - } + if (ga_n == 1) { + // infinite text generation via context shifting + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left/2; + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", - n_past, n_left, n_ctx, params.n_keep, n_discard); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - n_past -= n_discard; + n_past -= n_discard; - if (ctx_guidance) { - n_past_guidance -= n_discard; + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + LOG("clear session path\n"); + path_session.clear(); } + } else { + // context extension via Self-Extend + while (n_past >= ga_i + ga_w) { + const int ib = (ga_n*ga_i)/ga_w; + const int bd = (ga_w/ga_n)*(ga_n - 1); + const int dd = (ga_w/ga_n) - ib*bd - ga_w; - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("\n"); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); + LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); - LOG("clear session path\n"); - path_session.clear(); + n_past -= bd; + + ga_i += ga_w/ga_n; + + LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); + } } // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) From 82048d4750781dbbe657f2036f5b55dd21581165 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Jan 2024 11:17:00 +0200 Subject: [PATCH 8/8] llama : add comment about llama_kv_cache_seq_div --- llama.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llama.h b/llama.h index 5305de90be5c1..869ff0acf525a 100644 --- a/llama.h +++ b/llama.h @@ -484,6 +484,10 @@ extern "C" { llama_pos p1, llama_pos delta); + // Integer division of the positions by factor of `d > 1` + // If the KV cache is RoPEd, the KV data is updated accordingly + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id,