Skip to content

Conversation

SamuelOliveirads
Copy link

Hi @F1LM1, I've been following your PR and decided to tackle one of the to-dos you mentioned: implementing proper sampling for the MTP draft model.

I've successfully implemented a solution that retrieves the full logits from the MTP and passes them to the sampler for the draft token generation. The code seems stable and is ready for your review.

Here are the key results from my testing:

  • Greedy Decoding (Temp 0.8): My implementation performed identically to the original PR, with an acceptance rate of 0.756.
  • Creative Sampling (Temp 2.0): The performance in my debug tests was initially worse for creativity (acceptance rate of 0.566 vs. your original 0.670). I believe this is because allowing the MTP more freedom to choose a token via sampling can increase the divergence from the main model, leading to lower acceptance.

Interestingly, when compiled in release mode, I achieved an average acceptance rate of 0.51 for creative tasks, which as you mentioned in your PR it was around of 0.4.

I tried to preserve your code and I'm open to suggestions you have for improvement.

@F1LM1
Copy link
Owner

F1LM1 commented Sep 4, 2025

Hi, thanks for this! I think this is similar to, though probably a cleaner version than, what I had before I changed the MTP sampler to a simple argmax. I'll tell you what I think, let me know to what extent this does/does not agree with your understanding:

  • We probably do not want to call common_sampler_accept(smpl, ...) on the draft token, my understanding is that this causes us to modify our sampler -- which is the same object as the base model's sampler! -- as if the token was actually sampled, but we don't actually know if the draft token will be accepted yet. Plus, if the draft token actually gets accepted, then we'll call common_sampler_accept(smpl, ...) on it again and double-count the token
  • The optimal sampling strat on the MTP logits is to use samplers that transform the input logits (act pre-softmax), but then to do a greedy sampling after that. Let's say the MTP head was "perfect," i.e. on step N it returns the exact logits that the main model would return on step N+1. A full stochastic sampling is just inserting noise, by definition greedy sampling on the correct post-transformation logits gets you the single token that has the highest probability of being sampled by the main model on step N+1. This is why I switched to using a greedy sampling in my code. But sampling the pre-transform (from things like repetition, presence, DRY, etc.) is also incorrect. I haven't really looked into the sampler code so I don't know if there's a way to do this... maybe a hacky method would be to set the sampler temp to 0, then call common_sampler_sample, then set the temp back. We probably do want to use the same sampler object as the main model so that we carry over the correct states for stuff like repetition, presence, DRY.

@SamuelOliveirads
Copy link
Author

Hi, thanks for this! I think this is similar to, though probably a cleaner version than, what I had before I changed the MTP sampler to a simple argmax. I'll tell you what I think, let me know to what extent this does/does not agree with your understanding:

  • We probably do not want to call common_sampler_accept(smpl, ...) on the draft token, my understanding is that this causes us to modify our sampler -- which is the same object as the base model's sampler! -- as if the token was actually sampled, but we don't actually know if the draft token will be accepted yet. Plus, if the draft token actually gets accepted, then we'll call common_sampler_accept(smpl, ...) on it again and double-count the token
  • The optimal sampling strat on the MTP logits is to use samplers that transform the input logits (act pre-softmax), but then to do a greedy sampling after that. Let's say the MTP head was "perfect," i.e. on step N it returns the exact logits that the main model would return on step N+1. A full stochastic sampling is just inserting noise, by definition greedy sampling on the correct post-transformation logits gets you the single token that has the highest probability of being sampled by the main model on step N+1. This is why I switched to using a greedy sampling in my code. But sampling the pre-transform (from things like repetition, presence, DRY, etc.) is also incorrect. I haven't really looked into the sampler code so I don't know if there's a way to do this... maybe a hacky method would be to set the sampler temp to 0, then call common_sampler_sample, then set the temp back. We probably do want to use the same sampler object as the main model so that we carry over the correct states for stuff like repetition, presence, DRY.

Hi @F1LM1, thanks for the feedback! My bad that I didn't see your previous commits of a proper sampling implementation. Based on the discussion, I was thinking this was still on your to-do list.

After digging into the code, I think I have a clearer picture.

