Skip to content

Conversation

@akaashrp
Copy link
Contributor

@akaashrp akaashrp commented Jun 7, 2025

  1. Replace CPU function calls for the following tasks with GPU kernel invocations:
  • Apply logit bias
  • Apply penalties to logits
  • Compute softmax with temperature (sampling will be replaced in a future PR)
  1. Fixed bug with repetition penalty not being used in generation config
  • Added repetition penalty to CompletionCreateParamsBase and ChatCompletionRequestBase interfaces
  • Updated definition in GenerationConfig and added reference in engine.ts
  1. Added additional field in CompletionCreateParamsBase and ChatCompletionRequestBase interfaces to enable logging of time taken for individual steps
  2. Added sanity checks for individual steps in sampleTokenFromLogits

Performance Comparison: Compared performance for "canonical" flows averaged across 20 runs

  • No logit_bias
  • No logitProcessor
  • Applied penalties
  • With and without logprobs
  1. Before PR performance (without logprobs): ~0.064s per output token (~15.63 decode tokens/s)
  2. After PR performance (without logprobs): ~0.066s per output token (~15.15 decode tokens/s)
  3. Before PR performance (with logprobs): ~0.052s per output token (~19.23 decode tokens/s)
  4. After PR performance (without logprobs): ~0.048s per output token (~20.83 decode tokens/s)

Additional Notes:

  • Need to profile performance of sampleTopPFromLogits vs sampleTopPFromProb on CPU to determine why performance with logprobs is better
  • Application of logit_bias is much faster on GPU than CPU
  • There are additional overheads outside of the sampleTokenFromLogits function that make the performance improvement less pronounced (the total time spent in sampleTokenFromLogits is ~0.0117s before the PR and ~0.0076s after the PR)

@CharlieFRuan CharlieFRuan mentioned this pull request Jul 15, 2025
14 tasks
Copy link
Member

@CharlieFRuan CharlieFRuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Left some comments. As discussed offline, it'd be good to test each possible codepath and see if the behavior is expected, and measure performance for the canonical codepath. Thanks!

Copy link
Member

@CharlieFRuan CharlieFRuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thank you so much! Only two minor nits. Afterwards let's merge!

Copy link
Member

@CharlieFRuan CharlieFRuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work! Let's follow up with on-GPU sampling!

@CharlieFRuan CharlieFRuan merged commit e4b4dc2 into mlc-ai:main Sep 13, 2025
1 check passed
atebites-hub pushed a commit to atebites-hub/web-llm that referenced this pull request Oct 4, 2025
…-ai#697)

1. Replace CPU function calls for the following tasks with GPU kernel
invocations:
- Apply logit bias
- Apply penalties to logits
- Compute softmax with temperature (sampling will be replaced in a
future PR)
2. Fixed bug with repetition penalty not being used in generation config
- Added repetition penalty to CompletionCreateParamsBase and
ChatCompletionRequestBase interfaces
- Updated definition in GenerationConfig and added reference in
engine.ts
3. Added additional field in CompletionCreateParamsBase and
ChatCompletionRequestBase interfaces to enable logging of time taken for
individual steps
4. Added sanity checks for individual steps in sampleTokenFromLogits

Performance Comparison: Compared performance for "canonical" flows
averaged across 20 runs
- No logit_bias
- No logitProcessor
- Applied penalties
- With and without logprobs

1. Before PR performance (without logprobs): ~0.064s per output token
(~15.63 decode tokens/s)
2. After PR performance (without logprobs): ~0.066s per output token
(~15.15 decode tokens/s)
3. Before PR performance (with logprobs): ~0.052s per output token
(~19.23 decode tokens/s)
5. After PR performance (without logprobs): ~0.048s per output token
(~20.83 decode tokens/s)

Additional Notes:
- Need to profile performance of sampleTopPFromLogits vs
sampleTopPFromProb on CPU to determine why performance with logprobs is
better
- Application of logit_bias is much faster on GPU than CPU
- There are additional overheads outside of the sampleTokenFromLogits
function that make the performance improvement less pronounced (the
total time spent in sampleTokenFromLogits is ~0.0117s before the PR and
~0.0076s after the PR)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants