-
Notifications
You must be signed in to change notification settings - Fork 13.3k
implement context checkpointing for hybrid and recurrent models #16382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
this extends `llama-server`'s SWA checkpointing logic to include hybrid/recurrent models such as Jamba, Granite
include/llama.h
Outdated
size_t * n_token_count_out); | ||
|
||
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 | ||
#define LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a bit better name than this. The old name does not work, but the proposed new name is confusing.
The purpose of this flag is to indicate that we want save only the "small" caches such as SWA, "recr", etc. But I can't think of a good name to call it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean. I can't think of anything better at the moment
tools/server/server.cpp
Outdated
// make a checkpoint of the parts of memory that cannot be rolled back. | ||
// checkpoints are needed only if: | ||
// - the model uses SWA and we are not using `swa_full` | ||
// - the model architecture is marked as recurrent or hybrid | ||
bool do_checkpoint = (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || | ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit torn on this logic for determining when to do checkpoints. It should be centred around the memory module or the context, rather than the model.
Just making a note for the future - no need to change anything in this PR.
Co-authored-by: Georgi Gerganov <[email protected]>
Hmm, maybe I should mark this as draft. Because sometimes it works, and sometimes I still see this: srv log_server_r: request: POST /v1/chat/completions 192.168.68.69 200
srv params_from_: Chat format: Content-only
slot get_availabl: id 0 | task 595 | selected slot by lcs similarity, lcs_len = 872, similarity = 0.395 (> 0.100 thold)
slot launch_slot_: id 0 | task 1859 | processing task
slot update_slots: id 0 | task 1859 | new prompt, n_ctx_slot = 262144, n_keep = 0, n_prompt_tokens = 4084
slot update_slots: id 0 | task 1859 | n_past = 872, cache_tokens.size() = 2205, seq_id = 0, pos_min = 2204, n_swa = 0
slot update_slots: id 0 | task 1859 | forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
slot update_slots: id 0 | task 1859 | n_past = 0
slot update_slots: id 0 | task 1859 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.501469
slot update_slots: id 0 | task 1859 | n_past = 2048
slot update_slots: id 0 | task 1859 | prompt processing progress, n_past = 4084, n_tokens = 2036, progress = 1.000000
slot update_slots: id 0 | task 1859 | prompt done, n_past = 4084, n_tokens = 2036
slot update_slots: id 0 | task 1859 | saved context checkpoint 1 of 32 (pos_min = 4083, pos_max = 4083, size = 16.626 MiB)
slot release: id 0 | task 1859 | stop processing: n_past = 4121, truncated = 0
slot print_timing: id 0 | task 1859 |
prompt eval time = 9546.70 ms / 4084 tokens ( 2.34 ms per token, 427.79 tokens per second)
eval time = 4865.95 ms / 38 tokens ( 128.05 ms per token, 7.81 tokens per second)
total time = 14412.65 ms / 4122 tokens
srv update_slots: all slots are idle
srv log_server_r: request: POST /v1/chat/completions 192.168.68.69 200
Or maybe this is unavoidable in some cases, I'm not sure. |
FYI I am using this Q8_0 quant of Jamba Mini to test this PR |
I don't think there were any previous checkpoints in this case. But it's unclear since we don't see the preceding logs. |
While testing with a multi-turn conversation, the model seems to get increasingly confused as time goes on. I don't know what could be causing it, but it sure seems like something is broken. If you'd like I can mark this as draft. Here are the full console logs from that conversation: full_jamba_mini_console_logs.txt |
In that case I believe there was, but as soon as |
I think the checkpointing logic is good, but maybe there is an issue with saving/restoring the recurrent state. Try to adapt/run the |
I guess fixes #15677 ? :> |
@ddh0 Try to apply this patch on top of this PR: diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index db1f6d1aa..0edffa22f 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -3552,7 +3552,7 @@ struct server_context {
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
- if (pos_min > pos_min_thold) {
+ if (pos_min > pos_min_thold + 1) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint I think this works. It is not perfectly optimal in the sense that it will always reprocess the last response. But at least it will keep all the conversation up to the penultimate response. |
Co-authored-by: Georgi Gerganov <[email protected]>
Thank you, I've applied this change now.
Thanks, I've added this PR as well as #14625 to the main post. |
@ddh0 Can you confirm the quality is good now? |
The quality is fine until I re-generate from the same point in the context (by pressing the 🔄 button in llama-server). Very often (but not every single time) when I re-generate it's like the model gets confused and gives an irrelevant response. Example:
There must be a bug somewhere still. Let me look over the code again and see if I can find anything suspicious. (This is as of commit 126e08a) |
Thanks, I think I see the issue and have a fix. But there is a bug in the WebUI which is making this too difficult to debug: #16385 Let's come back to this after it is resolved. |
@ggerganov I've pushed #16402 which fixes #16385 |
2e1b88f
to
829c701
Compare
@ddh0 I just force pushed a patch to your branch that I think should work correctly. Let me know if you spot any problems. |
Here are the console logs for a simple chat with Jamba Mini, only 3 messages, without any re-generations or editing. As you mentioned, it seems like it always needs to re-process the last response. And for the very first message, it can still trigger I will do some more testing between this and |
@ddh0 Thanks for implementing and testing. I'll push a few more changes to this branch if that's ok and merge it. |
Of course :) please let me know if you'd like me to do any more testing before it's merged in |
This PR generalizes the SWA checkpointing logic (ref #15293) to also create checkpoints for recurrent and hybrid models such as Mamba, Jamba, etc.
--swa-checkpoints
CLI arg is renamed to--ctx-checkpoints
LLAMA_STATE_SEQ_FLAGS_SWA_ONLY
flag is renamed toLLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY
llama_model_is_hybrid
tollama-model.cpp
andllama.h
This removes the need to re-process the entire context in the majority of cases.
Would resolve #15677 and #14625
Make sure to read the contributing guidelines before submitting a PR