Skip to content

Commit e4b4dc2

Browse files
authored
[Kernels] Replace CPU Function Calls with GPU Kernel Invocations (#697)
1. Replace CPU function calls for the following tasks with GPU kernel invocations: - Apply logit bias - Apply penalties to logits - Compute softmax with temperature (sampling will be replaced in a future PR) 2. Fixed bug with repetition penalty not being used in generation config - Added repetition penalty to CompletionCreateParamsBase and ChatCompletionRequestBase interfaces - Updated definition in GenerationConfig and added reference in engine.ts 3. Added additional field in CompletionCreateParamsBase and ChatCompletionRequestBase interfaces to enable logging of time taken for individual steps 4. Added sanity checks for individual steps in sampleTokenFromLogits Performance Comparison: Compared performance for "canonical" flows averaged across 20 runs - No logit_bias - No logitProcessor - Applied penalties - With and without logprobs 1. Before PR performance (without logprobs): ~0.064s per output token (~15.63 decode tokens/s) 2. After PR performance (without logprobs): ~0.066s per output token (~15.15 decode tokens/s) 3. Before PR performance (with logprobs): ~0.052s per output token (~19.23 decode tokens/s) 5. After PR performance (without logprobs): ~0.048s per output token (~20.83 decode tokens/s) Additional Notes: - Need to profile performance of sampleTopPFromLogits vs sampleTopPFromProb on CPU to determine why performance with logprobs is better - Application of logit_bias is much faster on GPU than CPU - There are additional overheads outside of the sampleTokenFromLogits function that make the performance improvement less pronounced (the total time spent in sampleTokenFromLogits is ~0.0117s before the PR and ~0.0076s after the PR)
1 parent d8b25fe commit e4b4dc2

File tree

16 files changed

+752
-75
lines changed

16 files changed

+752
-75
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,5 +324,4 @@ node_modules
324324
lib
325325
.parcel-cache
326326

327-
examples/tests
328327
**/.next
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# WebLLM Get Started App
2+
3+
This folder provides a minimum demo to show WebLLM API in a webapp setting with
4+
collection of latency statistics for individual token sampling steps.
5+
To try it out, you can do the following steps under this folder
6+
7+
```bash
8+
npm install
9+
npm start
10+
```
11+
12+
Note if you would like to hack WebLLM core package.
13+
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
14+
instruction in the project to build webllm locally. This option is only recommended
15+
if you would like to hack WebLLM core package.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"name": "get-started-latency-breakdown",
3+
"version": "0.1.0",
4+
"private": true,
5+
"scripts": {
6+
"start": "parcel src/get_started_latency_breakdown.html --port 8888",
7+
"build": "parcel build src/get_started_latency_breakdown.html --dist-dir lib"
8+
},
9+
"devDependencies": {
10+
"buffer": "^5.7.1",
11+
"parcel": "^2.8.3",
12+
"process": "^0.11.10",
13+
"tslib": "^2.3.1",
14+
"typescript": "^4.9.5",
15+
"url": "^0.11.3"
16+
},
17+
"dependencies": {
18+
"@mlc-ai/web-llm": "^0.2.79"
19+
}
20+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<!doctype html>
2+
<html>
3+
<script>
4+
webLLMGlobal = {};
5+
</script>
6+
<body>
7+
<h2>WebLLM Test Page</h2>
8+
Open console to see output
9+
<br />
10+
<br />
11+
<label id="init-label"> </label>
12+
13+
<h3>Prompt</h3>
14+
<label id="prompt-label"> </label>
15+
16+
<h3>Response</h3>
17+
<label id="generate-label"> </label>
18+
<br />
19+
<label id="stats-label"> </label>
20+
21+
<script type="module" src="./get_started_latency_breakdown.ts"></script>
22+
</body>
23+
</html>
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import * as webllm from "@mlc-ai/web-llm";
2+
3+
function setLabel(id: string, text: string) {
4+
const label = document.getElementById(id);
5+
if (label == null) {
6+
throw Error("Cannot find label " + id);
7+
}
8+
label.innerText = text;
9+
}
10+
11+
type LatencyBreakdown = {
12+
logitProcessorTime: number[];
13+
logitBiasTime: number[];
14+
penaltyTime: number[];
15+
sampleTime: number[];
16+
totalTime: number[];
17+
grammarBitmaskTime: number[];
18+
};
19+
function computeStats(
20+
latency_breakdown: LatencyBreakdown,
21+
): Record<string, any> {
22+
function _computeStats(arr: number[]) {
23+
if (!arr.length) return undefined;
24+
const sorted = [...arr].sort((a, b) => a - b);
25+
const sum = arr.reduce((a, b) => a + b, 0);
26+
const avg = sum / arr.length;
27+
const min = sorted[0];
28+
const max = sorted[sorted.length - 1];
29+
const p99 = sorted[Math.floor(0.99 * (sorted.length - 1))];
30+
return { avg, min, max, p99 };
31+
}
32+
33+
const latencyStats: Record<string, any> = {};
34+
for (const key of Object.keys(latency_breakdown)) {
35+
const arr = (latency_breakdown as any)[key];
36+
if (Array.isArray(arr) && arr.length > 0) {
37+
latencyStats[key] = _computeStats(arr);
38+
}
39+
}
40+
return latencyStats;
41+
}
42+
43+
async function main() {
44+
const initProgressCallback = (report: webllm.InitProgressReport) => {
45+
setLabel("init-label", report.text);
46+
};
47+
// Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts`
48+
const selectedModel = "Qwen3-0.6B-q0f32-MLC";
49+
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
50+
selectedModel,
51+
{
52+
initProgressCallback: initProgressCallback,
53+
logLevel: "INFO", // specify the log level
54+
},
55+
// customize kv cache, use either context_window_size or sliding_window_size (with attention sink)
56+
{
57+
context_window_size: 2048,
58+
// sliding_window_size: 1024,
59+
// attention_sink_size: 4,
60+
},
61+
);
62+
63+
const latencyBreakdown: LatencyBreakdown = {
64+
logitProcessorTime: [],
65+
logitBiasTime: [],
66+
penaltyTime: [],
67+
sampleTime: [],
68+
totalTime: [],
69+
grammarBitmaskTime: [],
70+
};
71+
72+
const decodeTokensPerS: number[] = [];
73+
const completionTokens: number[] = [];
74+
const e2eLatencyS: number[] = [];
75+
const timePerOutputTokenS: number[] = [];
76+
77+
const numTrials = 20;
78+
for (let i = 0; i < numTrials; i++) {
79+
console.log(`Trial ${i + 1} / ${numTrials}`);
80+
const reply0 = await engine.chat.completions.create({
81+
messages: [{ role: "user", content: "List twenty US states." }],
82+
// below configurations are all optional
83+
n: 1,
84+
temperature: 0,
85+
max_tokens: 2048,
86+
// 46510 and 7188 are "California", and 8421 and 51325 are "Texas" in Llama-3.1-8B-Instruct
87+
// So we would have a higher chance of seeing the latter two, but never the first in the answer
88+
// logit_bias: {
89+
// "46510": -100,
90+
// "7188": -100,
91+
// "8421": 5,
92+
// "41325": 5,
93+
// },
94+
top_p: 0.8,
95+
logprobs: true,
96+
top_logprobs: 2,
97+
frequency_penalty: 1.2,
98+
presence_penalty: 1.0,
99+
repetition_penalty: 1.1,
100+
});
101+
102+
const logitProcessorTime =
103+
reply0.usage?.extra.latencyBreakdown?.logitProcessorTime;
104+
const logitBiasTime = reply0.usage?.extra.latencyBreakdown?.logitBiasTime;
105+
const penaltyTime = reply0.usage?.extra.latencyBreakdown?.penaltyTime;
106+
const sampleTime = reply0.usage?.extra.latencyBreakdown?.sampleTime;
107+
const totalTime = reply0.usage?.extra.latencyBreakdown?.totalTime;
108+
const grammarBitmaskTime =
109+
reply0.usage?.extra.latencyBreakdown?.grammarBitmaskTime;
110+
111+
latencyBreakdown.logitProcessorTime.push(...(logitProcessorTime || []));
112+
latencyBreakdown.logitBiasTime.push(...(logitBiasTime || []));
113+
latencyBreakdown.penaltyTime.push(...(penaltyTime || []));
114+
latencyBreakdown.sampleTime.push(...(sampleTime || []));
115+
latencyBreakdown.totalTime.push(...(totalTime || []));
116+
latencyBreakdown.grammarBitmaskTime.push(...(grammarBitmaskTime || []));
117+
118+
decodeTokensPerS.push(reply0.usage?.extra.decode_tokens_per_s || 0);
119+
e2eLatencyS.push(reply0.usage?.extra.e2e_latency_s || 0);
120+
timePerOutputTokenS.push(reply0.usage?.extra.time_per_output_token_s || 0);
121+
completionTokens.push(reply0.usage?.completion_tokens || 0);
122+
}
123+
124+
const latencyStats: { [key: string]: number } =
125+
computeStats(latencyBreakdown);
126+
console.log("Latency stats: ", latencyStats);
127+
console.log("Decode tokens per second: ", decodeTokensPerS);
128+
console.log("Completion tokens: ", completionTokens);
129+
console.log("E2E latency (s): ", e2eLatencyS);
130+
console.log("Time per output token (s): ", timePerOutputTokenS);
131+
132+
// To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)`
133+
}
134+
135+
main();

