Skip to content

Commit 0cc8e02

Browse files
committed
run : fix BOS being added to each message
Porting the fix from simple-chat. Signed-off-by: Eric Curtin <[email protected]>
1 parent b9daaff commit 0cc8e02

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

examples/run/run.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -729,11 +729,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
729729

730730
// Function to tokenize the prompt
731731
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
732-
std::vector<llama_token> & prompt_tokens) {
733-
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
732+
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data,
733+
const bool is_first) {
734+
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
734735
prompt_tokens.resize(n_prompt_tokens);
735-
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
736-
true) < 0) {
736+
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(),
737+
llama_get_kv_cache_used_cells(llama_data.context.get()) == 0, true) < 0) {
737738
printe("failed to tokenize the prompt\n");
738739
return -1;
739740
}
@@ -774,11 +775,11 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
774775
}
775776

776777
// helper function to evaluate a prompt and generate a response
777-
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
778+
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response, const bool is_first) {
778779
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
779780

780781
std::vector<llama_token> tokens;
781-
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
782+
if (tokenize_prompt(vocab, prompt, tokens, llama_data, is_first) < 0) {
782783
return 1;
783784
}
784785

@@ -852,13 +853,13 @@ static int read_user_input(std::string & user_input) {
852853

853854
// Function to generate a response based on the prompt
854855
static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response,
855-
const bool stdout_a_terminal) {
856+
const bool stdout_a_terminal, const int prev_len) {
856857
// Set response color
857858
if (stdout_a_terminal) {
858859
printf("\033[33m");
859860
}
860861

861-
if (generate(llama_data, prompt, response)) {
862+
if (generate(llama_data, prompt, response, prev_len == 0)) {
862863
printe("failed to generate response\n");
863864
return 1;
864865
}
@@ -948,7 +949,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
948949

949950
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
950951
std::string response;
951-
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
952+
if (generate_response(llama_data, prompt, response, stdout_a_terminal, prev_len)) {
952953
return 1;
953954
}
954955

0 commit comments

Comments
 (0)