Skip to content

Conversation

ddh0
Copy link
Contributor

@ddh0 ddh0 commented Oct 2, 2025

This PR generalizes the SWA checkpointing logic (ref #15293) to also create checkpoints for recurrent and hybrid models such as Mamba, Jamba, etc.

  • SWA-specific parts of the code are generalized:
    • the --swa-checkpoints CLI arg is renamed to --ctx-checkpoints
    • the internal LLAMA_STATE_SEQ_FLAGS_SWA_ONLY flag is renamed to LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY
  • adds llama_model_is_hybrid to llama-model.cpp and llama.h

This removes the need to re-process the entire context in the majority of cases.

Would resolve #15677 and #14625

Make sure to read the contributing guidelines before submitting a PR

ddh0 and others added 5 commits October 1, 2025 23:14
include/llama.h Outdated
size_t * n_token_count_out);

#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
#define LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a bit better name than this. The old name does not work, but the proposed new name is confusing.

The purpose of this flag is to indicate that we want save only the "small" caches such as SWA, "recr", etc. But I can't think of a good name to call it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. I can't think of anything better at the moment

Comment on lines 3856 to 3862
// make a checkpoint of the parts of memory that cannot be rolled back.
// checkpoints are needed only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
bool do_checkpoint = (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit torn on this logic for determining when to do checkpoints. It should be centred around the memory module or the context, rather than the model.

Just making a note for the future - no need to change anything in this PR.

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 2, 2025

Hmm, maybe I should mark this as draft. Because sometimes it works, and sometimes I still see this:

srv  log_server_r: request: POST /v1/chat/completions 192.168.68.69 200
srv  params_from_: Chat format: Content-only
slot get_availabl: id  0 | task 595 | selected slot by lcs similarity, lcs_len = 872, similarity = 0.395 (> 0.100 thold)
slot launch_slot_: id  0 | task 1859 | processing task
slot update_slots: id  0 | task 1859 | new prompt, n_ctx_slot = 262144, n_keep = 0, n_prompt_tokens = 4084
slot update_slots: id  0 | task 1859 | n_past = 872, cache_tokens.size() = 2205, seq_id = 0, pos_min = 2204, n_swa = 0
slot update_slots: id  0 | task 1859 | forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
slot update_slots: id  0 | task 1859 | n_past = 0
slot update_slots: id  0 | task 1859 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.501469
slot update_slots: id  0 | task 1859 | n_past = 2048
slot update_slots: id  0 | task 1859 | prompt processing progress, n_past = 4084, n_tokens = 2036, progress = 1.000000
slot update_slots: id  0 | task 1859 | prompt done, n_past = 4084, n_tokens = 2036
slot update_slots: id  0 | task 1859 | saved context checkpoint 1 of 32 (pos_min = 4083, pos_max = 4083, size = 16.626 MiB)
slot      release: id  0 | task 1859 | stop processing: n_past = 4121, truncated = 0
slot print_timing: id  0 | task 1859 | 
prompt eval time =    9546.70 ms /  4084 tokens (    2.34 ms per token,   427.79 tokens per second)
       eval time =    4865.95 ms /    38 tokens (  128.05 ms per token,     7.81 tokens per second)
      total time =   14412.65 ms /  4122 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 192.168.68.69 200

Or maybe this is unavoidable in some cases, I'm not sure.

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 2, 2025

FYI I am using this Q8_0 quant of Jamba Mini to test this PR

@ggerganov
Copy link
Member

saved context checkpoint 1 of 32

I don't think there were any previous checkpoints in this case. But it's unclear since we don't see the preceding logs.

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 2, 2025

While testing with a multi-turn conversation, the model seems to get increasingly confused as time goes on. I don't know what could be causing it, but it sure seems like something is broken. If you'd like I can mark this as draft.

Here are the full console logs from that conversation: full_jamba_mini_console_logs.txt

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 2, 2025

saved context checkpoint 1 of 32

I don't think there were any previous checkpoints in this case. But it's unclear since we don't see the preceding logs.

In that case I believe there was, but as soon as forcing full prompt re-processing due to lack of cache data is triggered then it invalidates all the checkpoints anyway.

@ggerganov
Copy link
Member

I think the checkpointing logic is good, but maybe there is an issue with saving/restoring the recurrent state. Try to adapt/run the llama-save-load-state and see if it runs correctly with mamba/jamba architectures in order to confirm that state save/load works correct.

@pwilkin
Copy link
Collaborator

pwilkin commented Oct 2, 2025

I guess fixes #15677 ? :>

@ggerganov
Copy link
Member

@ddh0 Try to apply this patch on top of this PR:

diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index db1f6d1aa..0edffa22f 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -3552,7 +3552,7 @@ struct server_context {
 
                                 const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
 
-                                if (pos_min > pos_min_thold) {
+                                if (pos_min > pos_min_thold + 1) {
                                     SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
 
                                     // search for a context checkpoint

I think this works. It is not perfectly optimal in the sense that it will always reprocess the last response. But at least it will keep all the conversation up to the penultimate response.

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 2, 2025

I think this works. It is not perfectly optimal in the sense that it will always reprocess the last response. But at least it will keep all the conversation up to the penultimate response.

Thank you, I've applied this change now.

I guess fixes #15677 ? :>

Thanks, I've added this PR as well as #14625 to the main post.

@ggerganov
Copy link
Member

@ddh0 Can you confirm the quality is good now?

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 2, 2025

Can you confirm the quality is good now?

The quality is fine until I re-generate from the same point in the context (by pressing the 🔄 button in llama-server). Very often (but not every single time) when I re-generate it's like the model gets confused and gives an irrelevant response.

Example:

User: What is the square root of 81?
AI: The square root of 81 is 9.
User: Please repeat that?
AI (initial response): The square root of 81 is 9.
AI (🔄1): Hello there, I'm your personal AI assistant. How can I assist you?
AI (🔄2): The square root of 81 is 9.

To find the square root of a number, you can follow these steps:

  1. Look at the number and see if it's a perfect square. For 81, you can check if it's 9 times 9.
  2. If it's not a perfect square, you can use a calculator to find the square root.
  3. Alternatively, you can use a method like long division or factorization to find the square root.

In this case, 81 is a perfect square, so the square root is 9.

AI (🔄3): The square root of 81 is 9.

There must be a bug somewhere still. Let me look over the code again and see if I can find anything suspicious.

(This is as of commit 126e08a)

@ggerganov
Copy link
Member

Thanks, I think I see the issue and have a fix. But there is a bug in the WebUI which is making this too difficult to debug: #16385

Let's come back to this after it is resolved.

@allozaur
Copy link
Collaborator

allozaur commented Oct 3, 2025

Thanks, I think I see the issue and have a fix. But there is a bug in the WebUI which is making this too difficult to debug: #16385

Let's come back to this after it is resolved.

@ggerganov I've pushed #16402 which fixes #16385

@ggerganov ggerganov force-pushed the mamba-checkpoints-3 branch from 2e1b88f to 829c701 Compare October 3, 2025 08:15
@ggerganov
Copy link
Member

@ddh0 I just force pushed a patch to your branch that I think should work correctly. Let me know if you spot any problems.

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 3, 2025

Here are the console logs for a simple chat with Jamba Mini, only 3 messages, without any re-generations or editing.

As you mentioned, it seems like it always needs to re-process the last response. And for the very first message, it can still trigger forcing full prompt re-processing due to lack of cache data, but it looks like the problem goes away for all subsequent messages in the chat.

I will do some more testing between this and master to see if I can find any difference, as well as testing with longer conversations. But this seems to be working! 🥳

@ggerganov
Copy link
Member

@ddh0 Thanks for implementing and testing.

I'll push a few more changes to this branch if that's ok and merge it.

@ddh0
Copy link
Contributor Author

ddh0 commented Oct 3, 2025

Thanks for implementing and testing.

Of course :) please let me know if you'd like me to do any more testing before it's merged in

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Eval bug: Nemotron v2 Nano always reprocesses prompt
4 participants