Skip to content

Commit 2e1b88f

Browse files
committed
server : fix checkpoint logic to support recurrent caches
1 parent e1b68d8 commit 2e1b88f

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

tools/server/server.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3541,7 +3541,8 @@ struct server_context {
35413541
slot.n_past = 0;
35423542
}
35433543

3544-
const auto n_swa = llama_model_n_swa(model);
3544+
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+
const auto n_swa = std::max(1, llama_model_n_swa(model));
35453546

35463547
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
35473548
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -3552,15 +3553,16 @@ struct server_context {
35523553

35533554
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
35543555

3555-
if (pos_min > pos_min_thold + 1) {
3556+
if (pos_min > pos_min_thold) {
35563557
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
35573558

35583559
// search for a context checkpoint
35593560
const auto it = std::find_if(
35603561
slot.ctx_checkpoints.rbegin(),
35613562
slot.ctx_checkpoints.rend(),
35623563
[&](const auto & cur) {
3563-
return cur.pos_min <= pos_min_thold;
3564+
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3565+
return cur.pos_min < pos_min_thold;
35643566
}
35653567
);
35663568

@@ -3577,7 +3579,7 @@ struct server_context {
35773579
do_reset = true;
35783580
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35793581
} else {
3580-
slot.n_past = std::min(slot.n_past, it->pos_max);
3582+
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
35813583
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
35823584
}
35833585
}
@@ -3586,25 +3588,23 @@ struct server_context {
35863588
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
35873589
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
35883590
slot.n_past = 0;
3589-
slot.ctx_checkpoints.clear();
35903591
}
35913592
}
3592-
}
3593-
3594-
if (n_swa > 0) {
3595-
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
35963593

3597-
// erase any checkpoints with pos_min > pos_min_thold
3598-
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
3599-
const auto & cur = slot.ctx_checkpoints[i];
3600-
if (cur.pos_min > pos_min_thold) {
3601-
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
3602-
SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
3594+
{
3595+
// erase any checkpoints with pos_min > pos_min_thold
3596+
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
3597+
const auto & cur = slot.ctx_checkpoints[i];
3598+
if (cur.pos_min > pos_min_thold) {
3599+
SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
3600+
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
3601+
}
36033602
}
36043603
}
36053604
}
36063605
}
36073606

3607+
// [TAG_PROMPT_LOGITS]
36083608
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
36093609
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
36103610
slot.n_past--;

0 commit comments

Comments
 (0)