Regarding common_sampler_accept: I see my mistake. In the standard speculative mode with a separate draft model (ctx_dft), the accept call inside the draft function is used to maintain the state of the draft context for sequential token generation. But for MTP, since we share a single context and sampler, calling it prematurely would indeed pollute the main sampler's state before verification.

Regarding the "Modify Logits + Greedy" Strategy: I'm already drafting an idea for that, which involves accessing the sample method to modify the logits and then getting the first candidate.

This is my concept:

// In common/speculative.cpp
llama_token mtp_speculative_gen_draft(
    struct common_sampler* smpl,
    struct llama_context* ctx,
    llama_token id_last,
    int32_t n_past,
    int32_t last_tok_idx) {

    if (!smpl) {
        return -1;
    }

    llama_batch batch = llama_batch_init(1, 0, 1);
    common_batch_add(batch, id_last, n_past, {0}, true);

    llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);

    const llama_model * model = llama_get_model(ctx);
    const llama_vocab * vocab = llama_model_get_vocab(model);
    const int n_vocab = llama_n_vocab(vocab);

    llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);

    cur_p->size = n_vocab;
    for (int i = 0; i < n_vocab; ++i) {
        cur_p->data[i].id = i;
        cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
    }
    cur_p->sorted = false;

    common_sampler_apply_chain(smpl, cur_p);

    const llama_token id = cur_p->data[0].id;

...
// In common/sampling.cpp
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
    llama_sampler_apply(gsmpl->chain, cur_p);
}

I'll test this out a bit, but in the meantime, I'm open to feedback.

I also have another question: I'm looking at how to implement the n+2 tokens for MTP and I don't know how far you are on that. If you'd like, I can also try some concepts. In my mind, this is my plan:

  • a new function, probably in llama-context.cpp, called llama_mtp_kv_cache_update_token to run the MTP graph and generate only one token to ensure that the KV Cache is updated.
  • Refactor mtp_update_kv_cache to call llama_mtp_kv_cache_update_token.
  • Refactor mtp_speculative_gen_draft to loop the function llama_build_and_execute_mtp_graph, apply the sampling, and get the temp 0 draft. The concept of the loop could be similar to common_speculative_gen_draft.

@F1LM1
Copy link
Owner

F1LM1 commented Sep 5, 2025

Regarding the "Modify Logits + Greedy" Strategy: I'm already drafting an idea for that, which involves accessing the sample method to modify the logits and then getting the first candidate.
...
I'll test this out a bit, but in the meantime, I'm open to feedback.

This seems reasonable to me, with luck it will show up as better acceptance rates when rep penalties are turned on :)

I also have another question: I'm looking at how to implement the n+2 tokens for MTP and I don't know how far you are on that. If you'd like, I can also try some concepts. In my mind, this is my plan:

Haven't really started thinking about this. When I have some free time I plan to focus on seeing if we can do some basic optimizations like graph reuse and such. You're definitely welcome to work on this!

* a new function, probably in llama-context.cpp, called `llama_mtp_kv_cache_update_token` to run the MTP graph and generate only one token to ensure that the KV Cache is updated.

* Refactor `mtp_update_kv_cache` to call `llama_mtp_kv_cache_update_token`.

* Refactor `mtp_speculative_gen_draft` to loop the function `llama_build_and_execute_mtp_graph`, apply the sampling, and get the `temp 0` draft. The concept of the loop could be similar to `common_speculative_gen_draft`.

Frankly I don't know exactly what the multi-head case for MTP would look like, but my impression that you cannot MTP draft N tokens simply by autoregressively predicting with a single MTP head the way you can for a typical draft model. Rather, I believe that the number of MTP heads is a fixed feature of the model/weights, so that if you wanted to draft say N = 5 tokens at once the model would have to have at least N = 5 MTP layers/heads that would all produce outputs in a single forward pass of the full model (including MTP). I would've guessed that each MTP layer takes as input the previous layer's output embedding and its sampled token (to concatenate as an input embedding the way we do for the single MTP head here), and rather than having to run llama_build_and_execute_mtp_graph in a loop you should only need to run it once and just collect the N outputs.

But if you find material to the contrary, I would absolutely love to see it.

@SamuelOliveirads
Copy link
Author

This seems reasonable to me, with luck it will show up as better acceptance rates when rep penalties are turned on :)

