diff --git a/.gitignore b/.gitignore index 8de96b40..adc767f8 100644 --- a/.gitignore +++ b/.gitignore @@ -324,5 +324,4 @@ node_modules lib .parcel-cache -examples/tests **/.next \ No newline at end of file diff --git a/examples/get-started-latency-breakdown/README.md b/examples/get-started-latency-breakdown/README.md new file mode 100644 index 00000000..2a8f6967 --- /dev/null +++ b/examples/get-started-latency-breakdown/README.md @@ -0,0 +1,15 @@ +# WebLLM Get Started App + +This folder provides a minimum demo to show WebLLM API in a webapp setting with +collection of latency statistics for individual token sampling steps. +To try it out, you can do the following steps under this folder + +```bash +npm install +npm start +``` + +Note if you would like to hack WebLLM core package. +You can change web-llm dependencies as `"file:../.."`, and follow the build from source +instruction in the project to build webllm locally. This option is only recommended +if you would like to hack WebLLM core package. diff --git a/examples/get-started-latency-breakdown/package.json b/examples/get-started-latency-breakdown/package.json new file mode 100644 index 00000000..0b321e9d --- /dev/null +++ b/examples/get-started-latency-breakdown/package.json @@ -0,0 +1,20 @@ +{ + "name": "get-started-latency-breakdown", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel src/get_started_latency_breakdown.html --port 8888", + "build": "parcel build src/get_started_latency_breakdown.html --dist-dir lib" + }, + "devDependencies": { + "buffer": "^5.7.1", + "parcel": "^2.8.3", + "process": "^0.11.10", + "tslib": "^2.3.1", + "typescript": "^4.9.5", + "url": "^0.11.3" + }, + "dependencies": { + "@mlc-ai/web-llm": "^0.2.79" + } +} diff --git a/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.html b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.html new file mode 100644 index 00000000..18298616 --- /dev/null +++ b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.html @@ -0,0 +1,23 @@ + + + + +

WebLLM Test Page

+ Open console to see output +
+
+ + +

Prompt

+ + +

Response

