-
Notifications
You must be signed in to change notification settings - Fork 31.2k
perf: Optimization for Min-p sampling implementation #42248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: Optimization for Min-p sampling implementation #42248
Conversation
|
Very interesting! This looks better than the original already. I wonder if we could compute this even more directly, though - really all we're doing is is keeping all the elements larger than cutoff_value = torch.quantile(probs, ...) # Set the quantile so that we get the n-th largest value for each sequence
cutoff_value.clamp(min=scaled_min_p) # Clip the cutoff value if it's smaller than scaled_min_p
scores_processed = scores.masked_fill(probs < cutoff_value, self.filter_value)and totally avoid computing I'm not sure if this works, but it might be worth a try! If you don't feel like benchmarking it, though, I can just review and merge your existing code, which is already a nice speedup. EDIT: I think |
|
Hey @Rocketknight1, thanks for the suggestion. I tried with quantile() and You said
I'd still have to compute it to get
Based on your suggestion I tried this alternative: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)
topk_probs = probs.topk(self.min_tokens_to_keep, dim=-1).values
# get top prob and n-th prob
top_prob = topk_probs.amax(dim=-1, keepdim=True)
nth_prob = topk_probs[..., -1]
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
scaled_min_p = top_prob * self.min_p
# we keep tokens that are >= scaled_min_p but also >= nth_prob
cutoff_value = torch.minimum(scaled_min_p, nth_prob)
# or cutoff_value = nth_prob.clamp(max=scaled_min_p)
scores_processed = scores.masked_fill(probs < cutoff_value, self.filter_value)
return scores_processedBut I didn’t see meaningful improvements compared to the first optimization, to justify choosing one over the other. Both are faster than the original or using I guess it's up to you guys on what style/preference you want to go with. |
|
Yes, thank you for checking! I'm happy with your implementation in that case. One final edge case I thought of, though: If Maybe we keep your implementation, but clip the |
|
Ah yes, that's a good find👍 Updated to take into account this edge case: |
0b38ebb to
77afe2c
Compare
Rocketknight1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, code looks good to me now. Thanks for a nice PR!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
77afe2c to
bbeb2a3
Compare
What does this PR do?
tldr: This PR aims to optimize
MinPLogitsWarperimplementation that was done in #30639Hello,
While I was re-implementing Min-P for my own sake, and later compared the current HF/official implementation, I've noticed that we could use a single
torch.topkcall to retrievemin_tokens_to_keepinstead of chainingargsort+gather+ masking.Also, minor, I simplified using
torch.amaxinstead oftorch.maxfor getting only top_probs.Overall, besides being faster, the readability is also improved I believe.
Benchmarking in isolation, the proposed optimization is noticeably faster, in bf16, varying single/multiple batches, vocab_sizes:
In real situation, generating with a
Qwen3-0.6Bmodel:model.generate()call (on average) withmax_new_tokens=200All tests passed. There's no extra changes needed to do I believe, it's just internal code optimization.
Min-P is a very popular sampling alternative after top_k and top_p if I'm not mistaken, so if this can help slightly speed up inference, I thought why not?
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Let me know what you think @gante
Also tagging @menhguin as he is a co-author on Min-P sampling who participated in the original PR