diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 056cd847..f9b1de27 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -53,6 +53,8 @@ export class LLMChatPipeline { private fapplyPenalty: tvmjs.PackedFunc; private fapplyLogitBias: tvmjs.PackedFunc; private fsoftmaxWithTemperature: tvmjs.PackedFunc; + private fsampleWithTopP: tvmjs.PackedFunc; + private fargsortProbs: tvmjs.PackedFunc; // Functions related to PagedKVCache private fclearKVCaches: tvmjs.PackedFunc; @@ -142,6 +144,10 @@ export class LLMChatPipeline { private curRoundGrammarInitTotalTime = 0; // Total time of getting next bitmask and accepting token in seconds private curRoundGrammarPerTokenTotalTime = 0; + // Instance variables for supporting sampling on WebGPU + private sampleIndices: Int32Array; + private sampleIndicesDevice: tvmjs.Tensor; + private topPDevice: tvmjs.Tensor; constructor( tvm: tvmjs.Instance, @@ -213,6 +219,12 @@ export class LLMChatPipeline { this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope( this.vm.getFunction("softmax_with_temperature"), ); + this.fsampleWithTopP = this.tvm.detachFromCurrentScope( + this.vm.getFunction("sample_with_top_p"), + ); + this.fargsortProbs = this.tvm.detachFromCurrentScope( + this.vm.getFunction("argsort_probs"), + ); try { this.image_embed = this.tvm.detachFromCurrentScope( this.vm.getFunction("image_embed"), @@ -310,6 +322,25 @@ export class LLMChatPipeline { this.filledKVCacheLength = 0; this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence + + // Initialize WebGPU sampling related device tensors + const numSamples = 1; + const numProbs = 1; + + this.sampleIndices = new Int32Array(numSamples); + for (let i = 0; i < numSamples; i++) { + this.sampleIndices[i] = i; + } + this.sampleIndicesDevice = this.tvm.detachFromCurrentScope( + this.tvm + .empty([numSamples], "int32", this.device) + .copyFrom(this.sampleIndices), + ); + + this.topPDevice = this.tvm.detachFromCurrentScope( + this.tvm.empty([numProbs], "float32", this.device), + ); + tvm.endScope(); } @@ -1271,11 +1302,13 @@ export class LLMChatPipeline { // If logprobs, need the actual distribution via softmax, otherwise directly sample from logits const sampleBegin = performance.now(); let sampledToken: number; - if (logprobs) { + let sampledTokensDevice: tvmjs.Tensor; + if (logprobs && _hasValue(top_p)) { // Inplace transform logitsOnCPU to a distribution temperature = Math.max(1e-6, temperature); // to prevent division by zero const numSeqs = 1; + const numProbs = 1; const temperatures = new Float32Array([temperature]); @@ -1284,18 +1317,52 @@ export class LLMChatPipeline { .empty([numSeqs], "float32", this.device) .copyFrom(temperatures); - const probs = this.fsoftmaxWithTemperature( - logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]), + let probs = this.fsoftmaxWithTemperature( + logitsOnGPU.view([numSeqs, numProbs, this.fullVocabSize]), temperaturesDevice, ); - this.updateLogitsOnCPU(probs); + probs = probs.view([numProbs, this.fullVocabSize]); + + const argsortResults = this.fargsortProbs(probs); + const sortedProbsDevice = argsortResults.get(0); + const sortedIndicesDevice = argsortResults.get(1); + + const uniformSamplesDevice = this.tvm.uniform([1], 0.0, 1.0, this.device); + + const topPHost = new Float32Array(numProbs).fill(-1); + const topPValue = Math.max(top_p, 1e-5); + this.sampleIndices.forEach((row) => { + topPHost[row] = topPValue; + }); + this.topPDevice.copyFrom(topPHost); + + sampledTokensDevice = this.tvm.detachFromCurrentScope( + this.fsampleWithTopP( + sortedProbsDevice, + sortedIndicesDevice, + uniformSamplesDevice, + this.sampleIndicesDevice, + this.topPDevice, + ), + ); + const sampledTokensHost = this.tvm.detachFromCurrentScope( + this.tvm + .empty([numSeqs], "int32", this.tvm.cpu()) + .copyFrom(sampledTokensDevice), + ); + if (top_logprobs! > 0) { + this.updateLogitsOnCPU(probs); + } this.tvm.endScope(); await this.device.sync(); - sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU!, top_p); - this.tokenLogprobArray.push( - this.getTokenLogprob(sampledToken, top_logprobs!), - ); + sampledToken = sampledTokensHost.toArray()[0]; + + if (top_logprobs! > 0) { + this.tokenLogprobArray.push( + this.getTokenLogprob(sampledToken, top_logprobs!), + ); + } } else { // temperature being 0 is allowed here, equivalent to argmax this.tvm.beginScope();