+ +
+ + + + + diff --git a/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.ts b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.ts new file mode 100644 index 00000000..104af5da --- /dev/null +++ b/examples/get-started-latency-breakdown/src/get_started_latency_breakdown.ts @@ -0,0 +1,135 @@ +import * as webllm from "@mlc-ai/web-llm"; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +type LatencyBreakdown = { + logitProcessorTime: number[]; + logitBiasTime: number[]; + penaltyTime: number[]; + sampleTime: number[]; + totalTime: number[]; + grammarBitmaskTime: number[]; +}; +function computeStats( + latency_breakdown: LatencyBreakdown, +): Record { + function _computeStats(arr: number[]) { + if (!arr.length) return undefined; + const sorted = [...arr].sort((a, b) => a - b); + const sum = arr.reduce((a, b) => a + b, 0); + const avg = sum / arr.length; + const min = sorted[0]; + const max = sorted[sorted.length - 1]; + const p99 = sorted[Math.floor(0.99 * (sorted.length - 1))]; + return { avg, min, max, p99 }; + } + + const latencyStats: Record = {}; + for (const key of Object.keys(latency_breakdown)) { + const arr = (latency_breakdown as any)[key]; + if (Array.isArray(arr) && arr.length > 0) { + latencyStats[key] = _computeStats(arr); + } + } + return latencyStats; +} + +async function main() { + const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }; + // Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts` + const selectedModel = "Qwen3-0.6B-q0f32-MLC"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { + initProgressCallback: initProgressCallback, + logLevel: "INFO", // specify the log level + }, + // customize kv cache, use either context_window_size or sliding_window_size (with attention sink) + { + context_window_size: 2048, + // sliding_window_size: 1024, + // attention_sink_size: 4, + }, + ); + + const latencyBreakdown: LatencyBreakdown = { + logitProcessorTime: [], + logitBiasTime: [], + penaltyTime: [], + sampleTime: [], + totalTime: [], + grammarBitmaskTime: [], + }; + + const decodeTokensPerS: number[] = []; + const completionTokens: number[] = []; + const e2eLatencyS: number[] = []; + const timePerOutputTokenS: number[] = []; + + const numTrials = 20; + for (let i = 0; i < numTrials; i++) { + console.log(`Trial ${i + 1} / ${numTrials}`); + const reply0 = await engine.chat.completions.create({ + messages: [{ role: "user", content: "List twenty US states." }], + // below configurations are all optional + n: 1, + temperature: 0, + max_tokens: 2048, + // 46510 and 7188 are "California", and 8421 and 51325 are "Texas" in Llama-3.1-8B-Instruct + // So we would have a higher chance of seeing the latter two, but never the first in the answer + // logit_bias: { + // "46510": -100, + // "7188": -100, + // "8421": 5, + // "41325": 5, + // }, + top_p: 0.8, + logprobs: true, + top_logprobs: 2, + frequency_penalty: 1.2, + presence_penalty: 1.0, + repetition_penalty: 1.1, + }); + + const logitProcessorTime = + reply0.usage?.extra.latencyBreakdown?.logitProcessorTime; + const logitBiasTime = reply0.usage?.extra.latencyBreakdown?.logitBiasTime; + const penaltyTime = reply0.usage?.extra.latencyBreakdown?.penaltyTime; + const sampleTime = reply0.usage?.extra.latencyBreakdown?.sampleTime; + const totalTime = reply0.usage?.extra.latencyBreakdown?.totalTime; + const grammarBitmaskTime = + reply0.usage?.extra.latencyBreakdown?.grammarBitmaskTime; + + latencyBreakdown.logitProcessorTime.push(...(logitProcessorTime || [])); + latencyBreakdown.logitBiasTime.push(...(logitBiasTime || [])); + latencyBreakdown.penaltyTime.push(...(penaltyTime || [])); + latencyBreakdown.sampleTime.push(...(sampleTime || [])); + latencyBreakdown.totalTime.push(...(totalTime || [])); + latencyBreakdown.grammarBitmaskTime.push(...(grammarBitmaskTime || [])); + + decodeTokensPerS.push(reply0.usage?.extra.decode_tokens_per_s || 0); + e2eLatencyS.push(reply0.usage?.extra.e2e_latency_s || 0); + timePerOutputTokenS.push(reply0.usage?.extra.time_per_output_token_s || 0); + completionTokens.push(reply0.usage?.completion_tokens || 0); + } + + const latencyStats: { [key: string]: number } = + computeStats(latencyBreakdown); + console.log("Latency stats: ", latencyStats); + console.log("Decode tokens per second: ", decodeTokensPerS); + console.log("Completion tokens: ", completionTokens); + console.log("E2E latency (s): ", e2eLatencyS); + console.log("Time per output token (s): ", timePerOutputTokenS); + + // To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)` +} + +main(); diff --git a/src/config.ts b/src/config.ts index dfeb1913..0ca51eba 100644 --- a/src/config.ts +++ b/src/config.ts @@ -126,7 +126,7 @@ export interface MLCEngineConfig { */ export interface GenerationConfig { // Only used in MLC - repetition_penalty?: number; + repetition_penalty?: number | null; ignore_eos?: boolean; // Shared by MLC and OpenAI APIs top_p?: number | null; @@ -143,6 +143,7 @@ export interface GenerationConfig { response_format?: ResponseFormat | null; // extra_body in ChatCompletionsRequest enable_thinking?: boolean | null; + enable_latency_breakdown?: boolean | null; } export function postInitAndCheckGenerationConfigValues( diff --git a/src/engine.ts b/src/engine.ts index fe4c427e..6609f3e5 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -41,6 +41,7 @@ import { MLCEngineInterface, LogitProcessor, LogLevel, + LatencyBreakdown, } from "./types"; import { compareConversationObject, @@ -694,12 +695,18 @@ export class MLCEngine implements MLCEngineInterface { const decode_time = pipeline.getCurRoundDecodingTotalTime(); const grammar_per_token_s = pipeline.getCurRoundGrammarPerTokenTotalTime(); + const latencyBreakdown: LatencyBreakdown = + pipeline.getCurRoundLatencyBreakdown(); + const defaultExtra = { e2e_latency_s: (Date.now() - timeReceived) / 1000, prefill_tokens_per_s: prefill_tokens_per_s, decode_tokens_per_s: decode_tokens_per_s, time_to_first_token_s: prefill_time, time_per_output_token_s: decode_time / completion_tokens, + latencyBreakdown: request.extra_body?.enable_latency_breakdown + ? latencyBreakdown + : undefined, }; const usage: CompletionUsage = { completion_tokens: completion_tokens, @@ -783,6 +790,7 @@ export class MLCEngine implements MLCEngineInterface { const genConfig: GenerationConfig = { frequency_penalty: request.frequency_penalty, presence_penalty: request.presence_penalty, + repetition_penalty: request.repetition_penalty, max_tokens: request.max_tokens, stop: request.stop, top_p: request.top_p, @@ -793,6 +801,7 @@ export class MLCEngine implements MLCEngineInterface { response_format: request.response_format, ignore_eos: request.ignore_eos, enable_thinking: request.extra_body?.enable_thinking, + enable_latency_breakdown: request.extra_body?.enable_latency_breakdown, }; // 0.5 Block wait until this pipeline finishes all previous requests @@ -890,12 +899,19 @@ export class MLCEngine implements MLCEngineInterface { "response_format" in request && (request.response_format?.type === "grammar" || request.response_format?.type === "json_object"); + + const latencyBreakdown: LatencyBreakdown = + selectedPipeline.getCurRoundLatencyBreakdown(); + const defaultExtra = { e2e_latency_s: (Date.now() - timeReceived) / 1000, prefill_tokens_per_s: prompt_tokens / prefill_time, decode_tokens_per_s: completion_tokens / decode_time, time_to_first_token_s: prefill_time, time_per_output_token_s: decode_time / completion_tokens, + latencyBreakdown: request.extra_body?.enable_latency_breakdown + ? latencyBreakdown + : undefined, }; const response: ChatCompletion = { id: crypto.randomUUID(), @@ -958,6 +974,7 @@ export class MLCEngine implements MLCEngineInterface { const genConfig: GenerationConfig = { frequency_penalty: request.frequency_penalty, presence_penalty: request.presence_penalty, + repetition_penalty: request.repetition_penalty, max_tokens: request.max_tokens, stop: request.stop, top_p: request.top_p, @@ -1030,6 +1047,9 @@ export class MLCEngine implements MLCEngineInterface { decode_time += selectedPipeline.getCurRoundDecodingTotalTime(); } + const latencyBreakdown: LatencyBreakdown = + selectedPipeline.getCurRoundLatencyBreakdown(); + const response: Completion = { id: crypto.randomUUID(), choices: choices, @@ -1046,6 +1066,9 @@ export class MLCEngine implements MLCEngineInterface { decode_tokens_per_s: completion_tokens / decode_time, time_to_first_token_s: prefill_time, time_per_output_token_s: decode_time / completion_tokens, + latencyBreakdown: request.extra_body?.enable_latency_breakdown + ? latencyBreakdown + : undefined, }, } as CompletionUsage, }; diff --git a/src/llm_chat.ts b/src/llm_chat.ts index 5f8ecf00..8f7620a8 100644 --- a/src/llm_chat.ts +++ b/src/llm_chat.ts @@ -6,7 +6,7 @@ import log from "loglevel"; import { Tokenizer } from "@mlc-ai/web-tokenizers"; import { ChatConfig, GenerationConfig, Role } from "./config"; import { getConversation, Conversation } from "./conversation"; -import { LogitProcessor } from "./types"; +import { LogitProcessor, LatencyBreakdown } from "./types"; import { getChunkedPrefillInputData, getImageDataFromURL, @@ -50,6 +50,10 @@ export class LLMChatPipeline { private image_embed: tvmjs.PackedFunc | undefined; private embed: tvmjs.PackedFunc; private fapplyBitmask: tvmjs.PackedFunc; + private fapplyPenalty: tvmjs.PackedFunc; + private fapplyLogitBias: tvmjs.PackedFunc; + private fsoftmaxWithTemperature: tvmjs.PackedFunc; + // Functions related to PagedKVCache private fclearKVCaches: tvmjs.PackedFunc; private fKVCacheAddSequence: tvmjs.PackedFunc; @@ -98,6 +102,16 @@ export class LLMChatPipeline { private curRoundDecodingTotalTime = 0; private curRoundPrefillTotalTime = 0; + // additional stats, reset at every prefillStep() + public curRoundLatencyBreakdown: LatencyBreakdown = { + logitProcessorTime: [], + logitBiasTime: [], + penaltyTime: [], + sampleTime: [], + totalTime: [], + grammarBitmaskTime: [], + }; + // LogitProcessor private logitProcessor?: LogitProcessor = undefined; @@ -190,6 +204,15 @@ export class LLMChatPipeline { this.fapplyBitmask = this.tvm.detachFromCurrentScope( this.vm.getFunction("apply_bitmask_inplace"), ); + this.fapplyPenalty = this.tvm.detachFromCurrentScope( + this.vm.getFunction("apply_penalty_inplace"), + ); + this.fapplyLogitBias = this.tvm.detachFromCurrentScope( + this.vm.getFunction("apply_logit_bias_inplace"), + ); + this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope( + this.vm.getFunction("softmax_with_temperature"), + ); try { this.image_embed = this.tvm.detachFromCurrentScope( this.vm.getFunction("image_embed"), @@ -421,6 +444,13 @@ export class LLMChatPipeline { return this.curRoundGrammarPerTokenTotalTime; } + /** + * @returns the breakdown of latencies for sampling each token for a single request. + */ + getCurRoundLatencyBreakdown(): LatencyBreakdown { + return this.curRoundLatencyBreakdown; + } + /** * @returns Runtime stats information. */ @@ -514,6 +544,16 @@ export class LLMChatPipeline { this.curRoundDecodingTotalTime = 0; this.curRoundGrammarInitTotalTime = 0; this.curRoundGrammarPerTokenTotalTime = 0; + + this.curRoundLatencyBreakdown = { + logitProcessorTime: [], + logitBiasTime: [], + penaltyTime: [], + sampleTime: [], + totalTime: [], + grammarBitmaskTime: [], + }; + this.stopTriggered = false; const conversation = this.conversation; @@ -1036,11 +1076,15 @@ export class LLMChatPipeline { throw new RangeError("presence_penalty", -2.0, 2.0); } + const outputTokenBegin = performance.now(); + // 0. Update logitsOnGPU with on-GPU grammar bitmasking if ( response_format?.type === "json_object" || response_format?.type === "grammar" ) { + const grammarBitmaskBegin = performance.now(); + this.tvm.beginScope(); if (this.grammarMatcher === undefined) { throw Error("Expect grammar matcher to be initialized."); @@ -1070,110 +1114,207 @@ export class LLMChatPipeline { bitMaskOnGPU, ); this.tvm.endScope(); + + if (genConfig?.enable_latency_breakdown) { + const grammarBitmaskEnd = performance.now(); + const grammarBitmaskTimeSpent = + (grammarBitmaskEnd - grammarBitmaskBegin) / 1e3; + this.curRoundLatencyBreakdown.grammarBitmaskTime.push( + grammarBitmaskTimeSpent, + ); + } } - // 1. Move logits to CPU - this.tvm.beginScope(); - this.updateLogitsOnCPU(logitsOnGPU); - this.tvm.endScope(); - await this.device.sync(); + // 1. Apply logitProcessor on CPU + if (this.logitProcessor !== undefined) { + // Move logits to CPU + this.tvm.beginScope(); + this.updateLogitsOnCPU(logitsOnGPU); + this.tvm.endScope(); + await this.device.sync(); - if (this.logitsOnCPU == undefined) { - throw Error("logits should be assigned"); - } + const logitProcessorBegin = performance.now(); - // 2. Post process logits via logitProcessor and/or logit_bias - if (this.logitProcessor !== undefined || _hasValue(logit_bias)) { + if (this.logitsOnCPU == undefined) { + throw Error("logits should be assigned"); + } let logitsOnCPUArray: Float32Array = ( this.logitsOnCPU.toArray() ); - const vocab_size = logitsOnCPUArray.length; - if (this.logitProcessor !== undefined) { - logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray); - } - if (_hasValue(logit_bias)) { - for (const tokenID in logit_bias) { - const curBias = logit_bias[tokenID]; - const curTokenID = parseInt(tokenID); - if (curTokenID > vocab_size) { - throw Error( - "Token " + - curTokenID + - " in logit_bias exceeds vocab_size " + - vocab_size, - ); - } - logitsOnCPUArray[curTokenID] += curBias; - } - } + logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray); + logitsOnGPU.copyFrom(logitsOnCPUArray); this.logitsOnCPU.copyFrom(logitsOnCPUArray); + + if (genConfig?.enable_latency_breakdown) { + const logitProcessorEnd = performance.now(); + const logitProcessorTimeSpent = + (logitProcessorEnd - logitProcessorBegin) / 1e3; + this.curRoundLatencyBreakdown.logitProcessorTime.push( + logitProcessorTimeSpent, + ); + } } - // 3. Apply penalties to logits - if (_hasValue(frequency_penalty) && _hasValue(presence_penalty)) { - // 3.1. Use frequency and presence penalty + // 2. Apply logit_bias on GPU + if (_hasValue(logit_bias)) { + const logitBiasBegin = performance.now(); + + const numTokens = Object.keys(logit_bias ?? {}).length; + const pos2seq_id = new Int32Array(numTokens).fill(0); + const tokenIds = new Int32Array(numTokens); + const tokenLogitBias = new Float32Array(numTokens); + + const logitBiasKeys = Object.keys(logit_bias ?? {}); + for (let index = 0; index < numTokens; index++) { + const tokenId = parseInt(logitBiasKeys[index]); + tokenIds[index] = tokenId; + tokenLogitBias[index] = logit_bias![tokenId]; + } + this.tvm.beginScope(); - // Both `keys()` and `values()` are in insertion order. - const appearedTokens = [...this.appearedTokensFreq.keys()]; - const appearedTokensFreqs = [...this.appearedTokensFreq.values()]; - const appeared_tokens_ndarray = this.tvm.empty( - [1, appearedTokens.length], - "int32", - this.tvm.cpu(), - ); - const appeared_tokens_freqs_ndarray = this.tvm.empty( - [1, appearedTokensFreqs.length], - "int32", - this.tvm.cpu(), - ); - appeared_tokens_ndarray.copyFrom(appearedTokens); - appeared_tokens_freqs_ndarray.copyFrom(appearedTokensFreqs); - this.tvm.applyPresenceAndFrequencyPenalty( - this.logitsOnCPU, - appeared_tokens_ndarray, - appeared_tokens_freqs_ndarray, - presence_penalty!, - frequency_penalty!, + + const pos2seqIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(pos2seq_id); + + const tokenIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenIds); + + const tokenLogitBiasArray = this.tvm + .empty([numTokens], "float32", this.device) + .copyFrom(tokenLogitBias); + + this.fapplyLogitBias( + logitsOnGPU.view([1, this.fullVocabSize]), + pos2seqIdsArray, + tokenIdsArray, + tokenLogitBiasArray, ); + this.tvm.endScope(); - } else if (repetition_penalty != 1.0) { - // 3.2. Use repetition penalty - this.tvm.beginScope(); + + if (genConfig?.enable_latency_breakdown) { + const logitBiasEnd = performance.now(); + const logitBiasTimeSpent = (logitBiasEnd - logitBiasBegin) / 1e3; + this.curRoundLatencyBreakdown.logitBiasTime.push(logitBiasTimeSpent); + } + } + + // 3. Apply penalties to logits on GPU + if ( + frequency_penalty != 0.0 || + presence_penalty != 0.0 || + repetition_penalty != 1.0 + ) { const appearedTokens = [...this.appearedTokensFreq.keys()]; - const appeared_tokens_ndarray = this.tvm.empty( - [1, appearedTokens.length], - "int32", - this.tvm.cpu(), - ); - appeared_tokens_ndarray.copyFrom(appearedTokens); - this.tvm.applyRepetitionPenalty( - this.logitsOnCPU, - appeared_tokens_ndarray, - repetition_penalty, - ); - this.tvm.endScope(); + const appearedTokensFreqs = [...this.appearedTokensFreq.values()]; + + const numTokens = appearedTokens.length; + + if (numTokens > 0) { + const penaltyBegin = performance.now(); + + const pos2seq_id = new Int32Array(numTokens).fill(0); + const tokenIds = new Int32Array(numTokens).fill(0); + const tokenCnt = new Int32Array(numTokens).fill(0); + const penalties = new Float32Array([ + presence_penalty, + frequency_penalty, + repetition_penalty, + ]); + + tokenIds.set(appearedTokens); + tokenCnt.set(appearedTokensFreqs); + + this.tvm.beginScope(); + const seqIdsArray = this.tvm + .empty([1], "int32", this.device) + .copyFrom([0]); + + const pos2seqIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(pos2seq_id); + + const tokenIdsArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenIds); + + const tokenCntArray = this.tvm + .empty([numTokens], "int32", this.device) + .copyFrom(tokenCnt); + + const penaltiesArray = this.tvm + .empty([1, 3], "float32", this.device) + .copyFrom(penalties); + + this.fapplyPenalty( + logitsOnGPU.view([1, this.fullVocabSize]), + seqIdsArray, + pos2seqIdsArray, + tokenIdsArray, + tokenCntArray, + penaltiesArray, + ); + + this.tvm.endScope(); + + if (genConfig?.enable_latency_breakdown) { + const penaltyEnd = performance.now(); + const penaltyTimeSpent = (penaltyEnd - penaltyBegin) / 1e3; + this.curRoundLatencyBreakdown.penaltyTime.push(penaltyTimeSpent); + } + } } // 4. Sample token from logits // If logprobs, need the actual distribution via softmax, otherwise directly sample from logits + const sampleBegin = performance.now(); let sampledToken: number; if (logprobs) { // Inplace transform logitsOnCPU to a distribution temperature = Math.max(1e-6, temperature); // to prevent division by zero - this.tvm.applySoftmaxWithTemperature(this.logitsOnCPU, temperature); - sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p); + + const numSeqs = 1; + + const temperatures = new Float32Array([temperature]); + + this.tvm.beginScope(); + const temperaturesArray = this.tvm + .empty([numSeqs], "float32", this.device) + .copyFrom(temperatures); + + const probs = this.fsoftmaxWithTemperature( + logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]), + temperaturesArray, + ); + this.updateLogitsOnCPU(probs); + this.tvm.endScope(); + await this.device.sync(); + + sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU!, top_p); this.tokenLogprobArray.push( this.getTokenLogprob(sampledToken, top_logprobs!), ); } else { // temperature being 0 is allowed here, equivalent to argmax + this.tvm.beginScope(); + this.updateLogitsOnCPU(logitsOnGPU); + this.tvm.endScope(); + await this.device.sync(); sampledToken = this.tvm.sampleTopPFromLogits( - this.logitsOnCPU, + this.logitsOnCPU!, temperature, top_p, ); } + if (genConfig?.enable_latency_breakdown) { + const sampleEnd = performance.now(); + const sampleTimeSpent = (sampleEnd - sampleBegin) / 1e3; + this.curRoundLatencyBreakdown.sampleTime.push(sampleTimeSpent); + } + // 5. Update logit processor this.logitProcessor?.processSampledToken(sampledToken); @@ -1194,6 +1335,12 @@ export class LLMChatPipeline { } } + if (genConfig?.enable_latency_breakdown) { + const outputTokenEnd = performance.now(); + const outputTokenTimeSpent = (outputTokenEnd - outputTokenBegin) / 1e3; + this.curRoundLatencyBreakdown.totalTime.push(outputTokenTimeSpent); + } + return sampledToken; } diff --git a/src/openai_api_protocols/chat_completion.ts b/src/openai_api_protocols/chat_completion.ts index 5a5229ed..e50352d6 100644 --- a/src/openai_api_protocols/chat_completion.ts +++ b/src/openai_api_protocols/chat_completion.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { MLCEngineInterface } from "../types"; +import { MLCEngineInterface, LatencyBreakdown } from "../types"; import { functionCallingModelIds, MessagePlaceholders, @@ -125,6 +125,13 @@ export interface ChatCompletionRequestBase { */ presence_penalty?: number | null; + /** + * Penalizes new tokens based on whether they appear in the prompt and the + * generated text so far. Values greater than 1.0 encourage the model to use new + * tokens, while values less than 1.0 encourage the model to repeat tokens. + */ + repetition_penalty?: number | null; + /** * The maximum number of [tokens](/tokenizer) that can be generated in the chat * completion. @@ -268,6 +275,12 @@ export interface ChatCompletionRequestBase { * @note Currently only allowed to be used for Qwen3 models, though not explicitly checked. */ enable_thinking?: boolean | null; + + /** + * If set to true, the response will include a breakdown of the time spent in various + * stages of token sampling. + */ + enable_latency_breakdown?: boolean | null; }; } @@ -980,6 +993,12 @@ export interface CompletionUsage { * structured output. If n > 1, it is the average over all choices. */ grammar_per_token_s?: number; + + /** + * If `enable_latency_breakdown` is set to true in the request, this field will be + * present and contain a breakdown of the time spent in various stages of token sampling. + */ + latencyBreakdown?: LatencyBreakdown; }; } diff --git a/src/openai_api_protocols/completion.ts b/src/openai_api_protocols/completion.ts index fb6aa458..0534fe93 100644 --- a/src/openai_api_protocols/completion.ts +++ b/src/openai_api_protocols/completion.ts @@ -137,6 +137,13 @@ export interface CompletionCreateParamsBase { */ presence_penalty?: number | null; + /** + * Penalizes new tokens based on whether they appear in the prompt and the + * generated text so far. Values greater than 1.0 encourage the model to use new + * tokens, while values less than 1.0 encourage the model to repeat tokens. + */ + repetition_penalty?: number | null; + /** * If specified, our system will make a best effort to sample deterministically, * such that repeated requests with the same `seed` and parameters should return @@ -225,6 +232,17 @@ export interface CompletionCreateParamsBase { * @note This field is not supported. */ best_of?: number | null; + + /** + * Fields specific to WebLLM, not present in OpenAI. + */ + extra_body?: { + /** + * If set to true, the response will include a breakdown of the time spent in various + * stages of token sampling. + */ + enable_latency_breakdown?: boolean | null; + }; } export type CompletionCreateParams = diff --git a/src/types.ts b/src/types.ts index 4d4522c0..ed79af57 100644 --- a/src/types.ts +++ b/src/types.ts @@ -251,3 +251,12 @@ export const LOG_LEVELS = { SILENT: 5, }; export type LogLevel = keyof typeof LOG_LEVELS; + +export type LatencyBreakdown = { + logitProcessorTime: number[]; + logitBiasTime: number[]; + penaltyTime: number[]; + sampleTime: number[]; + totalTime: number[]; + grammarBitmaskTime: number[]; +}; diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..d8b83df9 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +package-lock.json diff --git a/tests/scripts/sanity_checks/README.md b/tests/scripts/sanity_checks/README.md new file mode 100644 index 00000000..8b65a916 --- /dev/null +++ b/tests/scripts/sanity_checks/README.md @@ -0,0 +1,14 @@ +# Sanity Checks for Generated Output + +This folder provides simple sanity checks on the output generated +using WebLLM. To try it out, you can do the following steps under this folder + +```bash +npm install +npm start +``` + +Note if you would like to hack WebLLM core package. +You can change web-llm dependencies as `"file:../.."`, and follow the build from source +instruction in the project to build webllm locally. This option is only recommended +if you would like to hack WebLLM core package. diff --git a/tests/scripts/sanity_checks/package.json b/tests/scripts/sanity_checks/package.json new file mode 100644 index 00000000..cf86fb3d --- /dev/null +++ b/tests/scripts/sanity_checks/package.json @@ -0,0 +1,20 @@ +{ + "name": "sanity_checks", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel sanity_checks.html --port 8889", + "build": "parcel build sanity_checks.html --dist-dir lib" + }, + "devDependencies": { + "buffer": "^5.7.1", + "parcel": "^2.8.3", + "process": "^0.11.10", + "tslib": "^2.3.1", + "typescript": "^4.9.5", + "url": "^0.11.3" + }, + "dependencies": { + "@mlc-ai/web-llm": "^0.2.79" + } +} diff --git a/tests/scripts/sanity_checks/sanity_checks.html b/tests/scripts/sanity_checks/sanity_checks.html new file mode 100644 index 00000000..2b662f71 --- /dev/null +++ b/tests/scripts/sanity_checks/sanity_checks.html @@ -0,0 +1,49 @@ + + + + + + GPU sampleTokenFromLogits Tests + + + +

