@@ -729,11 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
729729
730730// Function to tokenize the prompt
731731static int tokenize_prompt (const llama_vocab * vocab, const std::string & prompt,
732- std::vector<llama_token> & prompt_tokens) {
733- const int n_prompt_tokens = -llama_tokenize (vocab, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
732+ std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data,
733+ const bool is_first) {
734+ const int n_prompt_tokens = -llama_tokenize (vocab, prompt.c_str (), prompt.size (), NULL , 0 , is_first, true );
734735 prompt_tokens.resize (n_prompt_tokens);
735- if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true ,
736- true ) < 0 ) {
736+ if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (),
737+ llama_get_kv_cache_used_cells (llama_data. context . get ()) == 0 , true ) < 0 ) {
737738 printe (" failed to tokenize the prompt\n " );
738739 return -1 ;
739740 }
@@ -774,11 +775,11 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
774775}
775776
776777// helper function to evaluate a prompt and generate a response
777- static int generate (LlamaData & llama_data, const std::string & prompt, std::string & response) {
778+ static int generate (LlamaData & llama_data, const std::string & prompt, std::string & response, const bool is_first ) {
778779 const llama_vocab * vocab = llama_model_get_vocab (llama_data.model .get ());
779780
780781 std::vector<llama_token> tokens;
781- if (tokenize_prompt (vocab, prompt, tokens) < 0 ) {
782+ if (tokenize_prompt (vocab, prompt, tokens, llama_data, is_first ) < 0 ) {
782783 return 1 ;
783784 }
784785
@@ -852,13 +853,13 @@ static int read_user_input(std::string & user_input) {
852853
853854// Function to generate a response based on the prompt
854855static int generate_response (LlamaData & llama_data, const std::string & prompt, std::string & response,
855- const bool stdout_a_terminal) {
856+ const bool stdout_a_terminal, const int prev_len ) {
856857 // Set response color
857858 if (stdout_a_terminal) {
858859 printf (" \033 [33m" );
859860 }
860861
861- if (generate (llama_data, prompt, response)) {
862+ if (generate (llama_data, prompt, response, prev_len == 0 )) {
862863 printe (" failed to generate response\n " );
863864 return 1 ;
864865 }
@@ -948,7 +949,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
948949
949950 std::string prompt (llama_data.fmtted .begin () + prev_len, llama_data.fmtted .begin () + new_len);
950951 std::string response;
951- if (generate_response (llama_data, prompt, response, stdout_a_terminal)) {
952+ if (generate_response (llama_data, prompt, response, stdout_a_terminal, prev_len )) {
952953 return 1 ;
953954 }
954955
0 commit comments