Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export class LLMChatPipeline {
private fapplyPenalty: tvmjs.PackedFunc;
private fapplyLogitBias: tvmjs.PackedFunc;
private fsoftmaxWithTemperature: tvmjs.PackedFunc;
private fsampleWithTopP: tvmjs.PackedFunc;
private fargsortProbs: tvmjs.PackedFunc;

// Functions related to PagedKVCache
private fclearKVCaches: tvmjs.PackedFunc;
Expand Down Expand Up @@ -142,6 +144,10 @@ export class LLMChatPipeline {
private curRoundGrammarInitTotalTime = 0;
// Total time of getting next bitmask and accepting token in seconds
private curRoundGrammarPerTokenTotalTime = 0;
// Instance variables for supporting sampling on WebGPU
private sampleIndices: Int32Array;
private sampleIndicesDevice: tvmjs.Tensor;
private topPDevice: tvmjs.Tensor;

constructor(
tvm: tvmjs.Instance,
Expand Down Expand Up @@ -213,6 +219,12 @@ export class LLMChatPipeline {
this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope(
this.vm.getFunction("softmax_with_temperature"),
);
this.fsampleWithTopP = this.tvm.detachFromCurrentScope(
this.vm.getFunction("sample_with_top_p"),
);
this.fargsortProbs = this.tvm.detachFromCurrentScope(
this.vm.getFunction("argsort_probs"),
);
try {
this.image_embed = this.tvm.detachFromCurrentScope(
this.vm.getFunction("image_embed"),
Expand Down Expand Up @@ -310,6 +322,25 @@ export class LLMChatPipeline {

this.filledKVCacheLength = 0;
this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence

// Initialize WebGPU sampling related device tensors
const numSamples = 1;
const numProbs = 1;

this.sampleIndices = new Int32Array(numSamples);
for (let i = 0; i < numSamples; i++) {
this.sampleIndices[i] = i;
}
this.sampleIndicesDevice = this.tvm.detachFromCurrentScope(
this.tvm
.empty([numSamples], "int32", this.device)
.copyFrom(this.sampleIndices),
);

this.topPDevice = this.tvm.detachFromCurrentScope(
this.tvm.empty([numProbs], "float32", this.device),
);

tvm.endScope();
}

Expand Down Expand Up @@ -1271,11 +1302,13 @@ export class LLMChatPipeline {
// If logprobs, need the actual distribution via softmax, otherwise directly sample from logits
const sampleBegin = performance.now();
let sampledToken: number;
if (logprobs) {
let sampledTokensDevice: tvmjs.Tensor;
if (logprobs && _hasValue(top_p)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remind me why we add a _hasValue(top_p) here? If a user wants logprobs but does not provide a top_p, it would go to the else branch, and thus not populating the tokenLogprobArray.

Let's set top_p to 1.0 -- the default value at the start when we are pre-processing the sampling parameters. Then we can remove this condition change

// Inplace transform logitsOnCPU to a distribution
temperature = Math.max(1e-6, temperature); // to prevent division by zero

const numSeqs = 1;
const numProbs = 1;

const temperatures = new Float32Array([temperature]);

Expand All @@ -1284,18 +1317,52 @@ export class LLMChatPipeline {
.empty([numSeqs], "float32", this.device)
.copyFrom(temperatures);

const probs = this.fsoftmaxWithTemperature(
logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]),
let probs = this.fsoftmaxWithTemperature(
logitsOnGPU.view([numSeqs, numProbs, this.fullVocabSize]),
temperaturesDevice,
);
this.updateLogitsOnCPU(probs);
probs = probs.view([numProbs, this.fullVocabSize]);

const argsortResults = this.fargsortProbs(probs);
const sortedProbsDevice = argsortResults.get(0);
const sortedIndicesDevice = argsortResults.get(1);

const uniformSamplesDevice = this.tvm.uniform([1], 0.0, 1.0, this.device);

const topPHost = new Float32Array(numProbs).fill(-1);
const topPValue = Math.max(top_p, 1e-5);
this.sampleIndices.forEach((row) => {
topPHost[row] = topPValue;
});
this.topPDevice.copyFrom(topPHost);

sampledTokensDevice = this.tvm.detachFromCurrentScope(
this.fsampleWithTopP(
sortedProbsDevice,
sortedIndicesDevice,
uniformSamplesDevice,
this.sampleIndicesDevice,
this.topPDevice,
),
);
const sampledTokensHost = this.tvm.detachFromCurrentScope(
this.tvm
.empty([numSeqs], "int32", this.tvm.cpu())
.copyFrom(sampledTokensDevice),
);
if (top_logprobs! > 0) {
this.updateLogitsOnCPU(probs);
}
this.tvm.endScope();
await this.device.sync();

sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU!, top_p);
this.tokenLogprobArray.push(
this.getTokenLogprob(sampledToken, top_logprobs!),
);
sampledToken = sampledTokensHost.toArray()[0];

if (top_logprobs! > 0) {
this.tokenLogprobArray.push(
this.getTokenLogprob(sampledToken, top_logprobs!),
);
}
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question: why we cannot use GPU sample kernel when we logprobs is False?

Copy link
Contributor Author

@akaashrp akaashrp Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the flow for the logprobs False case involves invoking _attach_multinomial_sampling_func / parallel_sampling_from_prob, which contains i8s that are not supported by WGSL / WebGPU yet. I experimented with enabling some experimental flags at the beginning of the relevant kernels, but I wasn't able to get these to work. One thing I haven't tried yet though is replacing the int8s with some other supported datatype in line 131 here: https://github.com/apache/tvm/blob/26db8bfd7e527198f43f3cc379f404c7513a82ef/python/tvm/relax/backend/gpu_generic/sampling.py#L131C1-L132C1.

Copy link
Member

@CharlieFRuan CharlieFRuan Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.. Ideally we could modify those kernels to not use i8s if the backend is WebGPU in TVM

Let's leave a TODO at the start of this else { and somewhere in this PR's description.

I suppose the else branch is the more canonical codepath, since local deployment rarely uses logprob I suppose.

But this PR is great!

// temperature being 0 is allowed here, equivalent to argmax
this.tvm.beginScope();
Expand Down