GPU sampleTokenFromLogits Tests

+ +
Overall:
+
Not started.
+
Logit Processor:
+
+
Logit Bias:
+
+
Penalties:
+
+
Logprobs:
+
+ + + diff --git a/tests/scripts/sanity_checks/sanity_checks.ts b/tests/scripts/sanity_checks/sanity_checks.ts new file mode 100644 index 00000000..da842353 --- /dev/null +++ b/tests/scripts/sanity_checks/sanity_checks.ts @@ -0,0 +1,184 @@ +import * as webllm from "@mlc-ai/web-llm"; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) return; + label.innerText = text; +} + +async function createEngine( + modelId: string, + appConfig: webllm.AppConfig, + logitProcessorRegistry?: Map, +) { + return await webllm.CreateMLCEngine(modelId, { + appConfig, + logLevel: "ERROR", + logitProcessorRegistry, + }); +} + +async function deleteModel(modelId: string, appConfig: webllm.AppConfig) { + await webllm.deleteModelAllInfoInCache(modelId, appConfig); +} + +async function testLogitProcessor( + modelId: string, + appConfig: webllm.AppConfig, +) { + // Set up a logit processor that sets logits[0] = 100.0, rest -100.0 + const logitProcessor = { + processLogits: (logits: Float32Array) => { + logits.fill(-100.0); + logits[0] = 100.0; + return logits; + }, + processSampledToken: () => {}, + resetState: () => {}, + }; + const logitProcessorRegistry: Map = new Map(); + logitProcessorRegistry.set(modelId, logitProcessor); + const engine: webllm.MLCEngineInterface = await createEngine( + modelId, + appConfig, + logitProcessorRegistry, + ); + + const prompt = "Test logit processor."; + const reply: webllm.ChatCompletion = await engine.chat.completions.create({ + messages: [{ role: "user", content: prompt }], + temperature: 1.0, + max_tokens: 20, + logprobs: true, + top_logprobs: 1, + }); + const logprobs = reply.choices[0]?.logprobs; + const logprobsAllZero = !!( + logprobs && + Array.isArray(logprobs.content) && + logprobs.content.every( + (lp: webllm.ChatCompletionTokenLogprob) => + lp.top_logprobs[0].logprob === 0, + ) + ); + + console.log(`[LogitProcessor] Logprobs all zero: ${logprobsAllZero}`); + setLabel("logit-processor-label", `Logprobs all zero: ${logprobsAllZero}`); + await deleteModel(modelId, appConfig); + return logprobsAllZero; +} + +async function testLogitBias(modelId: string, appConfig: webllm.AppConfig) { + // Set logit_bias to strongly favor token 0 + const prompt = "Test logit bias."; + const engine: webllm.MLCEngineInterface = await createEngine( + modelId, + appConfig, + ); + const reply = await engine.chat.completions.create({ + messages: [{ role: "user", content: prompt }], + temperature: 1.0, + max_tokens: 20, + logprobs: true, + top_logprobs: 1, + logit_bias: { "0": 100.0 }, + }); + const logprobs = reply.choices[0]?.logprobs; + const logprobsAllZero = !!( + logprobs && + Array.isArray(logprobs.content) && + logprobs.content.every( + (lp: webllm.ChatCompletionTokenLogprob) => + lp.top_logprobs[0].logprob === 0, + ) + ); + + console.log(`[LogitBias] Logprobs all zero: ${logprobsAllZero}`); + setLabel("logit-bias-label", `Logprobs all zero: ${logprobsAllZero}`); + await deleteModel(modelId, appConfig); + return logprobsAllZero; +} + +async function testPenalties(modelId: string, appConfig: webllm.AppConfig) { + const prompt = "Test presence and frequency penalties."; + const engine: webllm.MLCEngineInterface = await createEngine( + modelId, + appConfig, + ); + const reply = await engine.chat.completions.create({ + messages: [{ role: "user", content: prompt }], + temperature: 1.0, + max_tokens: 256, + presence_penalty: 2.0, + frequency_penalty: 2.0, + logit_bias: { "0": 100.0 }, + logprobs: true, + }); + const logprobs = reply.choices[0]?.logprobs; + const logprobsNotAllZero = !logprobs?.content?.every( + (lp: webllm.ChatCompletionTokenLogprob) => lp.logprob === 0, + ); + console.log(`[Penalties] Logprobs not all zero: ${logprobsNotAllZero}`); + setLabel("penalty-label", `Logprobs not all zero: ${logprobsNotAllZero}`); + await deleteModel(modelId, appConfig); + return logprobsNotAllZero; +} + +async function testLogprobs(modelId: string, appConfig: webllm.AppConfig) { + // Test logprobs: check that logprobs are returned and sum to ~1 after exp + const prompt = "Test logprobs."; + const engine: webllm.MLCEngineInterface = await createEngine( + modelId, + appConfig, + ); + const reply = await engine.chat.completions.create({ + messages: [{ role: "user", content: prompt }], + temperature: 1.0, + max_tokens: 20, + logprobs: true, + top_logprobs: 5, + }); + const logprobs = reply.choices[0]?.logprobs; + + let logprobsAllCloseTo1 = true; + for (const lp of logprobs?.content || []) { + const expSum = lp.top_logprobs?.reduce( + (acc: number, val: webllm.TopLogprob) => acc + Math.exp(val.logprob), + 0, + ); + logprobsAllCloseTo1 &&= Math.abs(expSum - 1.0) < 0.1; + } + console.log(`[Logprobs] Logprobs all close to 1: ${logprobsAllCloseTo1}`); + setLabel("logprobs-label", `Logprobs all close to 1: ${logprobsAllCloseTo1}`); + await deleteModel(modelId, appConfig); + return logprobsAllCloseTo1; +} + +async function main() { + const modelId = "Qwen3-0.6B-q0f32-MLC"; + const appConfig = webllm.prebuiltAppConfig; + appConfig.useIndexedDBCache = true; + setLabel("gpu-test-label", "Running tests..."); + let passed = 0, + total = 0; + + if (await testLogitProcessor(modelId, appConfig)) passed++; + total++; + if (await testLogitBias(modelId, appConfig)) passed++; + total++; + if (await testPenalties(modelId, appConfig)) passed++; + total++; + if (await testLogprobs(modelId, appConfig)) passed++; + total++; + + setLabel( + "gpu-test-label", + `GPU sampleTokenFromLogits tests: ${passed}/${total} passed.`, + ); + setLabel( + "gpu-test-label", + `Tests complete. Model deleted. ${passed}/${total} passed.`, + ); +} + +main();