-
Notifications
You must be signed in to change notification settings - Fork 1
feat: implemented sampling for MTP #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi, thanks for this! I think this is similar to, though probably a cleaner version than, what I had before I changed the MTP sampler to a simple argmax. I'll tell you what I think, let me know to what extent this does/does not agree with your understanding:
|
Hi @F1LM1, thanks for the feedback! My bad that I didn't see your previous commits of a proper sampling implementation. Based on the discussion, I was thinking this was still on your to-do list. After digging into the code, I think I have a clearer picture. Regarding common_sampler_accept: I see my mistake. In the standard speculative mode with a separate draft model (ctx_dft), the accept call inside the draft function is used to maintain the state of the draft context for sequential token generation. But for MTP, since we share a single context and sampler, calling it prematurely would indeed pollute the main sampler's state before verification. Regarding the "Modify Logits + Greedy" Strategy: I'm already drafting an idea for that, which involves accessing the sample method to modify the logits and then getting the first candidate. This is my concept: // In common/speculative.cpp
llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {
if (!smpl) {
return -1;
}
llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, id_last, n_past, {0}, true);
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
}
cur_p->sorted = false;
common_sampler_apply_chain(smpl, cur_p);
const llama_token id = cur_p->data[0].id;
...
// In common/sampling.cpp
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
llama_sampler_apply(gsmpl->chain, cur_p);
} I'll test this out a bit, but in the meantime, I'm open to feedback. I also have another question: I'm looking at how to implement the
|
This seems reasonable to me, with luck it will show up as better acceptance rates when rep penalties are turned on :)
Haven't really started thinking about this. When I have some free time I plan to focus on seeing if we can do some basic optimizations like graph reuse and such. You're definitely welcome to work on this!
Frankly I don't know exactly what the multi-head case for MTP would look like, but my impression that you cannot MTP draft N tokens simply by autoregressively predicting with a single MTP head the way you can for a typical draft model. Rather, I believe that the number of MTP heads is a fixed feature of the model/weights, so that if you wanted to draft say N = 5 tokens at once the model would have to have at least N = 5 MTP layers/heads that would all produce outputs in a single forward pass of the full model (including MTP). I would've guessed that each MTP layer takes as input the previous layer's output embedding and its sampled token (to concatenate as an input embedding the way we do for the single MTP head here), and rather than having to run But if you find material to the contrary, I would absolutely love to see it. |
Hey! I tried with 52 requests, between 6k to 38k of context and was able to get an acceptance rate of ~0.5931 +/- 0.041 as compared with the previous rate of ~0.51 that I reported before. This was with the same settings (temp=1.0, DRY enabled) for writing. The latest commit includes these changes.
I was unable to find proper documentation or even discussion, only mentions to look at SGLang and VLLM, so I did look at how VLLM implemented and their approach is as follows: self.num_mtp_layers = config.num_nextn_predict_layers
self.layers = torch.nn.ModuleDict({
str(idx): Glm4MoeMultiTokenPredictorLayer(...)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
...
def forward(..., spec_step_idx: int = 0):
...
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](...) So if we pass for example My proposed plan now shifts to two steps:
|
Great, I've been away last couple of days but I'll give this a spin as well, sounds promising!
If I'm reading this correctly it looks like if num_mtp_layers = 1 then it will run the one MTP layer autoregressively, but if num_mtp_layers = 2 for example then it will alternate between the layers? That seems... odd, but I agree it can't hurt to match their implementation until we have an example of a model with num_mtp_layers > 1 to if it works. Hopefully we'll see decent draft acceptance at least for the "easy" cases (coding), and even if not, it's easy enough to just recommend choosing the N that ends up working best. |
Yes, the alternating layer logic seems odd. I felt the same way, especially since we've only seen models with a single MTP head, and the previous layers don't have the # in vllm/spec_decode/eagle.py
class EagleProposer:
...
def propose(self, ...):
...
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
for _ in range(self.num_speculative_tokens - 1):
# The input for this iteration is the token generated in the previous one.
input_ids = draft_token_ids_list[-1].int()
# Runs the model for a single step
last_hidden_states, hidden_states = self.model(...)
# Calculates logits and samples the next token (with argmax)
logits = self.model.compute_logits(last_hidden_states[:batch_size], None)
draft_token_ids = logits.argmax(dim=-1)
# Appends the new token for the next iteration
draft_token_ids_list.append(draft_token_ids) This is essentially what our |
As I commented on the other PR, I suspect that supporting KV cache for multi-token MTP drafts is going to be a significant step up in complexity, while the one token case we can piggyback on the existing KV cache system (since it sets aside cache for the single MTP layer already). I'll get a chance tomorrow to spin up this PR but I think this should represent the optimal sampling subroutine. If you're eager to finish it off, maybe start thinking about how we can make the setup more efficient by reusing stuff where possible (memory ctx? graphs? >1 size batches?), since we're basically recreating a bunch of stuff from scratch for every token. |
Okay, I'll take a look at what you suggested and find ways to store the state for the context and graph. Regarding the batch size part, if I understand correctly, you're referring to fixing the alternation between draft and main model tokens in the server's main loop. I agree that would be a great optimization, but it seems like it would take a while, change a lot of the server logic, and require extensive testing. I feel like that would be a good follow-up PR. It's more of a general feature to improve not only MTP but drafting in general, and giving it a separate PR would allow us to merge the core MTP implementation first. |
Nah, I meant some form of batching when we do the MTP layer prompt processing step, since we're likely going to process hundreds of tokens at once using the same graphs/memory context/etc. Right now we're building all 1-size batches, which just feels wrong. AFAIK, the alternation thing might be deceptively easy to fix. I suspect it could be as simple as just making sure we only do the non-speculative llama_decode step exactly once, i.e. immediately after prompt processing. I'll need to find my notes on this, but I'm pretty sure everything else is always correctly synced, at least for the MTP case. |
Ah, thanks for clarifying the batching part! I was looking at the main generation loop, but batching the MTP prompt processing makes sense. I'll keep that on my list of things to look into.
That's a great point about the alternation fix potentially being simple; I'm looking forward to your notes on that when you find them. In the meantime, I've been focused on your first suggestion: reusing resources during the single-token draft generation to avoid recreating everything from scratch. I've been experimenting with a My latest attempt was to build the graph once inside My diagnosis is that the It feels like a chicken-and-egg problem. To properly reuse the graph, we'd need to persist the scheduler, but to persist the scheduler, we'd need to persist the memory context, which is the core of the KV cache problem. I'm probably missing something due to my limited knowledge of the Any pointers would be be a huge help. |
I finally got a chance to test the improved sampler. It works well in my testing, raising draft acceptance rate in some "hard" writing scenarios by more than 10 percentage points on average, which is a clear and large gain. Ironically it ends up being slower in actual tok/s generation, presumably because the sampling chain as-is is inefficient, but let's see what we can do about that in follow-ups. Re: the graph reuse questions you mentioned above, I'll fire up the project again this weekend and see what I find. It's been a while since I dove in. |
Hi @F1LM1, I've been following your PR and decided to tackle one of the to-dos you mentioned: implementing proper sampling for the MTP draft model.
I've successfully implemented a solution that retrieves the full logits from the MTP and passes them to the sampler for the draft token generation. The code seems stable and is ready for your review.
Here are the key results from my testing:
Interestingly, when compiled in release mode, I achieved an average acceptance rate of 0.51 for creative tasks, which as you mentioned in your PR it was around of 0.4.
I tried to preserve your code and I'm open to suggestions you have for improvement.