Hey! I tried with 52 requests, between 6k to 38k of context and was able to get an acceptance rate of ~0.5931 +/- 0.041 as compared with the previous rate of ~0.51 that I reported before. This was with the same settings (temp=1.0, DRY enabled) for writing. The latest commit includes these changes.

Frankly I don't know exactly what the multi-head case for MTP would look like, but my impression that you cannot MTP draft N tokens simply by autoregressively predicting with a single MTP head the way you can for a typical draft model. Rather, I believe that the number of MTP heads is a fixed feature of the model/weights, so that if you wanted to draft say N = 5 tokens at once the model would have to have at least N = 5 MTP layers/heads that would all produce outputs in a single forward pass of the full model (including MTP). I would've guessed that each MTP layer takes as input the previous layer's output embedding and its sampled token (to concatenate as an input embedding the way we do for the single MTP head here), and rather than having to run llama_build_and_execute_mtp_graph in a loop you should only need to run it once and just collect the N outputs.

But if you find material to the contrary, I would absolutely love to see it.

I was unable to find proper documentation or even discussion, only mentions to look at SGLang and VLLM, so I did look at how VLLM implemented and their approach is as follows:

self.num_mtp_layers = config.num_nextn_predict_layers

self.layers = torch.nn.ModuleDict({

    str(idx): Glm4MoeMultiTokenPredictorLayer(...)

    for idx in range(self.mtp_start_layer_idx,

                     self.mtp_start_layer_idx + self.num_mtp_layers)

})

...

def forward(..., spec_step_idx: int = 0):

    ...

    current_step_idx = (spec_step_idx % self.num_mtp_layers)

    return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](...)

So if we pass for example num_nextn_predict_layers = 3 they will create 3 heads. To generate the draft, they use an autoregressive loop, but for each step, it uses a different head for prediction. SGLang is similar, it has a Glm4MoeDecoderLayer with a boolean called is_nextn to choose if the layer will be MTP or not. So yeah, you are right about just running llama_build_and_execute_mtp_graph once but using different heads.

My proposed plan now shifts to two steps:

  1. Modify build_mtp_graph to loop through the N MTP heads (model.layers[n_layer - num_mtp_layers + i]). It will build a single, larger computation graph that chains the output of MTP_head_i as the input for MTP_head_i+1.
  2. Now that the graph has all the logits outputs, llama_build_and_execute_mtp_graph will execute the graph only once and will copy the N logits to the main context's logit buffer.

mtp_speculative_gen_draft will continue to do the same as before: call the graph, collect the logits, and apply the logits + greedy sampling approach.

@F1LM1
Copy link
Owner

F1LM1 commented Sep 8, 2025

This seems reasonable to me, with luck it will show up as better acceptance rates when rep penalties are turned on :)

Hey! I tried with 52 requests, between 6k to 38k of context and was able to get an acceptance rate of ~0.5931 +/- 0.041 as compared with the previous rate of ~0.51 that I reported before. This was with the same settings (temp=1.0, DRY enabled) for writing. The latest commit includes these changes.

Great, I've been away last couple of days but I'll give this a spin as well, sounds promising!

I was unable to find proper documentation or even discussion, only mentions to look at SGLang and VLLM, so I did look at how VLLM implemented and their approach is as follows:

self.num_mtp_layers = config.num_nextn_predict_layers

self.layers = torch.nn.ModuleDict({

    str(idx): Glm4MoeMultiTokenPredictorLayer(...)

    for idx in range(self.mtp_start_layer_idx,

                     self.mtp_start_layer_idx + self.num_mtp_layers)

})

...

def forward(..., spec_step_idx: int = 0):

    ...

    current_step_idx = (spec_step_idx % self.num_mtp_layers)

    return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](...)

So if we pass for example num_nextn_predict_layers = 3 they will create 3 heads. To generate the draft, they use an autoregressive loop, but for each step, it uses a different head for prediction. SGLang is similar, it has a Glm4MoeDecoderLayer with a boolean called is_nextn to choose if the layer will be MTP or not. So yeah, you are right about just running llama_build_and_execute_mtp_graph once but using different heads.

My proposed plan now shifts to two steps:

