@@ -317,10 +317,32 @@ struct llama_server_context
317317 return true ;
318318 }
319319
320+ void truncatePrompt (std::vector<llama_token> &prompt_tokens) {
321+ const int n_left = n_ctx - params.n_keep ;
322+ const int n_block_size = n_left / 2 ;
323+ const int erased_blocks = (prompt_tokens.size () - params.n_keep - n_block_size) / n_block_size;
324+
325+ // Keep n_keep tokens at start of prompt (at most n_ctx - 4)
326+ std::vector<llama_token> new_tokens (prompt_tokens.begin (), prompt_tokens.begin () + params.n_keep );
327+
328+ new_tokens.insert (new_tokens.end (), prompt_tokens.begin () + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end ());
329+
330+ LOG_VERBOSE (" input truncated" , {
331+ {" n_ctx" , n_ctx},
332+ {" n_keep" , params.n_keep },
333+ {" n_left" , n_left},
334+ {" new_tokens" , tokens_to_str (ctx, new_tokens.cbegin (), new_tokens.cend ())},
335+ {" num_prompt_tokens" , new_tokens.size ()}
336+ });
337+
338+ truncated = true ;
339+ prompt_tokens = new_tokens;
340+ }
341+
320342 void loadInfill ()
321343 {
322344 bool suff_rm_leading_spc = true ;
323- if (params.input_suffix .find_first_of (" " ) == 0 && params.input_suffix .size () > 1 ) {
345+ if (params.input_suffix .find_first_of (' ' ) == 0 && params.input_suffix .size () > 1 ) {
324346 params.input_suffix .erase (0 , 1 );
325347 suff_rm_leading_spc = false ;
326348 }
@@ -336,6 +358,7 @@ struct llama_server_context
336358 prefix_tokens.insert (prefix_tokens.end (), llama_token_suffix (ctx));
337359 prefix_tokens.insert (prefix_tokens.end (), suffix_tokens.begin (), suffix_tokens.end ());
338360 prefix_tokens.push_back (llama_token_middle (ctx));
361+
339362 auto prompt_tokens = prefix_tokens;
340363
341364 num_prompt_tokens = prompt_tokens.size ();
@@ -347,31 +370,18 @@ struct llama_server_context
347370 params.n_keep = std::min (params.n_ctx - 4 , params.n_keep );
348371
349372 // if input prompt is too big, truncate like normal
350- if (num_prompt_tokens >= (size_t )params. n_ctx )
373+ if (num_prompt_tokens >= (size_t ) n_ctx)
351374 {
352- printf (" Input prompt is too big, truncating. Can only take %d tokens but got %zu\n " , params.n_ctx , num_prompt_tokens);
353- // todo we probably want to cut from both sides
354- const int n_left = (params.n_ctx - params.n_keep ) / 2 ;
355- std::vector<llama_token> new_tokens (prompt_tokens.begin (), prompt_tokens.begin () + params.n_keep );
356- const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1 ) / n_left;
357- new_tokens.insert (new_tokens.end (), prompt_tokens.begin () + params.n_keep + erased_blocks * n_left, prompt_tokens.end ());
358- std::copy (prompt_tokens.end () - params.n_ctx , prompt_tokens.end (), ctx_sampling->prev .begin ());
375+ truncatePrompt (prompt_tokens);
376+ num_prompt_tokens = prompt_tokens.size ();
359377
360- LOG_VERBOSE (" input truncated" , {
361- {" n_ctx" , params.n_ctx },
362- {" n_keep" , params.n_keep },
363- {" n_left" , n_left},
364- {" new_tokens" , tokens_to_str (ctx, new_tokens.cbegin (), new_tokens.cend ())},
365- });
366-
367- truncated = true ;
368- prompt_tokens = new_tokens;
378+ GGML_ASSERT (num_prompt_tokens < (size_t )n_ctx);
369379 }
370- else
380+
381+ // push the prompt into the sampling context (do not apply grammar)
382+ for (auto & token : prompt_tokens)
371383 {
372- const size_t ps = num_prompt_tokens;
373- std::fill (ctx_sampling->prev .begin (), ctx_sampling->prev .end () - ps, 0 );
374- std::copy (prompt_tokens.begin (), prompt_tokens.end (), ctx_sampling->prev .end () - ps);
384+ llama_sampling_accept (ctx_sampling, ctx, token, false );
375385 }
376386
377387 // compare the evaluated prompt with the new prompt
@@ -409,29 +419,18 @@ struct llama_server_context
409419 params.n_keep = std::min (n_ctx - 4 , params.n_keep );
410420
411421 // if input prompt is too big, truncate like normal
412- if (num_prompt_tokens >= (size_t )n_ctx)
422+ if (num_prompt_tokens >= (size_t ) n_ctx)
413423 {
414- const int n_left = (n_ctx - params.n_keep ) / 2 ;
415- std::vector<llama_token> new_tokens (prompt_tokens.begin (), prompt_tokens.begin () + params.n_keep );
416- const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1 ) / n_left;
417- new_tokens.insert (new_tokens.end (), prompt_tokens.begin () + params.n_keep + erased_blocks * n_left, prompt_tokens.end ());
418- std::copy (prompt_tokens.end () - n_ctx, prompt_tokens.end (), ctx_sampling->prev .begin ());
424+ truncatePrompt (prompt_tokens);
425+ num_prompt_tokens = prompt_tokens.size ();
419426
420- LOG_VERBOSE (" input truncated" , {
421- {" n_ctx" , n_ctx},
422- {" n_keep" , params.n_keep },
423- {" n_left" , n_left},
424- {" new_tokens" , tokens_to_str (ctx, new_tokens.cbegin (), new_tokens.cend ())},
425- });
426-
427- truncated = true ;
428- prompt_tokens = new_tokens;
427+ GGML_ASSERT (num_prompt_tokens < (size_t )n_ctx);
429428 }
430- else
429+
430+ // push the prompt into the sampling context (do not apply grammar)
431+ for (auto & token : prompt_tokens)
431432 {
432- const size_t ps = num_prompt_tokens;
433- std::fill (ctx_sampling->prev .begin (), ctx_sampling->prev .end () - ps, 0 );
434- std::copy (prompt_tokens.begin (), prompt_tokens.end (), ctx_sampling->prev .end () - ps);
433+ llama_sampling_accept (ctx_sampling, ctx, token, false );
435434 }
436435
437436 // compare the evaluated prompt with the new prompt
@@ -542,7 +541,7 @@ struct llama_server_context
542541 result.probs .push_back ({cur_p.data [i].id , cur_p.data [i].p });
543542 }
544543
545- llama_sampling_accept (ctx_sampling, ctx, result.tok );
544+ llama_sampling_accept (ctx_sampling, ctx, result.tok , true );
546545
547546 if (tg) {
548547 num_tokens_predicted++;
0 commit comments