diff --git a/.gitignore b/.gitignore
index 8de96b40..adc767f8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -324,5 +324,4 @@ node_modules
lib
.parcel-cache
-examples/tests
**/.next
\ No newline at end of file
diff --git a/examples/get-started-latency-breakdown/README.md b/examples/get-started-latency-breakdown/README.md
new file mode 100644
index 00000000..2a8f6967
--- /dev/null
+++ b/examples/get-started-latency-breakdown/README.md
@@ -0,0 +1,15 @@
+# WebLLM Get Started App
+
+This folder provides a minimum demo to show WebLLM API in a webapp setting with
+collection of latency statistics for individual token sampling steps.
+To try it out, you can do the following steps under this folder
+
+```bash
+npm install
+npm start
+```
+
+Note if you would like to hack WebLLM core package.
+You can change web-llm dependencies as `"file:../.."`, and follow the build from source
+instruction in the project to build webllm locally. This option is only recommended
+if you would like to hack WebLLM core package.
diff --git a/examples/get-started-latency-breakdown/package.json b/examples/get-started-latency-breakdown/package.json
new file mode 100644
index 00000000..0b321e9d
--- /dev/null
+++ b/examples/get-started-latency-breakdown/package.json
@@ -0,0 +1,20 @@
+{
+ "name": "get-started-latency-breakdown",
+ "version": "0.1.0",
+ "private": true,
+ "scripts": {
+ "start": "parcel src/get_started_latency_breakdown.html --port 8888",
+ "build": "parcel build src/get_started_latency_breakdown.html --dist-dir lib"
+ },
+ "devDependencies": {
+ "buffer": "^5.7.1",
+ "parcel": "^2.8.3",
+ "process": "^0.11.10",
+ "tslib": "^2.3.1",
+ "typescript": "^4.9.5",
+ "url": "^0.11.3"
+ },
+ "dependencies": {
+ "@mlc-ai/web-llm": "^0.2.79"
+ }
+}
diff --git a/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.html b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.html
new file mode 100644
index 00000000..18298616
--- /dev/null
+++ b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.html
@@ -0,0 +1,23 @@
+
+
+
+
+
WebLLM Test Page
+ Open console to see output
+
+
+
+
+
Prompt
+
+
+
Response
+
+
+
+
+
+
+
diff --git a/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.ts b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.ts
new file mode 100644
index 00000000..104af5da
--- /dev/null
+++ b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.ts
@@ -0,0 +1,135 @@
+import * as webllm from "@mlc-ai/web-llm";
+
+function setLabel(id: string, text: string) {
+ const label = document.getElementById(id);
+ if (label == null) {
+ throw Error("Cannot find label " + id);
+ }
+ label.innerText = text;
+}
+
+type LatencyBreakdown = {
+ logitProcessorTime: number[];
+ logitBiasTime: number[];
+ penaltyTime: number[];
+ sampleTime: number[];
+ totalTime: number[];
+ grammarBitmaskTime: number[];
+};
+function computeStats(
+ latency_breakdown: LatencyBreakdown,
+): Record {
+ function _computeStats(arr: number[]) {
+ if (!arr.length) return undefined;
+ const sorted = [...arr].sort((a, b) => a - b);
+ const sum = arr.reduce((a, b) => a + b, 0);
+ const avg = sum / arr.length;
+ const min = sorted[0];
+ const max = sorted[sorted.length - 1];
+ const p99 = sorted[Math.floor(0.99 * (sorted.length - 1))];
+ return { avg, min, max, p99 };
+ }
+
+ const latencyStats: Record = {};
+ for (const key of Object.keys(latency_breakdown)) {
+ const arr = (latency_breakdown as any)[key];
+ if (Array.isArray(arr) && arr.length > 0) {
+ latencyStats[key] = _computeStats(arr);
+ }
+ }
+ return latencyStats;
+}
+
+async function main() {
+ const initProgressCallback = (report: webllm.InitProgressReport) => {
+ setLabel("init-label", report.text);
+ };
+ // Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts`
+ const selectedModel = "Qwen3-0.6B-q0f32-MLC";
+ const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
+ selectedModel,
+ {
+ initProgressCallback: initProgressCallback,
+ logLevel: "INFO", // specify the log level
+ },
+ // customize kv cache, use either context_window_size or sliding_window_size (with attention sink)
+ {
+ context_window_size: 2048,
+ // sliding_window_size: 1024,
+ // attention_sink_size: 4,
+ },
+ );
+
+ const latencyBreakdown: LatencyBreakdown = {
+ logitProcessorTime: [],
+ logitBiasTime: [],
+ penaltyTime: [],
+ sampleTime: [],
+ totalTime: [],
+ grammarBitmaskTime: [],
+ };
+
+ const decodeTokensPerS: number[] = [];
+ const completionTokens: number[] = [];
+ const e2eLatencyS: number[] = [];
+ const timePerOutputTokenS: number[] = [];
+
+ const numTrials = 20;
+ for (let i = 0; i < numTrials; i++) {
+ console.log(`Trial ${i + 1} / ${numTrials}`);
+ const reply0 = await engine.chat.completions.create({
+ messages: [{ role: "user", content: "List twenty US states." }],
+ // below configurations are all optional
+ n: 1,
+ temperature: 0,
+ max_tokens: 2048,
+ // 46510 and 7188 are "California", and 8421 and 51325 are "Texas" in Llama-3.1-8B-Instruct
+ // So we would have a higher chance of seeing the latter two, but never the first in the answer
+ // logit_bias: {
+ // "46510": -100,
+ // "7188": -100,
+ // "8421": 5,
+ // "41325": 5,
+ // },
+ top_p: 0.8,
+ logprobs: true,
+ top_logprobs: 2,
+ frequency_penalty: 1.2,
+ presence_penalty: 1.0,
+ repetition_penalty: 1.1,
+ });
+
+ const logitProcessorTime =
+ reply0.usage?.extra.latencyBreakdown?.logitProcessorTime;
+ const logitBiasTime = reply0.usage?.extra.latencyBreakdown?.logitBiasTime;
+ const penaltyTime = reply0.usage?.extra.latencyBreakdown?.penaltyTime;
+ const sampleTime = reply0.usage?.extra.latencyBreakdown?.sampleTime;
+ const totalTime = reply0.usage?.extra.latencyBreakdown?.totalTime;
+ const grammarBitmaskTime =
+ reply0.usage?.extra.latencyBreakdown?.grammarBitmaskTime;
+
+ latencyBreakdown.logitProcessorTime.push(...(logitProcessorTime || []));
+ latencyBreakdown.logitBiasTime.push(...(logitBiasTime || []));
+ latencyBreakdown.penaltyTime.push(...(penaltyTime || []));
+ latencyBreakdown.sampleTime.push(...(sampleTime || []));
+ latencyBreakdown.totalTime.push(...(totalTime || []));
+ latencyBreakdown.grammarBitmaskTime.push(...(grammarBitmaskTime || []));
+
+ decodeTokensPerS.push(reply0.usage?.extra.decode_tokens_per_s || 0);
+ e2eLatencyS.push(reply0.usage?.extra.e2e_latency_s || 0);
+ timePerOutputTokenS.push(reply0.usage?.extra.time_per_output_token_s || 0);
+ completionTokens.push(reply0.usage?.completion_tokens || 0);
+ }
+
+ const latencyStats: { [key: string]: number } =
+ computeStats(latencyBreakdown);
+ console.log("Latency stats: ", latencyStats);
+ console.log("Decode tokens per second: ", decodeTokensPerS);
+ console.log("Completion tokens: ", completionTokens);
+ console.log("E2E latency (s): ", e2eLatencyS);
+ console.log("Time per output token (s): ", timePerOutputTokenS);
+
+ // To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)`
+}
+
+main();
diff --git a/src/config.ts b/src/config.ts
index dfeb1913..0ca51eba 100644
--- a/src/config.ts
+++ b/src/config.ts
@@ -126,7 +126,7 @@ export interface MLCEngineConfig {
*/
export interface GenerationConfig {
// Only used in MLC
- repetition_penalty?: number;
+ repetition_penalty?: number | null;
ignore_eos?: boolean;
// Shared by MLC and OpenAI APIs
top_p?: number | null;
@@ -143,6 +143,7 @@ export interface GenerationConfig {
response_format?: ResponseFormat | null;
// extra_body in ChatCompletionsRequest
enable_thinking?: boolean | null;
+ enable_latency_breakdown?: boolean | null;
}
export function postInitAndCheckGenerationConfigValues(
diff --git a/src/engine.ts b/src/engine.ts
index fe4c427e..6609f3e5 100644
--- a/src/engine.ts
+++ b/src/engine.ts
@@ -41,6 +41,7 @@ import {
MLCEngineInterface,
LogitProcessor,
LogLevel,
+ LatencyBreakdown,
} from "./types";
import {
compareConversationObject,
@@ -694,12 +695,18 @@ export class MLCEngine implements MLCEngineInterface {
const decode_time = pipeline.getCurRoundDecodingTotalTime();
const grammar_per_token_s =
pipeline.getCurRoundGrammarPerTokenTotalTime();
+ const latencyBreakdown: LatencyBreakdown =
+ pipeline.getCurRoundLatencyBreakdown();
+
const defaultExtra = {
e2e_latency_s: (Date.now() - timeReceived) / 1000,
prefill_tokens_per_s: prefill_tokens_per_s,
decode_tokens_per_s: decode_tokens_per_s,
time_to_first_token_s: prefill_time,
time_per_output_token_s: decode_time / completion_tokens,
+ latencyBreakdown: request.extra_body?.enable_latency_breakdown
+ ? latencyBreakdown
+ : undefined,
};
const usage: CompletionUsage = {
completion_tokens: completion_tokens,
@@ -783,6 +790,7 @@ export class MLCEngine implements MLCEngineInterface {
const genConfig: GenerationConfig = {
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
+ repetition_penalty: request.repetition_penalty,
max_tokens: request.max_tokens,
stop: request.stop,
top_p: request.top_p,
@@ -793,6 +801,7 @@ export class MLCEngine implements MLCEngineInterface {
response_format: request.response_format,
ignore_eos: request.ignore_eos,
enable_thinking: request.extra_body?.enable_thinking,
+ enable_latency_breakdown: request.extra_body?.enable_latency_breakdown,
};
// 0.5 Block wait until this pipeline finishes all previous requests
@@ -890,12 +899,19 @@ export class MLCEngine implements MLCEngineInterface {
"response_format" in request &&
(request.response_format?.type === "grammar" ||
request.response_format?.type === "json_object");
+
+ const latencyBreakdown: LatencyBreakdown =
+ selectedPipeline.getCurRoundLatencyBreakdown();
+
const defaultExtra = {
e2e_latency_s: (Date.now() - timeReceived) / 1000,
prefill_tokens_per_s: prompt_tokens / prefill_time,
decode_tokens_per_s: completion_tokens / decode_time,
time_to_first_token_s: prefill_time,
time_per_output_token_s: decode_time / completion_tokens,
+ latencyBreakdown: request.extra_body?.enable_latency_breakdown
+ ? latencyBreakdown
+ : undefined,
};
const response: ChatCompletion = {
id: crypto.randomUUID(),
@@ -958,6 +974,7 @@ export class MLCEngine implements MLCEngineInterface {
const genConfig: GenerationConfig = {
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
+ repetition_penalty: request.repetition_penalty,
max_tokens: request.max_tokens,
stop: request.stop,
top_p: request.top_p,
@@ -1030,6 +1047,9 @@ export class MLCEngine implements MLCEngineInterface {
decode_time += selectedPipeline.getCurRoundDecodingTotalTime();
}
+ const latencyBreakdown: LatencyBreakdown =
+ selectedPipeline.getCurRoundLatencyBreakdown();
+
const response: Completion = {
id: crypto.randomUUID(),
choices: choices,
@@ -1046,6 +1066,9 @@ export class MLCEngine implements MLCEngineInterface {
decode_tokens_per_s: completion_tokens / decode_time,
time_to_first_token_s: prefill_time,
time_per_output_token_s: decode_time / completion_tokens,
+ latencyBreakdown: request.extra_body?.enable_latency_breakdown
+ ? latencyBreakdown
+ : undefined,
},
} as CompletionUsage,
};
diff --git a/src/llm_chat.ts b/src/llm_chat.ts
index 5f8ecf00..8f7620a8 100644
--- a/src/llm_chat.ts
+++ b/src/llm_chat.ts
@@ -6,7 +6,7 @@ import log from "loglevel";
import { Tokenizer } from "@mlc-ai/web-tokenizers";
import { ChatConfig, GenerationConfig, Role } from "./config";
import { getConversation, Conversation } from "./conversation";
-import { LogitProcessor } from "./types";
+import { LogitProcessor, LatencyBreakdown } from "./types";
import {
getChunkedPrefillInputData,
getImageDataFromURL,
@@ -50,6 +50,10 @@ export class LLMChatPipeline {
private image_embed: tvmjs.PackedFunc | undefined;
private embed: tvmjs.PackedFunc;
private fapplyBitmask: tvmjs.PackedFunc;
+ private fapplyPenalty: tvmjs.PackedFunc;
+ private fapplyLogitBias: tvmjs.PackedFunc;
+ private fsoftmaxWithTemperature: tvmjs.PackedFunc;
+
// Functions related to PagedKVCache
private fclearKVCaches: tvmjs.PackedFunc;
private fKVCacheAddSequence: tvmjs.PackedFunc;
@@ -98,6 +102,16 @@ export class LLMChatPipeline {
private curRoundDecodingTotalTime = 0;
private curRoundPrefillTotalTime = 0;
+ // additional stats, reset at every prefillStep()
+ public curRoundLatencyBreakdown: LatencyBreakdown = {
+ logitProcessorTime: [],
+ logitBiasTime: [],
+ penaltyTime: [],
+ sampleTime: [],
+ totalTime: [],
+ grammarBitmaskTime: [],
+ };
+
// LogitProcessor
private logitProcessor?: LogitProcessor = undefined;
@@ -190,6 +204,15 @@ export class LLMChatPipeline {
this.fapplyBitmask = this.tvm.detachFromCurrentScope(
this.vm.getFunction("apply_bitmask_inplace"),
);
+ this.fapplyPenalty = this.tvm.detachFromCurrentScope(
+ this.vm.getFunction("apply_penalty_inplace"),
+ );
+ this.fapplyLogitBias = this.tvm.detachFromCurrentScope(
+ this.vm.getFunction("apply_logit_bias_inplace"),
+ );
+ this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope(
+ this.vm.getFunction("softmax_with_temperature"),
+ );
try {
this.image_embed = this.tvm.detachFromCurrentScope(
this.vm.getFunction("image_embed"),
@@ -421,6 +444,13 @@ export class LLMChatPipeline {
return this.curRoundGrammarPerTokenTotalTime;
}
+ /**
+ * @returns the breakdown of latencies for sampling each token for a single request.
+ */
+ getCurRoundLatencyBreakdown(): LatencyBreakdown {
+ return this.curRoundLatencyBreakdown;
+ }
+
/**
* @returns Runtime stats information.
*/
@@ -514,6 +544,16 @@ export class LLMChatPipeline {
this.curRoundDecodingTotalTime = 0;
this.curRoundGrammarInitTotalTime = 0;
this.curRoundGrammarPerTokenTotalTime = 0;
+
+ this.curRoundLatencyBreakdown = {
+ logitProcessorTime: [],
+ logitBiasTime: [],
+ penaltyTime: [],
+ sampleTime: [],
+ totalTime: [],
+ grammarBitmaskTime: [],
+ };
+
this.stopTriggered = false;
const conversation = this.conversation;
@@ -1036,11 +1076,15 @@ export class LLMChatPipeline {
throw new RangeError("presence_penalty", -2.0, 2.0);
}
+ const outputTokenBegin = performance.now();
+
// 0. Update logitsOnGPU with on-GPU grammar bitmasking
if (
response_format?.type === "json_object" ||
response_format?.type === "grammar"
) {
+ const grammarBitmaskBegin = performance.now();
+
this.tvm.beginScope();
if (this.grammarMatcher === undefined) {
throw Error("Expect grammar matcher to be initialized.");
@@ -1070,110 +1114,207 @@ export class LLMChatPipeline {
bitMaskOnGPU,
);
this.tvm.endScope();
+
+ if (genConfig?.enable_latency_breakdown) {
+ const grammarBitmaskEnd = performance.now();
+ const grammarBitmaskTimeSpent =
+ (grammarBitmaskEnd - grammarBitmaskBegin) / 1e3;
+ this.curRoundLatencyBreakdown.grammarBitmaskTime.push(
+ grammarBitmaskTimeSpent,
+ );
+ }
}
- // 1. Move logits to CPU
- this.tvm.beginScope();
- this.updateLogitsOnCPU(logitsOnGPU);
- this.tvm.endScope();
- await this.device.sync();
+ // 1. Apply logitProcessor on CPU
+ if (this.logitProcessor !== undefined) {
+ // Move logits to CPU
+ this.tvm.beginScope();
+ this.updateLogitsOnCPU(logitsOnGPU);
+ this.tvm.endScope();
+ await this.device.sync();
- if (this.logitsOnCPU == undefined) {
- throw Error("logits should be assigned");
- }
+ const logitProcessorBegin = performance.now();
- // 2. Post process logits via logitProcessor and/or logit_bias
- if (this.logitProcessor !== undefined || _hasValue(logit_bias)) {
+ if (this.logitsOnCPU == undefined) {
+ throw Error("logits should be assigned");
+ }
let logitsOnCPUArray: Float32Array = (
this.logitsOnCPU.toArray()
);
- const vocab_size = logitsOnCPUArray.length;
- if (this.logitProcessor !== undefined) {
- logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray);
- }
- if (_hasValue(logit_bias)) {
- for (const tokenID in logit_bias) {
- const curBias = logit_bias[tokenID];
- const curTokenID = parseInt(tokenID);
- if (curTokenID > vocab_size) {
- throw Error(
- "Token " +
- curTokenID +
- " in logit_bias exceeds vocab_size " +
- vocab_size,
- );
- }
- logitsOnCPUArray[curTokenID] += curBias;
- }
- }
+ logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray);
+ logitsOnGPU.copyFrom(logitsOnCPUArray);
this.logitsOnCPU.copyFrom(logitsOnCPUArray);
+
+ if (genConfig?.enable_latency_breakdown) {
+ const logitProcessorEnd = performance.now();
+ const logitProcessorTimeSpent =
+ (logitProcessorEnd - logitProcessorBegin) / 1e3;
+ this.curRoundLatencyBreakdown.logitProcessorTime.push(
+ logitProcessorTimeSpent,
+ );
+ }
}
- // 3. Apply penalties to logits
- if (_hasValue(frequency_penalty) && _hasValue(presence_penalty)) {
- // 3.1. Use frequency and presence penalty
+ // 2. Apply logit_bias on GPU
+ if (_hasValue(logit_bias)) {
+ const logitBiasBegin = performance.now();
+
+ const numTokens = Object.keys(logit_bias ?? {}).length;
+ const pos2seq_id = new Int32Array(numTokens).fill(0);
+ const tokenIds = new Int32Array(numTokens);
+ const tokenLogitBias = new Float32Array(numTokens);
+
+ const logitBiasKeys = Object.keys(logit_bias ?? {});
+ for (let index = 0; index < numTokens; index++) {
+ const tokenId = parseInt(logitBiasKeys[index]);
+ tokenIds[index] = tokenId;
+ tokenLogitBias[index] = logit_bias![tokenId];
+ }
+
this.tvm.beginScope();
- // Both `keys()` and `values()` are in insertion order.
- const appearedTokens = [...this.appearedTokensFreq.keys()];
- const appearedTokensFreqs = [...this.appearedTokensFreq.values()];
- const appeared_tokens_ndarray = this.tvm.empty(
- [1, appearedTokens.length],
- "int32",
- this.tvm.cpu(),
- );
- const appeared_tokens_freqs_ndarray = this.tvm.empty(
- [1, appearedTokensFreqs.length],
- "int32",
- this.tvm.cpu(),
- );
- appeared_tokens_ndarray.copyFrom(appearedTokens);
- appeared_tokens_freqs_ndarray.copyFrom(appearedTokensFreqs);
- this.tvm.applyPresenceAndFrequencyPenalty(
- this.logitsOnCPU,
- appeared_tokens_ndarray,
- appeared_tokens_freqs_ndarray,
- presence_penalty!,
- frequency_penalty!,
+
+ const pos2seqIdsArray = this.tvm
+ .empty([numTokens], "int32", this.device)
+ .copyFrom(pos2seq_id);
+
+ const tokenIdsArray = this.tvm
+ .empty([numTokens], "int32", this.device)
+ .copyFrom(tokenIds);
+
+ const tokenLogitBiasArray = this.tvm
+ .empty([numTokens], "float32", this.device)
+ .copyFrom(tokenLogitBias);
+
+ this.fapplyLogitBias(
+ logitsOnGPU.view([1, this.fullVocabSize]),
+ pos2seqIdsArray,
+ tokenIdsArray,
+ tokenLogitBiasArray,
);
+
this.tvm.endScope();
- } else if (repetition_penalty != 1.0) {
- // 3.2. Use repetition penalty
- this.tvm.beginScope();
+
+ if (genConfig?.enable_latency_breakdown) {
+ const logitBiasEnd = performance.now();
+ const logitBiasTimeSpent = (logitBiasEnd - logitBiasBegin) / 1e3;
+ this.curRoundLatencyBreakdown.logitBiasTime.push(logitBiasTimeSpent);
+ }
+ }
+
+ // 3. Apply penalties to logits on GPU
+ if (
+ frequency_penalty != 0.0 ||
+ presence_penalty != 0.0 ||
+ repetition_penalty != 1.0
+ ) {
const appearedTokens = [...this.appearedTokensFreq.keys()];
- const appeared_tokens_ndarray = this.tvm.empty(
- [1, appearedTokens.length],
- "int32",
- this.tvm.cpu(),
- );
- appeared_tokens_ndarray.copyFrom(appearedTokens);
- this.tvm.applyRepetitionPenalty(
- this.logitsOnCPU,
- appeared_tokens_ndarray,
- repetition_penalty,
- );
- this.tvm.endScope();
+ const appearedTokensFreqs = [...this.appearedTokensFreq.values()];
+
+ const numTokens = appearedTokens.length;
+
+ if (numTokens > 0) {
+ const penaltyBegin = performance.now();
+
+ const pos2seq_id = new Int32Array(numTokens).fill(0);
+ const tokenIds = new Int32Array(numTokens).fill(0);
+ const tokenCnt = new Int32Array(numTokens).fill(0);
+ const penalties = new Float32Array([
+ presence_penalty,
+ frequency_penalty,
+ repetition_penalty,
+ ]);
+
+ tokenIds.set(appearedTokens);
+ tokenCnt.set(appearedTokensFreqs);
+
+ this.tvm.beginScope();
+ const seqIdsArray = this.tvm
+ .empty([1], "int32", this.device)
+ .copyFrom([0]);
+
+ const pos2seqIdsArray = this.tvm
+ .empty([numTokens], "int32", this.device)
+ .copyFrom(pos2seq_id);
+
+ const tokenIdsArray = this.tvm
+ .empty([numTokens], "int32", this.device)
+ .copyFrom(tokenIds);
+
+ const tokenCntArray = this.tvm
+ .empty([numTokens], "int32", this.device)
+ .copyFrom(tokenCnt);
+
+ const penaltiesArray = this.tvm
+ .empty([1, 3], "float32", this.device)
+ .copyFrom(penalties);
+
+ this.fapplyPenalty(
+ logitsOnGPU.view([1, this.fullVocabSize]),
+ seqIdsArray,
+ pos2seqIdsArray,
+ tokenIdsArray,
+ tokenCntArray,
+ penaltiesArray,
+ );
+
+ this.tvm.endScope();
+
+ if (genConfig?.enable_latency_breakdown) {
+ const penaltyEnd = performance.now();
+ const penaltyTimeSpent = (penaltyEnd - penaltyBegin) / 1e3;
+ this.curRoundLatencyBreakdown.penaltyTime.push(penaltyTimeSpent);
+ }
+ }
}
// 4. Sample token from logits
// If logprobs, need the actual distribution via softmax, otherwise directly sample from logits
+ const sampleBegin = performance.now();
let sampledToken: number;
if (logprobs) {
// Inplace transform logitsOnCPU to a distribution
temperature = Math.max(1e-6, temperature); // to prevent division by zero
- this.tvm.applySoftmaxWithTemperature(this.logitsOnCPU, temperature);
- sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p);
+
+ const numSeqs = 1;
+
+ const temperatures = new Float32Array([temperature]);
+
+ this.tvm.beginScope();
+ const temperaturesArray = this.tvm
+ .empty([numSeqs], "float32", this.device)
+ .copyFrom(temperatures);
+
+ const probs = this.fsoftmaxWithTemperature(
+ logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]),
+ temperaturesArray,
+ );
+ this.updateLogitsOnCPU(probs);
+ this.tvm.endScope();
+ await this.device.sync();
+
+ sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU!, top_p);
this.tokenLogprobArray.push(
this.getTokenLogprob(sampledToken, top_logprobs!),
);
} else {
// temperature being 0 is allowed here, equivalent to argmax
+ this.tvm.beginScope();
+ this.updateLogitsOnCPU(logitsOnGPU);
+ this.tvm.endScope();
+ await this.device.sync();
sampledToken = this.tvm.sampleTopPFromLogits(
- this.logitsOnCPU,
+ this.logitsOnCPU!,
temperature,
top_p,
);
}
+ if (genConfig?.enable_latency_breakdown) {
+ const sampleEnd = performance.now();
+ const sampleTimeSpent = (sampleEnd - sampleBegin) / 1e3;
+ this.curRoundLatencyBreakdown.sampleTime.push(sampleTimeSpent);
+ }
+
// 5. Update logit processor
this.logitProcessor?.processSampledToken(sampledToken);
@@ -1194,6 +1335,12 @@ export class LLMChatPipeline {
}
}
+ if (genConfig?.enable_latency_breakdown) {
+ const outputTokenEnd = performance.now();
+ const outputTokenTimeSpent = (outputTokenEnd - outputTokenBegin) / 1e3;
+ this.curRoundLatencyBreakdown.totalTime.push(outputTokenTimeSpent);
+ }
+
return sampledToken;
}
diff --git a/src/openai_api_protocols/chat_completion.ts b/src/openai_api_protocols/chat_completion.ts
index 5a5229ed..e50352d6 100644
--- a/src/openai_api_protocols/chat_completion.ts
+++ b/src/openai_api_protocols/chat_completion.ts
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-import { MLCEngineInterface } from "../types";
+import { MLCEngineInterface, LatencyBreakdown } from "../types";
import {
functionCallingModelIds,
MessagePlaceholders,
@@ -125,6 +125,13 @@ export interface ChatCompletionRequestBase {
*/
presence_penalty?: number | null;
+ /**
+ * Penalizes new tokens based on whether they appear in the prompt and the
+ * generated text so far. Values greater than 1.0 encourage the model to use new
+ * tokens, while values less than 1.0 encourage the model to repeat tokens.
+ */
+ repetition_penalty?: number | null;
+
/**
* The maximum number of [tokens](/tokenizer) that can be generated in the chat
* completion.
@@ -268,6 +275,12 @@ export interface ChatCompletionRequestBase {
* @note Currently only allowed to be used for Qwen3 models, though not explicitly checked.
*/
enable_thinking?: boolean | null;
+
+ /**
+ * If set to true, the response will include a breakdown of the time spent in various
+ * stages of token sampling.
+ */
+ enable_latency_breakdown?: boolean | null;
};
}
@@ -980,6 +993,12 @@ export interface CompletionUsage {
* structured output. If n > 1, it is the average over all choices.
*/
grammar_per_token_s?: number;
+
+ /**
+ * If `enable_latency_breakdown` is set to true in the request, this field will be
+ * present and contain a breakdown of the time spent in various stages of token sampling.
+ */
+ latencyBreakdown?: LatencyBreakdown;
};
}
diff --git a/src/openai_api_protocols/completion.ts b/src/openai_api_protocols/completion.ts
index fb6aa458..0534fe93 100644
--- a/src/openai_api_protocols/completion.ts
+++ b/src/openai_api_protocols/completion.ts
@@ -137,6 +137,13 @@ export interface CompletionCreateParamsBase {
*/
presence_penalty?: number | null;
+ /**
+ * Penalizes new tokens based on whether they appear in the prompt and the
+ * generated text so far. Values greater than 1.0 encourage the model to use new
+ * tokens, while values less than 1.0 encourage the model to repeat tokens.
+ */
+ repetition_penalty?: number | null;
+
/**
* If specified, our system will make a best effort to sample deterministically,
* such that repeated requests with the same `seed` and parameters should return
@@ -225,6 +232,17 @@ export interface CompletionCreateParamsBase {
* @note This field is not supported.
*/
best_of?: number | null;
+
+ /**
+ * Fields specific to WebLLM, not present in OpenAI.
+ */
+ extra_body?: {
+ /**
+ * If set to true, the response will include a breakdown of the time spent in various
+ * stages of token sampling.
+ */
+ enable_latency_breakdown?: boolean | null;
+ };
}
export type CompletionCreateParams =
diff --git a/src/types.ts b/src/types.ts
index 4d4522c0..ed79af57 100644
--- a/src/types.ts
+++ b/src/types.ts
@@ -251,3 +251,12 @@ export const LOG_LEVELS = {
SILENT: 5,
};
export type LogLevel = keyof typeof LOG_LEVELS;
+
+export type LatencyBreakdown = {
+ logitProcessorTime: number[];
+ logitBiasTime: number[];
+ penaltyTime: number[];
+ sampleTime: number[];
+ totalTime: number[];
+ grammarBitmaskTime: number[];
+};
diff --git a/tests/.gitignore b/tests/.gitignore
new file mode 100644
index 00000000..d8b83df9
--- /dev/null
+++ b/tests/.gitignore
@@ -0,0 +1 @@
+package-lock.json
diff --git a/tests/scripts/sanity_checks/README.md b/tests/scripts/sanity_checks/README.md
new file mode 100644
index 00000000..8b65a916
--- /dev/null
+++ b/tests/scripts/sanity_checks/README.md
@@ -0,0 +1,14 @@
+# Sanity Checks for Generated Output
+
+This folder provides simple sanity checks on the output generated
+using WebLLM. To try it out, you can do the following steps under this folder
+
+```bash
+npm install
+npm start
+```
+
+Note if you would like to hack WebLLM core package.
+You can change web-llm dependencies as `"file:../.."`, and follow the build from source
+instruction in the project to build webllm locally. This option is only recommended
+if you would like to hack WebLLM core package.
diff --git a/tests/scripts/sanity_checks/package.json b/tests/scripts/sanity_checks/package.json
new file mode 100644
index 00000000..cf86fb3d
--- /dev/null
+++ b/tests/scripts/sanity_checks/package.json
@@ -0,0 +1,20 @@
+{
+ "name": "sanity_checks",
+ "version": "0.1.0",
+ "private": true,
+ "scripts": {
+ "start": "parcel sanity_checks.html --port 8889",
+ "build": "parcel build sanity_checks.html --dist-dir lib"
+ },
+ "devDependencies": {
+ "buffer": "^5.7.1",
+ "parcel": "^2.8.3",
+ "process": "^0.11.10",
+ "tslib": "^2.3.1",
+ "typescript": "^4.9.5",
+ "url": "^0.11.3"
+ },
+ "dependencies": {
+ "@mlc-ai/web-llm": "^0.2.79"
+ }
+}
diff --git a/tests/scripts/sanity_checks/sanity_checks.html b/tests/scripts/sanity_checks/sanity_checks.html
new file mode 100644
index 00000000..2b662f71
--- /dev/null
+++ b/tests/scripts/sanity_checks/sanity_checks.html
@@ -0,0 +1,49 @@
+
+
+
+
+
+ GPU sampleTokenFromLogits Tests
+
+
+
+