|
| 1 | +import * as webllm from "@mlc-ai/web-llm"; |
| 2 | + |
| 3 | +function setLabel(id: string, text: string) { |
| 4 | + const label = document.getElementById(id); |
| 5 | + if (label == null) { |
| 6 | + throw Error("Cannot find label " + id); |
| 7 | + } |
| 8 | + label.innerText = text; |
| 9 | +} |
| 10 | + |
| 11 | +type LatencyBreakdown = { |
| 12 | + logitProcessorTime: number[]; |
| 13 | + logitBiasTime: number[]; |
| 14 | + penaltyTime: number[]; |
| 15 | + sampleTime: number[]; |
| 16 | + totalTime: number[]; |
| 17 | + grammarBitmaskTime: number[]; |
| 18 | +}; |
| 19 | +function computeStats( |
| 20 | + latency_breakdown: LatencyBreakdown, |
| 21 | +): Record<string, any> { |
| 22 | + function _computeStats(arr: number[]) { |
| 23 | + if (!arr.length) return undefined; |
| 24 | + const sorted = [...arr].sort((a, b) => a - b); |
| 25 | + const sum = arr.reduce((a, b) => a + b, 0); |
| 26 | + const avg = sum / arr.length; |
| 27 | + const min = sorted[0]; |
| 28 | + const max = sorted[sorted.length - 1]; |
| 29 | + const p99 = sorted[Math.floor(0.99 * (sorted.length - 1))]; |
| 30 | + return { avg, min, max, p99 }; |
| 31 | + } |
| 32 | + |
| 33 | + const latencyStats: Record<string, any> = {}; |
| 34 | + for (const key of Object.keys(latency_breakdown)) { |
| 35 | + const arr = (latency_breakdown as any)[key]; |
| 36 | + if (Array.isArray(arr) && arr.length > 0) { |
| 37 | + latencyStats[key] = _computeStats(arr); |
| 38 | + } |
| 39 | + } |
| 40 | + return latencyStats; |
| 41 | +} |
| 42 | + |
| 43 | +async function main() { |
| 44 | + const initProgressCallback = (report: webllm.InitProgressReport) => { |
| 45 | + setLabel("init-label", report.text); |
| 46 | + }; |
| 47 | + // Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts` |
| 48 | + const selectedModel = "Qwen3-0.6B-q0f32-MLC"; |
| 49 | + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( |
| 50 | + selectedModel, |
| 51 | + { |
| 52 | + initProgressCallback: initProgressCallback, |
| 53 | + logLevel: "INFO", // specify the log level |
| 54 | + }, |
| 55 | + // customize kv cache, use either context_window_size or sliding_window_size (with attention sink) |
| 56 | + { |
| 57 | + context_window_size: 2048, |
| 58 | + // sliding_window_size: 1024, |
| 59 | + // attention_sink_size: 4, |
| 60 | + }, |
| 61 | + ); |
| 62 | + |
| 63 | + const latencyBreakdown: LatencyBreakdown = { |
| 64 | + logitProcessorTime: [], |
| 65 | + logitBiasTime: [], |
| 66 | + penaltyTime: [], |
| 67 | + sampleTime: [], |
| 68 | + totalTime: [], |
| 69 | + grammarBitmaskTime: [], |
| 70 | + }; |
| 71 | + |
| 72 | + const decodeTokensPerS: number[] = []; |
| 73 | + const completionTokens: number[] = []; |
| 74 | + const e2eLatencyS: number[] = []; |
| 75 | + const timePerOutputTokenS: number[] = []; |
| 76 | + |
| 77 | + const numTrials = 20; |
| 78 | + for (let i = 0; i < numTrials; i++) { |
| 79 | + console.log(`Trial ${i + 1} / ${numTrials}`); |
| 80 | + const reply0 = await engine.chat.completions.create({ |
| 81 | + messages: [{ role: "user", content: "List twenty US states." }], |
| 82 | + // below configurations are all optional |
| 83 | + n: 1, |
| 84 | + temperature: 0, |
| 85 | + max_tokens: 2048, |
| 86 | + // 46510 and 7188 are "California", and 8421 and 51325 are "Texas" in Llama-3.1-8B-Instruct |
| 87 | + // So we would have a higher chance of seeing the latter two, but never the first in the answer |
| 88 | + // logit_bias: { |
| 89 | + // "46510": -100, |
| 90 | + // "7188": -100, |
| 91 | + // "8421": 5, |
| 92 | + // "41325": 5, |
| 93 | + // }, |
| 94 | + top_p: 0.8, |
| 95 | + logprobs: true, |
| 96 | + top_logprobs: 2, |
| 97 | + frequency_penalty: 1.2, |
| 98 | + presence_penalty: 1.0, |
| 99 | + repetition_penalty: 1.1, |
| 100 | + }); |
| 101 | + |
| 102 | + const logitProcessorTime = |
| 103 | + reply0.usage?.extra.latencyBreakdown?.logitProcessorTime; |
| 104 | + const logitBiasTime = reply0.usage?.extra.latencyBreakdown?.logitBiasTime; |
| 105 | + const penaltyTime = reply0.usage?.extra.latencyBreakdown?.penaltyTime; |
| 106 | + const sampleTime = reply0.usage?.extra.latencyBreakdown?.sampleTime; |
| 107 | + const totalTime = reply0.usage?.extra.latencyBreakdown?.totalTime; |
| 108 | + const grammarBitmaskTime = |
| 109 | + reply0.usage?.extra.latencyBreakdown?.grammarBitmaskTime; |
| 110 | + |
| 111 | + latencyBreakdown.logitProcessorTime.push(...(logitProcessorTime || [])); |
| 112 | + latencyBreakdown.logitBiasTime.push(...(logitBiasTime || [])); |
| 113 | + latencyBreakdown.penaltyTime.push(...(penaltyTime || [])); |
| 114 | + latencyBreakdown.sampleTime.push(...(sampleTime || [])); |
| 115 | + latencyBreakdown.totalTime.push(...(totalTime || [])); |
| 116 | + latencyBreakdown.grammarBitmaskTime.push(...(grammarBitmaskTime || [])); |
| 117 | + |
| 118 | + decodeTokensPerS.push(reply0.usage?.extra.decode_tokens_per_s || 0); |
| 119 | + e2eLatencyS.push(reply0.usage?.extra.e2e_latency_s || 0); |
| 120 | + timePerOutputTokenS.push(reply0.usage?.extra.time_per_output_token_s || 0); |
| 121 | + completionTokens.push(reply0.usage?.completion_tokens || 0); |
| 122 | + } |
| 123 | + |
| 124 | + const latencyStats: { [key: string]: number } = |
| 125 | + computeStats(latencyBreakdown); |
| 126 | + console.log("Latency stats: ", latencyStats); |
| 127 | + console.log("Decode tokens per second: ", decodeTokensPerS); |
| 128 | + console.log("Completion tokens: ", completionTokens); |
| 129 | + console.log("E2E latency (s): ", e2eLatencyS); |
| 130 | + console.log("Time per output token (s): ", timePerOutputTokenS); |
| 131 | + |
| 132 | + // To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)` |
| 133 | +} |
| 134 | + |
| 135 | +main(); |
0 commit comments