src/config.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ export interface MLCEngineConfig {
126126
*/
127127
export interface GenerationConfig {
128128
// Only used in MLC
129-
repetition_penalty?: number;
129+
repetition_penalty?: number | null;
130130
ignore_eos?: boolean;
131131
// Shared by MLC and OpenAI APIs
132132
top_p?: number | null;
@@ -143,6 +143,7 @@ export interface GenerationConfig {
143143
response_format?: ResponseFormat | null;
144144
// extra_body in ChatCompletionsRequest
145145
enable_thinking?: boolean | null;
146+
enable_latency_breakdown?: boolean | null;
146147
}
147148

148149
export function postInitAndCheckGenerationConfigValues(

src/engine.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import {
4141
MLCEngineInterface,
4242
LogitProcessor,
4343
LogLevel,
44+
LatencyBreakdown,
4445
} from "./types";
4546
import {
4647
compareConversationObject,
@@ -694,12 +695,18 @@ export class MLCEngine implements MLCEngineInterface {
694695
const decode_time = pipeline.getCurRoundDecodingTotalTime();
695696
const grammar_per_token_s =
696697
pipeline.getCurRoundGrammarPerTokenTotalTime();
698+
const latencyBreakdown: LatencyBreakdown =
699+
pipeline.getCurRoundLatencyBreakdown();
700+
697701
const defaultExtra = {
698702
e2e_latency_s: (Date.now() - timeReceived) / 1000,
699703
prefill_tokens_per_s: prefill_tokens_per_s,
700704
decode_tokens_per_s: decode_tokens_per_s,
701705
time_to_first_token_s: prefill_time,
702706
time_per_output_token_s: decode_time / completion_tokens,
707+
latencyBreakdown: request.extra_body?.enable_latency_breakdown
708+
? latencyBreakdown
709+
: undefined,
703710
};
704711
const usage: CompletionUsage = {
705712
completion_tokens: completion_tokens,
@@ -783,6 +790,7 @@ export class MLCEngine implements MLCEngineInterface {
783790
const genConfig: GenerationConfig = {
784791
frequency_penalty: request.frequency_penalty,
785792
presence_penalty: request.presence_penalty,
793+
repetition_penalty: request.repetition_penalty,
786794
max_tokens: request.max_tokens,
787795
stop: request.stop,
788796
top_p: request.top_p,
@@ -793,6 +801,7 @@ export class MLCEngine implements MLCEngineInterface {
793801
response_format: request.response_format,
794802
ignore_eos: request.ignore_eos,
795803
enable_thinking: request.extra_body?.enable_thinking,
804+
enable_latency_breakdown: request.extra_body?.enable_latency_breakdown,
796805
};
797806

798807
// 0.5 Block wait until this pipeline finishes all previous requests
@@ -890,12 +899,19 @@ export class MLCEngine implements MLCEngineInterface {
890899
"response_format" in request &&
891900
(request.response_format?.type === "grammar" ||
892901
request.response_format?.type === "json_object");
902+
903+
const latencyBreakdown: LatencyBreakdown =
904+
selectedPipeline.getCurRoundLatencyBreakdown();
905+
893906
const defaultExtra = {
894907
e2e_latency_s: (Date.now() - timeReceived) / 1000,
895908
prefill_tokens_per_s: prompt_tokens / prefill_time,
896909
decode_tokens_per_s: completion_tokens / decode_time,
897910
time_to_first_token_s: prefill_time,
898911
time_per_output_token_s: decode_time / completion_tokens,
912+
latencyBreakdown: request.extra_body?.enable_latency_breakdown
913+
? latencyBreakdown
914+
: undefined,
899915
};
900916
const response: ChatCompletion = {
901917
id: crypto.randomUUID(),
@@ -958,6 +974,7 @@ export class MLCEngine implements MLCEngineInterface {
958974
const genConfig: GenerationConfig = {
959975
frequency_penalty: request.frequency_penalty,
960976
presence_penalty: request.presence_penalty,
977+
repetition_penalty: request.repetition_penalty,
961978
max_tokens: request.max_tokens,
962979
stop: request.stop,
963980
top_p: request.top_p,
@@ -1030,6 +1047,9 @@ export class MLCEngine implements MLCEngineInterface {
10301047
decode_time += selectedPipeline.getCurRoundDecodingTotalTime();
10311048
}
10321049

1050+
const latencyBreakdown: LatencyBreakdown =
1051+
selectedPipeline.getCurRoundLatencyBreakdown();
1052+
10331053
const response: Completion = {
10341054
id: crypto.randomUUID(),
10351055
choices: choices,
@@ -1046,6 +1066,9 @@ export class MLCEngine implements MLCEngineInterface {
10461066
decode_tokens_per_s: completion_tokens / decode_time,
10471067
time_to_first_token_s: prefill_time,
10481068
time_per_output_token_s: decode_time / completion_tokens,
1069+
latencyBreakdown: request.extra_body?.enable_latency_breakdown
1070+
? latencyBreakdown
1071+
: undefined,
10491072
},
10501073
} as CompletionUsage,
10511074
};

0 commit comments

Comments
 (0)