1. Modify `build_mtp_graph` to loop through the N MTP heads (`model.layers[n_layer - num_mtp_layers + i]`). It will build a single, larger computation graph that chains the output of `MTP_head_i` as the input for `MTP_head_i+1`.

2. Now that the graph has all the logits outputs, `llama_build_and_execute_mtp_graph` will execute the graph only once and will copy the N logits to the main context's logit buffer.

mtp_speculative_gen_draft will continue to do the same as before: call the graph, collect the logits, and apply the logits + greedy sampling approach.

If I'm reading this correctly it looks like if num_mtp_layers = 1 then it will run the one MTP layer autoregressively, but if num_mtp_layers = 2 for example then it will alternate between the layers? That seems... odd, but I agree it can't hurt to match their implementation until we have an example of a model with num_mtp_layers > 1 to if it works. Hopefully we'll see decent draft acceptance at least for the "easy" cases (coding), and even if not, it's easy enough to just recommend choosing the N that ends up working best.

@SamuelOliveirads
Copy link
Author

If I'm reading this correctly it looks like if num_mtp_layers = 1 then it will run the one MTP layer autoregressively, but if num_mtp_layers = 2 for example then it will alternate between the layers? That seems... odd, but I agree it can't hurt to match their implementation until we have an example of a model with num_mtp_layers > 1 to if it works. Hopefully we'll see decent draft acceptance at least for the "easy" cases (coding), and even if not, it's easy enough to just recommend choosing the N that ends up working best.

Yes, the alternating layer logic seems odd. I felt the same way, especially since we've only seen models with a single MTP head, and the previous layers don't have the nextN weights.
That feeling made me dive deeper into the vLLM and SGLang implementations. It turns out both effectively use an "Eagle worker" as a "proposer" pattern (which is compatible with DeepSeek-like models) to generate the draft tokens.
In my previous example, if a model had num_mtp_layers = 3, the spec_step_idx % 3 logic would indeed alternate between three different MTP layers. But since the models we have only have one mtp layer, it effectively just loops on the same head.
The key was finding the code that calls the forward function. The actual autoregressive generation doesn't happen inside the model graph, but in the worker that drives it. For example, in vLLM's Eagle proposer, you find this explicit loop:

# in vllm/spec_decode/eagle.py
class EagleProposer:
...
    def propose(self, ...):
        ...
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]
        for _ in range(self.num_speculative_tokens - 1):
            # The input for this iteration is the token generated in the previous one.
            input_ids = draft_token_ids_list[-1].int()
            # Runs the model for a single step
            last_hidden_states, hidden_states = self.model(...)
            # Calculates logits and samples the next token (with argmax)
            logits = self.model.compute_logits(last_hidden_states[:batch_size], None)
            draft_token_ids = logits.argmax(dim=-1)
            # Appends the new token for the next iteration
            draft_token_ids_list.append(draft_token_ids)

This is essentially what our mtp_speculative_gen_draft does: it calls the graph executor (llama_build_and_execute_mtp_graph), then the logic returns to the C++ side to sample the logits, and the loop feeds the result into the next iteration.
The bottleneck, is the dependency on the sampled token. We have to exit the graph to make that decision on the CPU.
This discovery meant I had to pivot from my original plan of a single large graph. The new approach is to loop within mtp_speculative_gen_draft, creating and executing a small graph for each draft token. I've put together a draft PR to show how this works in practice. I'd really appreciate your thoughts and feedback when you have a moment.

@F1LM1
Copy link
Owner

F1LM1 commented Sep 10, 2025

As I commented on the other PR, I suspect that supporting KV cache for multi-token MTP drafts is going to be a significant step up in complexity, while the one token case we can piggyback on the existing KV cache system (since it sets aside cache for the single MTP layer already).

I'll get a chance tomorrow to spin up this PR but I think this should represent the optimal sampling subroutine. If you're eager to finish it off, maybe start thinking about how we can make the setup more efficient by reusing stuff where possible (memory ctx? graphs? >1 size batches?), since we're basically recreating a bunch of stuff from scratch for every token.

@SamuelOliveirads
Copy link
Author

I'll get a chance tomorrow to spin up this PR but I think this should represent the optimal sampling subroutine. If you're eager to finish it off, maybe start thinking about how we can make the setup more efficient by reusing stuff where possible (memory ctx? graphs? >1 size batches?), since we're basically recreating a bunch of stuff from scratch for every token.

Okay, I'll take a look at what you suggested and find ways to store the state for the context and graph.

Regarding the batch size part, if I understand correctly, you're referring to fixing the alternation between draft and main model tokens in the server's main loop. I agree that would be a great optimization, but it seems like it would take a while, change a lot of the server logic, and require extensive testing.

I feel like that would be a good follow-up PR. It's more of a general feature to improve not only MTP but drafting in general, and giving it a separate PR would allow us to merge the core MTP implementation first.

@F1LM1
Copy link
Owner

F1LM1 commented Sep 11, 2025

Regarding the batch size part, if I understand correctly, you're referring to fixing the alternation between draft and main model tokens in the server's main loop. I agree that would be a great optimization, but it seems like it would take a while, change a lot of the server logic, and require extensive testing.

Nah, I meant some form of batching when we do the MTP layer prompt processing step, since we're likely going to process hundreds of tokens at once using the same graphs/memory context/etc. Right now we're building all 1-size batches, which just feels wrong.

AFAIK, the alternation thing might be deceptively easy to fix. I suspect it could be as simple as just making sure we only do the non-speculative llama_decode step exactly once, i.e. immediately after prompt processing. I'll need to find my notes on this, but I'm pretty sure everything else is always correctly synced, at least for the MTP case.

@SamuelOliveirads
Copy link
Author

Nah, I meant some form of batching when we do the MTP layer prompt processing step, since we're likely going to process hundreds of tokens at once using the same graphs/memory context/etc. Right now we're building all 1-size batches, which just feels wrong.

Ah, thanks for clarifying the batching part! I was looking at the main generation loop, but batching the MTP prompt processing makes sense. I'll keep that on my list of things to look into.

AFAIK, the alternation thing might be deceptively easy to fix. I suspect it could be as simple as just making sure we only do the non-speculative llama_decode step exactly once, i.e. immediately after prompt processing. I'll need to find my notes on this, but I'm pretty sure everything else is always correctly synced, at least for the MTP case.

That's a great point about the alternation fix potentially being simple; I'm looking forward to your notes on that when you find them.


In the meantime, I've been focused on your first suggestion: reusing resources during the single-token draft generation to avoid recreating everything from scratch. I've been experimenting with a llama_mtp_state object to persist the graph and context, but I've run into problems.

My latest attempt was to build the graph once inside llama_build_and_execute_mtp_graph and store the ggml_cgraph in the llama_mtp_state. The setup works for the first call, but it crashes on subsequent calls inside ggml_backend_sched_alloc_graph.

My diagnosis is that the ggml_cgraph created by model->build_mtp_graph seems to have a strong dependency on the temporary mctx and params (including the sched) that were used to build it. When we try to reuse that persistent graph with a new, temporary sched and mctx on the next call, ggml detects an incompatibility and fails. It seems the graph and the scheduler/context are tightly coupled.

It feels like a chicken-and-egg problem. To properly reuse the graph, we'd need to persist the scheduler, but to persist the scheduler, we'd need to persist the memory context, which is the core of the KV cache problem.

I'm probably missing something due to my limited knowledge of the ggml backend internals. Is there a good way to attack this? For example, is it possible to "re-bind" a persistent ggml_cgraph to a new scheduler or memory context on each execution? Or perhaps we could create a "template" graph with dummy inputs that can then be used with real contexts later?

Any pointers would be be a huge help.

@F1LM1 F1LM1 merged commit c6237c7 into F1LM1:glm4-moe-mtp Sep 13, 2025
@F1LM1
Copy link
Owner

F1LM1 commented Sep 13, 2025

I finally got a chance to test the improved sampler. It works well in my testing, raising draft acceptance rate in some "hard" writing scenarios by more than 10 percentage points on average, which is a clear and large gain. Ironically it ends up being slower in actual tok/s generation, presumably because the sampling chain as-is is inefficient, but let's see what we can do about that in follow-ups.

Re: the graph reuse questions you mentioned above, I'll fire up the project again this weekend and see what I find. It's been a while since I dove in.

@SamuelOliveirads SamuelOliveirads deleted the glm4-moe-mtp branch September 13, 2025 20:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants