@@ -3541,7 +3541,8 @@ struct server_context {
35413541 slot.n_past = 0 ;
35423542 }
35433543
3544- const auto n_swa = llama_model_n_swa (model);
3544+ // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+ const auto n_swa = std::max (1 , llama_model_n_swa (model));
35453546
35463547 if (slot.n_past > 0 && slot.n_past < (int ) slot.cache_tokens .size ()) {
35473548 const auto pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
@@ -3552,15 +3553,16 @@ struct server_context {
35523553
35533554 const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
35543555
3555- if (pos_min > pos_min_thold + 1 ) {
3556+ if (pos_min > pos_min_thold) {
35563557 SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " , slot.n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
35573558
35583559 // search for a context checkpoint
35593560 const auto it = std::find_if (
35603561 slot.ctx_checkpoints .rbegin (),
35613562 slot.ctx_checkpoints .rend (),
35623563 [&](const auto & cur) {
3563- return cur.pos_min <= pos_min_thold;
3564+ // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3565+ return cur.pos_min < pos_min_thold;
35643566 }
35653567 );
35663568
@@ -3577,7 +3579,7 @@ struct server_context {
35773579 do_reset = true ;
35783580 // printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35793581 } else {
3580- slot.n_past = std::min (slot.n_past , it->pos_max );
3582+ slot.n_past = std::min (slot.n_past , std::max ( it->pos_min + 1 , it-> pos_max ) );
35813583 SLT_WRN (slot, " restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
35823584 }
35833585 }
@@ -3586,25 +3588,23 @@ struct server_context {
35863588 SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n " ,
35873589 " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
35883590 slot.n_past = 0 ;
3589- slot.ctx_checkpoints .clear ();
35903591 }
35913592 }
3592- }
3593-
3594- if (n_swa > 0 ) {
3595- const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
35963593
3597- // erase any checkpoints with pos_min > pos_min_thold
3598- for (int i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599- const auto & cur = slot.ctx_checkpoints [i];
3600- if (cur.pos_min > pos_min_thold) {
3601- slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
3602- SLT_WRN (slot, " erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
3594+ {
3595+ // erase any checkpoints with pos_min > pos_min_thold
3596+ for (int i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
3597+ const auto & cur = slot.ctx_checkpoints [i];
3598+ if (cur.pos_min > pos_min_thold) {
3599+ SLT_WRN (slot, " erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
3600+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
3601+ }
36033602 }
36043603 }
36053604 }
36063605 }
36073606
3607+ // [TAG_PROMPT_LOGITS]
36083608 if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0 ) {
36093609 SLT_WRN (slot, " need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n " , slot.n_past , slot.n_prompt_tokens );
36103610 slot.n_past --;
0 commit comments