-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Kernels] Migrate sampling to WebGPU #737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,8 @@ export class LLMChatPipeline { | |
| private fapplyPenalty: tvmjs.PackedFunc; | ||
| private fapplyLogitBias: tvmjs.PackedFunc; | ||
| private fsoftmaxWithTemperature: tvmjs.PackedFunc; | ||
| private fsampleWithTopP: tvmjs.PackedFunc; | ||
| private fargsortProbs: tvmjs.PackedFunc; | ||
|
|
||
| // Functions related to PagedKVCache | ||
| private fclearKVCaches: tvmjs.PackedFunc; | ||
|
|
@@ -142,6 +144,10 @@ export class LLMChatPipeline { | |
| private curRoundGrammarInitTotalTime = 0; | ||
| // Total time of getting next bitmask and accepting token in seconds | ||
| private curRoundGrammarPerTokenTotalTime = 0; | ||
| // Instance variables for supporting sampling on WebGPU | ||
| private sampleIndices: Int32Array; | ||
| private sampleIndicesDevice: tvmjs.Tensor; | ||
| private topPDevice: tvmjs.Tensor; | ||
|
|
||
| constructor( | ||
| tvm: tvmjs.Instance, | ||
|
|
@@ -213,6 +219,12 @@ export class LLMChatPipeline { | |
| this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("softmax_with_temperature"), | ||
| ); | ||
| this.fsampleWithTopP = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("sample_with_top_p"), | ||
| ); | ||
| this.fargsortProbs = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("argsort_probs"), | ||
| ); | ||
| try { | ||
| this.image_embed = this.tvm.detachFromCurrentScope( | ||
| this.vm.getFunction("image_embed"), | ||
|
|
@@ -310,6 +322,25 @@ export class LLMChatPipeline { | |
|
|
||
| this.filledKVCacheLength = 0; | ||
| this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence | ||
|
|
||
| // Initialize WebGPU sampling related device tensors | ||
| const numSamples = 1; | ||
| const numProbs = 1; | ||
|
|
||
| this.sampleIndices = new Int32Array(numSamples); | ||
| for (let i = 0; i < numSamples; i++) { | ||
| this.sampleIndices[i] = i; | ||
| } | ||
| this.sampleIndicesDevice = this.tvm.detachFromCurrentScope( | ||
| this.tvm | ||
| .empty([numSamples], "int32", this.device) | ||
| .copyFrom(this.sampleIndices), | ||
| ); | ||
|
|
||
| this.topPDevice = this.tvm.detachFromCurrentScope( | ||
| this.tvm.empty([numProbs], "float32", this.device), | ||
| ); | ||
|
|
||
| tvm.endScope(); | ||
| } | ||
|
|
||
|
|
@@ -1271,11 +1302,13 @@ export class LLMChatPipeline { | |
| // If logprobs, need the actual distribution via softmax, otherwise directly sample from logits | ||
| const sampleBegin = performance.now(); | ||
| let sampledToken: number; | ||
| if (logprobs) { | ||
| let sampledTokensDevice: tvmjs.Tensor; | ||
| if (logprobs && _hasValue(top_p)) { | ||
| // Inplace transform logitsOnCPU to a distribution | ||
| temperature = Math.max(1e-6, temperature); // to prevent division by zero | ||
|
|
||
| const numSeqs = 1; | ||
| const numProbs = 1; | ||
|
|
||
| const temperatures = new Float32Array([temperature]); | ||
|
|
||
|
|
@@ -1284,18 +1317,52 @@ export class LLMChatPipeline { | |
| .empty([numSeqs], "float32", this.device) | ||
| .copyFrom(temperatures); | ||
|
|
||
| const probs = this.fsoftmaxWithTemperature( | ||
| logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]), | ||
| let probs = this.fsoftmaxWithTemperature( | ||
| logitsOnGPU.view([numSeqs, numProbs, this.fullVocabSize]), | ||
| temperaturesDevice, | ||
| ); | ||
| this.updateLogitsOnCPU(probs); | ||
| probs = probs.view([numProbs, this.fullVocabSize]); | ||
|
|
||
| const argsortResults = this.fargsortProbs(probs); | ||
| const sortedProbsDevice = argsortResults.get(0); | ||
| const sortedIndicesDevice = argsortResults.get(1); | ||
|
|
||
| const uniformSamplesDevice = this.tvm.uniform([1], 0.0, 1.0, this.device); | ||
|
|
||
| const topPHost = new Float32Array(numProbs).fill(-1); | ||
| const topPValue = Math.max(top_p, 1e-5); | ||
| this.sampleIndices.forEach((row) => { | ||
| topPHost[row] = topPValue; | ||
| }); | ||
| this.topPDevice.copyFrom(topPHost); | ||
|
|
||
| sampledTokensDevice = this.tvm.detachFromCurrentScope( | ||
| this.fsampleWithTopP( | ||
| sortedProbsDevice, | ||
| sortedIndicesDevice, | ||
| uniformSamplesDevice, | ||
| this.sampleIndicesDevice, | ||
| this.topPDevice, | ||
| ), | ||
| ); | ||
| const sampledTokensHost = this.tvm.detachFromCurrentScope( | ||
| this.tvm | ||
| .empty([numSeqs], "int32", this.tvm.cpu()) | ||
| .copyFrom(sampledTokensDevice), | ||
| ); | ||
| if (top_logprobs! > 0) { | ||
| this.updateLogitsOnCPU(probs); | ||
| } | ||
| this.tvm.endScope(); | ||
| await this.device.sync(); | ||
|
|
||
| sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU!, top_p); | ||
| this.tokenLogprobArray.push( | ||
| this.getTokenLogprob(sampledToken, top_logprobs!), | ||
| ); | ||
| sampledToken = sampledTokensHost.toArray()[0]; | ||
|
|
||
| if (top_logprobs! > 0) { | ||
| this.tokenLogprobArray.push( | ||
| this.getTokenLogprob(sampledToken, top_logprobs!), | ||
| ); | ||
| } | ||
| } else { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quick question: why we cannot use GPU sample kernel when we
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, the flow for the logprobs False case involves invoking
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see.. Ideally we could modify those kernels to not use i8s if the backend is WebGPU in TVM Let's leave a TODO at the start of this I suppose the But this PR is great! |
||
| // temperature being 0 is allowed here, equivalent to argmax | ||
| this.tvm.beginScope(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remind me why we add a
_hasValue(top_p)here? If a user wantslogprobsbut does not provide atop_p, it would go to the else branch, and thus not populating thetokenLogprobArray.Let's set
top_pto1.0-- the default value at the start when we are pre-processing the sampling parameters. Then we can remove this condition change