Skip to content

Conversation

SamuelOliveirads
Copy link

For now, this is just a proof of concept I had in mind. I tried to replicate the workflow from SGLang because even with a proper implementation in vLLM, they are not drafting more than one token, as stated in this comment on their PR.

The drafting process must be an autoregressive loop to accommodate the CPU-side sampling required between each token generation. In case you are curious, the loop in SGLang looks like this:

# in sglang/srt/speculative/eagle_worker.py
class EAGLEWorker:
    ...
    def draft_forward(self, forward_batch: ForwardBatch):
        ...
        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            # Selects the input tokens for this iteration
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(...)
            
            # ... (saves the results)

            # Stop generating after the last required step
            if i == self.speculative_num_steps - 1:
                break

            # Prepares the inputs for the next model run
            forward_batch.input_ids = input_ids
            # ... (other batch preparations)

            # Executes the draft model for ONE step
            logits_output, _ = self.draft_model_runner.forward(
                forward_batch, skip_attn_backend_init=True
            )
            
            # Samples the logits to obtain tokens for the NEXT iteration
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

With that in mind, I implemented a loop for mtp_speculative_gen_draft based on how many draft tokens you want at once, which gives us a similar idea.

The major problem is a significant limitation regarding KV cache management; the KV cache is not persisted between draft steps, leading to a degradation in draft quality as more tokens are generated.

Here are some preliminary results showing the drop in acceptance rate (It's hardcoded for now to run 5 drafts at once in the line const int n_mtp_draft_target = 5):

  • 1 draft: draft acceptance rate = 0.64045 ( 57 accepted / 89 generated)
  • 2 drafts: draft acceptance rate = 0.41765 ( 71 accepted / 170 generated)
  • 3 drafts: draft acceptance rate = 0.31481 ( 68 accepted / 216 generated)
  • 4 drafts: draft acceptance rate = 0.25309 ( 82 accepted / 324 generated)
  • 5 drafts: draft acceptance rate = 0.16667 ( 55 accepted / 330 generated)

If my concept is correct, then we "just" need to fix the KV cache, and here is the nightmare that you probably walked before. A proper solution would likely involve creating a persistent KV cache context for the duration of the draft loop.

I can see two options to follow:

  1. Continue down this path to fix the KV Cache (and I would recommend doing that only after your server: implement GLM-style MTP PR is merged).
  2. Deliver this simplified version with only 1 draft token. This would align with vLLM's current safe implementation for GLM-4.5 and still provide a performance benefit.

@F1LM1
Copy link
Owner

F1LM1 commented Sep 10, 2025

Great work! 👍

regarding KV cache management

I suspect this will wind up being even more awkward than the existing MTP KV cache management setup since KV caching is done per-layer (in the sense that you can index into each layer's KV cache). One possible solution might be to add "layers" when N draft tokens is > 1. The other solution, possibly cleaner and more consistent with existing llama.cpp behavior, would be to revert to storing an explicit draft context as in the current speculative module. I don't love either solution right now, the former because it's hacky, the latter because IIRC the current speculative draft KV cache isn't really maintained that well. I agree that it's best to deliver the simplified version first before reverting here, it would add a good deal of complexity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants