Skip to content

Conversation

@casinca
Copy link
Contributor

@casinca casinca commented Nov 17, 2025

What does this PR do?

tldr: This PR aims to optimize MinPLogitsWarper implementation that was done in #30639

Hello,

While I was re-implementing Min-P for my own sake, and later compared the current HF/official implementation, I've noticed that we could use a single torch.topk call to retrieve min_tokens_to_keep instead of chaining argsort + gather + masking.
Also, minor, I simplified using torch.amax instead of torch.max for getting only top_probs.

Overall, besides being faster, the readability is also improved I believe.


Benchmarking in isolation, the proposed optimization is noticeably faster, in bf16, varying single/multiple batches, vocab_sizes:

  • on cpu, 2x up to 10x
  • on cuda, 20% up to 100%

In real situation, generating with a Qwen3-0.6B model:

  • on cpu, i'm getting ~2s faster per model.generate() call (on average) with max_new_tokens=200
  • on cuda, there's no meaningful difference

All tests passed. There's no extra changes needed to do I believe, it's just internal code optimization.

Min-P is a very popular sampling alternative after top_k and top_p if I'm not mistaken, so if this can help slightly speed up inference, I thought why not?

 

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

 

Who can review?

Let me know what you think @gante
Also tagging @menhguin as he is a co-author on Min-P sampling who participated in the original PR

@Rocketknight1
Copy link
Member

Rocketknight1 commented Nov 18, 2025

Very interesting! This looks better than the original already. I wonder if we could compute this even more directly, though - really all we're doing is is keeping all the elements larger than scaled_min_p and the n-th largest probability, where n is self.min_tokens_to_keep. So we could just do something like

cutoff_value = torch.quantile(probs, ...)  # Set the quantile so that we get the n-th largest value for each sequence
cutoff_value.clamp(min=scaled_min_p)  # Clip the cutoff value if it's smaller than scaled_min_p
scores_processed = scores.masked_fill(probs < cutoff_value, self.filter_value)

and totally avoid computing top_probs, sorted_indices or tokens_to_remove. The behaviour might change slightly in cases where a lot of tokens have probability exactly equal to the cutoff value, but that seems like an edge case where either behaviour is acceptable.

I'm not sure if this works, but it might be worth a try! If you don't feel like benchmarking it, though, I can just review and merge your existing code, which is already a nice speedup.

EDIT: I think torch.quantile() is actually not very optimized, so this approach probably won't work. Finding the n-th largest element might be quicker with topk(sorted=True)

@casinca
Copy link
Contributor Author

casinca commented Nov 18, 2025

Hey @Rocketknight1, thanks for the suggestion.

I tried with quantile() and q= 1- (self.min_tokens_to_keep/vocab_size) you had the right intuition because it was way worse than using topk.

You said

and totally avoid computing top_probs

I'd still have to compute it to get scaled_min_p if i'm not mistaken?
If you meant cutoff_value.clamp(min=scaled_min_p) only, it could fail for some cases where we also need to guarantee at least self.min_tokens_to_keep
nth_prob.clamp(max=scaled_min_p) work (more readable imo is torch.minimum(scaled_min_p, nth_prob))

 

Based on your suggestion I tried this alternative:

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
    # Convert logits to probabilities
    probs = torch.softmax(scores, dim=-1)

    topk_probs = probs.topk(self.min_tokens_to_keep, dim=-1).values
    # get top prob and n-th prob
    top_prob = topk_probs.amax(dim=-1, keepdim=True)
    nth_prob = topk_probs[..., -1]
    # Calculate the actual min_p threshold by scaling min_p with the top token's probability
    scaled_min_p = top_prob * self.min_p
    # we keep tokens that are >= scaled_min_p but also >= nth_prob
    cutoff_value = torch.minimum(scaled_min_p, nth_prob)
    # or cutoff_value = nth_prob.clamp(max=scaled_min_p)

    scores_processed = scores.masked_fill(probs < cutoff_value, self.filter_value)
    return scores_processed

But I didn’t see meaningful improvements compared to the first optimization, to justify choosing one over the other. Both are faster than the original or using torch.quantile(). Quite pronounced on cpu and to a lesser extent on cuda.

I guess it's up to you guys on what style/preference you want to go with.

@Rocketknight1
Copy link
Member

Yes, thank you for checking! I'm happy with your implementation in that case. One final edge case I thought of, though: If min_tokens_to_keep is larger than the vocab size, topk will throw an error, but the existing code will tolerate this because sorted_indices_to_remove[..., : self.min_tokens_to_keep] will work fine even if that slice runs off the end of the array.

Maybe we keep your implementation, but clip the k arg to topk so that it's never larger than the array size, to avoid that potential error? Then we should be fully equivalent to the original code while keeping the performance improvements.

@casinca
Copy link
Contributor Author

casinca commented Nov 19, 2025

Ah yes, that's a good find👍

Updated to take into account this edge case:
k for topk is now the min of self.min_tokens_to_keep and vocab size (inferred from probs.shape[-1])

@Rocketknight1 Rocketknight1 force-pushed the optimized-min-p-sampling branch from 0b38ebb to 77afe2c Compare November 19, 2025 16:04
Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, code looks good to me now. Thanks for a nice PR!

@Rocketknight1 Rocketknight1 enabled auto-merge (squash) November 19, 2025 16:08
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1 Rocketknight1 force-pushed the optimized-min-p-sampling branch from 77afe2c to bbeb2a3 Compare November 19, 2025 16:34
@Rocketknight1 Rocketknight1 merged commit 4391cfd into huggingface:main Nov 19, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants