From cae8f50b1a240fa7c1216a34ab4b0b2e604012a3 Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Mon, 4 Dec 2023 21:52:17 +0100 Subject: [PATCH 1/8] initial commit, going through initializations --- examples/lookup/CMakeLists.txt | 5 ++ examples/lookup/README.md | 0 examples/lookup/lookup.cpp | 113 +++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 examples/lookup/CMakeLists.txt create mode 100644 examples/lookup/README.md create mode 100644 examples/lookup/lookup.cpp diff --git a/examples/lookup/CMakeLists.txt b/examples/lookup/CMakeLists.txt new file mode 100644 index 0000000000000..c060b8f56d436 --- /dev/null +++ b/examples/lookup/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET lookup) +add_executable(${TARGET} lookup.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/lookup/README.md b/examples/lookup/README.md new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp new file mode 100644 index 0000000000000..e7174ffe42922 --- /dev/null +++ b/examples/lookup/lookup.cpp @@ -0,0 +1,113 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +/* +def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10): + input_length = input_ids.size(1) + + for ngram_size in range(max_ngram_size, 0, -1): + # Extract the last n tokens as our search ngram + ngram = input_ids[0, -ngram_size:].tolist() + + # Create sliding windows of size ngram_size + windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) + + # Convert ngram to a tensor for comparison + ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0) + + # Find where the windows match the ngram + matches = (windows == ngram_tensor).all(dim=2) + + # Get the indices of matches + match_indices = matches.nonzero(as_tuple=True)[1] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + num_pred_tokens + # Ensure we don't go beyond the length of input_ids and avoid self-match + if end_idx <= input_length and start_idx < input_length - ngram_size: + return input_ids[0, start_idx:end_idx] + + # If no match is found, return an empty tensor + return torch.tensor([], dtype=torch.long, device=input_ids.device) +*/ + +int main(int argc, char ** argv){ + gpt_params params; + + if(gpt_params_parse(argc, argv, params) == false){ + return 1; + } + + // maximum n-grams to search for in prompt + const int max_ngram_size = 3; + + // length of the candidate sequence, if match is found + const int num_pred_tokens = 10; + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("lookup", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); +#endif // LOG_DISABLE_LOGS + + // init llama.cpp + llama_backend_init(params.numa); + + llama_model * model = NULL; + llama_context * ctx = NULL; + + // load the model + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + // tokenize the prompt + const bool add_bos = llama_should_add_bos_token(model); + LOG("add_bos tgt: %d\n", add_bos); + + std::vector inp; + inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + fprintf(stderr, "\n\n"); + + for (auto id : inp) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + + const auto t_enc_end = ggml_time_us(); + + int n_accept = 0; + + int n_past = inp.size(); + + bool has_eos = false; + + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + + const auto t_dec_start = ggml_time_us(); + + + +} \ No newline at end of file From 0ec5fdb5ceff119dd3364702a2ea1ebc0ec3e5e6 Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Sun, 10 Dec 2023 20:20:01 +0100 Subject: [PATCH 2/8] main loop finished, starting to debug --- .gitignore | 1 + Makefile | 5 +- common/common.h | 4 +- examples/CMakeLists.txt | 1 + examples/lookup/lookup.cpp | 164 +++++++++++++++++++++++++++++-------- 5 files changed, 138 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 58c4839940b49..5464f08cca4e3 100644 --- a/.gitignore +++ b/.gitignore @@ -48,6 +48,7 @@ models-mnt /llama-bench /llava-cli /lookahead +/lookup /main /metal /perplexity diff --git a/Makefile b/Makefile index 3cc932a2e2822..4fefa4e0da67d 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = \ main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ - speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead tests/test-c.o + speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o # Binaries only useful for tests TEST_TARGETS = \ @@ -664,6 +664,9 @@ parallel: examples/parallel/parallel.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + ifdef LLAMA_METAL metal: examples/metal/metal.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) diff --git a/common/common.h b/common/common.h index 2f6fe48ab53d3..3acdfdd742d93 100644 --- a/common/common.h +++ b/common/common.h @@ -75,10 +75,10 @@ struct gpt_params { // // sampling parameters struct llama_sampling_params sparams; - std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model = "models/7B/ggml-model-q4_0.gguf"; // model path std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias - std::string prompt = ""; + std::string prompt = "Hello my name is"; std::string prompt_file = ""; // store the external prompt file name std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6744944fd8b99..4cc13d6e99ce1 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -33,6 +33,7 @@ else() add_subdirectory(simple) add_subdirectory(speculative) add_subdirectory(lookahead) + add_subdirectory(lookup) add_subdirectory(train-text-from-scratch) if (LLAMA_METAL) add_subdirectory(metal) diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index e7174ffe42922..3f3ad10b9358a 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -6,38 +6,6 @@ #include #include -/* -def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10): - input_length = input_ids.size(1) - - for ngram_size in range(max_ngram_size, 0, -1): - # Extract the last n tokens as our search ngram - ngram = input_ids[0, -ngram_size:].tolist() - - # Create sliding windows of size ngram_size - windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) - - # Convert ngram to a tensor for comparison - ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0) - - # Find where the windows match the ngram - matches = (windows == ngram_tensor).all(dim=2) - - # Get the indices of matches - match_indices = matches.nonzero(as_tuple=True)[1] - - # Iterate through match indices to find a valid continuation - for idx in match_indices: - start_idx = idx + ngram_size - end_idx = start_idx + num_pred_tokens - # Ensure we don't go beyond the length of input_ids and avoid self-match - if end_idx <= input_length and start_idx < input_length - ngram_size: - return input_ids[0, start_idx:end_idx] - - # If no match is found, return an empty tensor - return torch.tensor([], dtype=torch.long, device=input_ids.device) -*/ - int main(int argc, char ** argv){ gpt_params params; @@ -48,8 +16,8 @@ int main(int argc, char ** argv){ // maximum n-grams to search for in prompt const int max_ngram_size = 3; - // length of the candidate sequence, if match is found - const int num_pred_tokens = 10; + // length of the candidate / draft sequence, if match is found + const int n_draft = 10; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("lookup", "log")); @@ -98,6 +66,8 @@ int main(int argc, char ** argv){ const auto t_enc_end = ggml_time_us(); + int n_predict = 0; + int n_drafted = 0; int n_accept = 0; int n_past = inp.size(); @@ -106,8 +76,134 @@ int main(int argc, char ** argv){ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + std::vector draft(n_draft); + + llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + const auto t_dec_start = ggml_time_us(); + while(true){ + // print current draft sequence + LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str()); + + int i_dft = 0; + while (true) { + //LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + + // sample from the target model + llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + + llama_sampling_accept(ctx_sampling, ctx, id, true); + + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + + const std::string token_str = llama_token_to_piece(ctx, id); + + printf("%s", token_str.c_str()); + fflush(stdout); + + if (id == llama_token_eos(model)) { + has_eos = true; + } + + ++n_predict; + + // check if the target token matches the draft + if (i_dft < (int) draft.size() && id == draft[i_dft]) { + LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); + ++n_accept; + ++n_past; + ++i_dft; + + continue; + } + + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + + draft.clear(); + draft.push_back(id); + // drafts[0].i_batch_tgt.push_back(0); + + // llama_batch_clear(batch_dft); + // llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); + + // llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); + // // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); + // llama_decode (ctx_dft, batch_dft); + + // ++n_past_dft; + break; + } + + if (n_predict > params.n_predict || has_eos) { + break; + } + + llama_batch_clear(batch_tgt); + llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); + + bool match = false; + // generate n_pred tokens through prompt lookup + for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ + if (match){ + break; + } + const auto & prev = ctx_sampling->prev; + int prev_size = prev.size(); + const llama_token * ngram = &prev[prev_size - ngram_size]; + + for (int i = 0; i <= (int) prev_size - (ngram_size * 2); ++i) { + if (prev[i] == ngram[0] && prev[i + 1] == ngram[1] && prev[i + 2] == ngram[2]) { + const int startIdx = i + ngram_size; + const int endIdx = startIdx + n_draft; + if (endIdx < prev_size){ + match = true; + for (int j = startIdx; j < endIdx; ++j) { + LOG(" - draft candidate %d: %d\n", j, prev[j]); + draft.push_back(prev[j]); + llama_batch_add(batch_tgt, prev[j], n_past + j + 1, { 1 }, true); + ++n_drafted; + } + } + } + } + } + + llama_decode(ctx, batch_tgt); + ++n_past; + + draft.erase(draft.begin()); + + // we have our draft! + } + + auto t_dec_end = ggml_time_us(); + + LOG_TEE("\n\n"); + + LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_TEE("\n"); + LOG_TEE("n_draft = %d\n", n_draft); + LOG_TEE("n_predict = %d\n", n_predict); + LOG_TEE("n_drafted = %d\n", n_drafted); + LOG_TEE("n_accept = %d\n", n_accept); + LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_TEE("\ntarget:\n"); + llama_print_timings(ctx); + + llama_sampling_free(ctx_sampling); + llama_batch_free(batch_tgt); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + fprintf(stderr, "\n\n"); + return 0; } \ No newline at end of file From 1665ad8bf1780de80e1a5e7d99761a3ebcbde4ee Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Fri, 15 Dec 2023 14:14:17 +0100 Subject: [PATCH 3/8] BUG: generates gibberish/repeating tokens after a while --- common/common.h | 6 ++-- examples/lookup/lookup.cpp | 61 +++++++++++++++++++------------------- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/common/common.h b/common/common.h index 3acdfdd742d93..8c73da247bc6d 100644 --- a/common/common.h +++ b/common/common.h @@ -75,10 +75,10 @@ struct gpt_params { // // sampling parameters struct llama_sampling_params sparams; - std::string model = "models/7B/ggml-model-q4_0.gguf"; // model path + std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias - std::string prompt = "Hello my name is"; + std::string prompt = ""; std::string prompt_file = ""; // store the external prompt file name std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with @@ -228,4 +228,4 @@ void dump_non_result_info_yaml( void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); +void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); \ No newline at end of file diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 3f3ad10b9358a..a9347e51e0515 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -122,16 +122,7 @@ int main(int argc, char ** argv){ draft.clear(); draft.push_back(id); - // drafts[0].i_batch_tgt.push_back(0); - - // llama_batch_clear(batch_dft); - // llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); - - // llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); - // // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - // llama_decode (ctx_dft, batch_dft); - - // ++n_past_dft; + inp.push_back(id); break; } @@ -142,33 +133,41 @@ int main(int argc, char ** argv){ llama_batch_clear(batch_tgt); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); - bool match = false; // generate n_pred tokens through prompt lookup - for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ - if (match){ - break; - } - const auto & prev = ctx_sampling->prev; - int prev_size = prev.size(); - const llama_token * ngram = &prev[prev_size - ngram_size]; - - for (int i = 0; i <= (int) prev_size - (ngram_size * 2); ++i) { - if (prev[i] == ngram[0] && prev[i + 1] == ngram[1] && prev[i + 2] == ngram[2]) { - const int startIdx = i + ngram_size; - const int endIdx = startIdx + n_draft; - if (endIdx < prev_size){ - match = true; - for (int j = startIdx; j < endIdx; ++j) { - LOG(" - draft candidate %d: %d\n", j, prev[j]); - draft.push_back(prev[j]); - llama_batch_add(batch_tgt, prev[j], n_past + j + 1, { 1 }, true); - ++n_drafted; + auto prompt_lookup = [&]() -> void { + int inp_size = inp.size(); + for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ + const llama_token * ngram = &inp[inp_size - ngram_size]; + + for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) { + bool match = true; + for (int j = 0; j < ngram_size; ++j) { + if (inp[i + j] != ngram[j]) { + match = false; + break; + } + } + + if (match) { + const int startIdx = i + ngram_size; + const int endIdx = startIdx + n_draft; + if (endIdx < inp_size){ + for (int j = startIdx; j < endIdx; ++j) { + LOG(" - draft candidate %d: %d\n", j, inp[j]); + draft.push_back(inp[j]); + llama_batch_add(batch_tgt, inp[j], n_past + j + 1, { 0 }, true); + ++n_drafted; + } + return; } } } } - } + return; + }; + prompt_lookup(); + llama_decode(ctx, batch_tgt); ++n_past; From 21431197a1c7d04cd95fae4360667be1177d1dd9 Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Sat, 16 Dec 2023 12:12:33 +0100 Subject: [PATCH 4/8] kv_cache management --- examples/lookup/lookup.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index a9347e51e0515..28b9c2c950be0 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -130,6 +130,10 @@ int main(int argc, char ** argv){ break; } + // KV cache management + // clean the cache of draft tokens that weren't accepted + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + llama_batch_clear(batch_tgt); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); From 1b26d7151abbc8bb41ba154335aa537e0cb4f08c Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Sun, 17 Dec 2023 13:04:46 +0100 Subject: [PATCH 5/8] Added colors to distinguish drafted tokens (--color). Updated README --- examples/lookup/README.md | 13 +++++++++++++ examples/lookup/lookup.cpp | 28 +++++++++++++++++----------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/examples/lookup/README.md b/examples/lookup/README.md index e69de29bb2d1d..03a772c45a45e 100644 --- a/examples/lookup/README.md +++ b/examples/lookup/README.md @@ -0,0 +1,13 @@ +# llama.cpp/examples/lookup + +Demonstration of Prompt Lookup Decoding + +https://github.com/apoorvumang/prompt-lookup-decoding + +The two key parameters for lookup decoding are `max_ngram_size` and `n_draft`. The first, determines how many ngrams to use when searching through the prompt for a match and the second specifies how many subsequent tokens to draft if a match is found. + +More info: + +https://github.com/ggerganov/llama.cpp/pull/4484 +https://github.com/ggerganov/llama.cpp/issues/4226 + diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 28b9c2c950be0..db97d241c7c5a 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -88,19 +88,16 @@ int main(int argc, char ** argv){ int i_dft = 0; while (true) { - //LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); llama_sampling_accept(ctx_sampling, ctx, id, true); - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); - const std::string token_str = llama_token_to_piece(ctx, id); - printf("%s", token_str.c_str()); - fflush(stdout); + if (!params.use_color) { + printf("%s", token_str.c_str()); + } if (id == llama_token_eos(model)) { has_eos = true; @@ -114,9 +111,21 @@ int main(int argc, char ** argv){ ++n_accept; ++n_past; ++i_dft; - + inp.push_back(id); + + if (params.use_color) { + // color accepted draft token + printf("\033[34m%s\033[0m", token_str.c_str()); + fflush(stdout); + } continue; + } + + if (params.use_color) { + printf("%s", token_str.c_str()); } + fflush(stdout); + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); @@ -176,9 +185,6 @@ int main(int argc, char ** argv){ ++n_past; draft.erase(draft.begin()); - - // we have our draft! - } auto t_dec_end = ggml_time_us(); From 5b27975479c4d9a0cffc3528af54b8d01ae46bfb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Dec 2023 16:47:26 +0200 Subject: [PATCH 6/8] lookup : fix token positions in the draft batch --- common/common.h | 3 ++- examples/lookup/lookup.cpp | 41 ++++++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/common/common.h b/common/common.h index ef2a61de6c942..875e012a21829 100644 --- a/common/common.h +++ b/common/common.h @@ -239,4 +239,5 @@ void dump_non_result_info_yaml( void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); \ No newline at end of file +void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); + diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index db97d241c7c5a..6b4eb957af5e5 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -19,6 +19,8 @@ int main(int argc, char ** argv){ // length of the candidate / draft sequence, if match is found const int n_draft = 10; + const bool dump_kv_cache = params.dump_kv_cache; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("lookup", "log")); LOG_TEE("Log start\n"); @@ -37,7 +39,7 @@ int main(int argc, char ** argv){ // tokenize the prompt const bool add_bos = llama_should_add_bos_token(model); LOG("add_bos tgt: %d\n", add_bos); - + std::vector inp; inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); @@ -69,24 +71,33 @@ int main(int argc, char ** argv){ int n_predict = 0; int n_drafted = 0; int n_accept = 0; - + int n_past = inp.size(); bool has_eos = false; struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); - std::vector draft(n_draft); + std::vector draft; llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + // debug + struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); + const auto t_dec_start = ggml_time_us(); - while(true){ + while (true) { + // debug + if (dump_kv_cache) { + llama_kv_cache_view_update(ctx, &kvc_view); + dump_kv_cache_view_seqs(kvc_view, 40); + } + // print current draft sequence LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str()); - int i_dft = 0; + int i_dft = 0; while (true) { // sample from the target model llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); @@ -120,13 +131,13 @@ int main(int argc, char ** argv){ } continue; } - + if (params.use_color) { printf("%s", token_str.c_str()); - } + } fflush(stdout); - + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); draft.clear(); @@ -135,7 +146,7 @@ int main(int argc, char ** argv){ break; } - if (n_predict > params.n_predict || has_eos) { + if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) { break; } @@ -149,9 +160,9 @@ int main(int argc, char ** argv){ // generate n_pred tokens through prompt lookup auto prompt_lookup = [&]() -> void { int inp_size = inp.size(); - for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ + for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ const llama_token * ngram = &inp[inp_size - ngram_size]; - + for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) { bool match = true; for (int j = 0; j < ngram_size; ++j) { @@ -164,11 +175,11 @@ int main(int argc, char ** argv){ if (match) { const int startIdx = i + ngram_size; const int endIdx = startIdx + n_draft; - if (endIdx < inp_size){ + if (endIdx < inp_size) { for (int j = startIdx; j < endIdx; ++j) { LOG(" - draft candidate %d: %d\n", j, inp[j]); draft.push_back(inp[j]); - llama_batch_add(batch_tgt, inp[j], n_past + j + 1, { 0 }, true); + llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true); ++n_drafted; } return; @@ -180,7 +191,7 @@ int main(int argc, char ** argv){ }; prompt_lookup(); - + llama_decode(ctx, batch_tgt); ++n_past; @@ -215,4 +226,4 @@ int main(int argc, char ** argv){ fprintf(stderr, "\n\n"); return 0; -} \ No newline at end of file +} From d8ed670c6ce276301c085bcc6a9f589ef0916502 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Dec 2023 20:06:41 +0200 Subject: [PATCH 7/8] lookup : use n_draft from CLI params --- common/common.h | 2 +- examples/lookup/lookup.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.h b/common/common.h index 875e012a21829..9659aa0453ff8 100644 --- a/common/common.h +++ b/common/common.h @@ -51,7 +51,7 @@ struct gpt_params { int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 16; // number of tokens to draft during speculative decoding + int32_t n_draft = 8; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 6b4eb957af5e5..ab1be0a327c33 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -17,7 +17,7 @@ int main(int argc, char ** argv){ const int max_ngram_size = 3; // length of the candidate / draft sequence, if match is found - const int n_draft = 10; + const int n_draft = params.n_draft; const bool dump_kv_cache = params.dump_kv_cache; From 50ea1ef7c8639609593017215604c3a3a84987d4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Dec 2023 18:04:30 +0200 Subject: [PATCH 8/8] lookup : final touches --- examples/lookup/README.md | 2 +- examples/lookup/lookup.cpp | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/lookup/README.md b/examples/lookup/README.md index 03a772c45a45e..5bfb0de936041 100644 --- a/examples/lookup/README.md +++ b/examples/lookup/README.md @@ -4,7 +4,7 @@ Demonstration of Prompt Lookup Decoding https://github.com/apoorvumang/prompt-lookup-decoding -The two key parameters for lookup decoding are `max_ngram_size` and `n_draft`. The first, determines how many ngrams to use when searching through the prompt for a match and the second specifies how many subsequent tokens to draft if a match is found. +The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found. More info: diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index ab1be0a327c33..d8de7dd387273 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -9,12 +9,13 @@ int main(int argc, char ** argv){ gpt_params params; - if(gpt_params_parse(argc, argv, params) == false){ + if (!gpt_params_parse(argc, argv, params)) { return 1; } - // maximum n-grams to search for in prompt - const int max_ngram_size = 3; + // max/min n-grams size to search for in prompt + const int ngram_max = 4; + const int ngram_min = 1; // length of the candidate / draft sequence, if match is found const int n_draft = params.n_draft; @@ -160,7 +161,7 @@ int main(int argc, char ** argv){ // generate n_pred tokens through prompt lookup auto prompt_lookup = [&]() -> void { int inp_size = inp.size(); - for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ + for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){ const llama_token * ngram = &inp[inp_size - ngram_size]; for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {