Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 26f7b83

Browse files
fix engine race condition (#181)
* fix engine race condition * update logic
1 parent 24853c0 commit 26f7b83

File tree

3 files changed

+101
-62
lines changed

3 files changed

+101
-62
lines changed

examples/server/server.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ int main(int argc, char** argv) {
241241

242242
LOG_INFO << "HTTP server listening: " << hostname << ":" << port;
243243
svr->new_task_queue = [] {
244-
return new httplib::ThreadPool(5);
244+
return new httplib::ThreadPool(64);
245245
};
246246
// run the HTTP server in a thread - see comment below
247247
std::thread t([&]() {

scripts/benchmark.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77

88
start_time = time.time()
99
SERVER_ENDPOINT = "http://localhost:3928"
10-
TOTAL_REQUESTS = 16
11-
N_PARALLEL = 4
12-
MAX_CTX_FOR_ONE_SEQUENCE = 512
13-
N_CTX = MAX_CTX_FOR_ONE_SEQUENCE*N_PARALLEL # this number related to reserve GPU memory for kv cache
10+
TOTAL_USERS = 40
11+
NUM_ROUNDS = 10
12+
MAX_TOKENS = 500
13+
N_PARALLEL = 32
14+
MAX_CTX_FOR_ONE_SEQUENCE = 1000
15+
# this number related to reserve GPU memory for kv cache
16+
N_CTX = MAX_CTX_FOR_ONE_SEQUENCE*N_PARALLEL
17+
1418

1519
def start_server():
1620
import subprocess
@@ -21,80 +25,108 @@ def start_server():
2125

2226
def load_model():
2327
headers = {"Content-Type": "application/json"}
24-
data = {"llama_model_path": "/mnt/nas/gguf-models/meta-llama3.1-8b-instruct-q4km.gguf", "model_alias": "meta-llama3.1-8b-instruct",
25-
"model": "meta-llama3.1-8b-instruct", "ctx_len": N_CTX,"n_batch":2048, "ngl": 300, "model_type": "llm", "n_parallel": N_PARALLEL}
28+
data = {"llama_model_path": "/mnt/nas/gguf-models/meta-llama3.1-8b-instruct-q4km.gguf", "model_alias": "meta-llama3.1-8b-instruct","engine": "cortex.llamacpp",
29+
"model": "meta-llama3.1-8b-instruct", "ctx_len": N_CTX, "ngl": 300, "model_type": "llm", "n_parallel": N_PARALLEL}
30+
2631
result = requests.post(SERVER_ENDPOINT+"/loadmodel",
2732
headers=headers, json=data)
33+
# result = requests.post(SERVER_ENDPOINT+"/inferences/server/loadmodel",
34+
# headers=headers, json=data)
2835
print(result.json())
2936

3037

31-
async def send_request(session, prompt):
38+
async def send_request(session, prompt,sleep = 0):
39+
await asyncio.sleep(sleep)
3240
headers = {"Content-Type": "application/json"}
33-
data = {"model": "meta-llama3.1-8b-instruct",
41+
data = {"model": "meta-llama3.1-8b-instruct", "max_tokens": MAX_TOKENS, "stop": ["<|eom_id|>", "<|end_of_text|>", "<|eot_id|>"],"engine": "cortex.llamacpp",
3442
"messages": [{"role": "user", "content": prompt},]}
3543
async with session.post(SERVER_ENDPOINT+"/v1/chat/completions", headers=headers, json=data) as resp:
3644
result = await resp.json()
3745
return result
3846

47+
async def one_user(session, prompt):
48+
tasks = [send_request(session, prompt,random.random()*0.2+ i ) for i in range(NUM_ROUNDS)]
49+
results = await asyncio.gather(*tasks)
50+
return results
51+
3952

4053
async def send_request_sequence():
4154
# warm up
42-
async with aiohttp.ClientSession() as session:
55+
async with aiohttp.ClientSession(timeout = aiohttp.ClientTimeout()) as session:
4356
res = await send_request(session, "What is GPU?")
4457

4558
start = time.time()
4659
total_token_processed = 0
47-
async with aiohttp.ClientSession() as session:
60+
async with aiohttp.ClientSession(timeout = aiohttp.ClientTimeout()) as session:
4861

4962
tasks = []
5063
prompts = ["What is GPU?", "Who won the world cup 2022?", "Tell me some dad's joke",
5164
"Write a quick sort function", "What is the price of Nvidia H100?", "Who won the world series in 2020?"]
52-
for number in range(TOTAL_REQUESTS):
65+
for number in range(TOTAL_USERS):
5366
res = await send_request(session, random.choice(prompts))
5467
if res.get("usage"):
5568
total_token_processed += res["usage"]["total_tokens"]
5669
else:
5770
print(res)
58-
71+
5972
end = time.time()
6073
print("Finished in", end-start, "s")
6174
print("Total token:", total_token_processed)
62-
print("Throughput when run in sequence:", total_token_processed/(end-start), "tokens/s")
75+
print("Throughput when run in sequence:",
76+
total_token_processed/(end-start), "tokens/s")
6377
print("------------------------------------------------------------------------")
6478

6579

6680
async def main():
6781
# warm up
68-
async with aiohttp.ClientSession() as session:
82+
async with aiohttp.ClientSession(timeout = aiohttp.ClientTimeout()) as session:
6983
res = await send_request(session, "What is GPU?")
7084

7185
start = time.time()
7286
total_token_processed = 0
73-
async with aiohttp.ClientSession() as session:
87+
async with aiohttp.ClientSession(timeout = aiohttp.ClientTimeout()) as session:
7488

7589
tasks = []
76-
prompts = ["What is GPU?", "Who won the world cup 2022?", "Tell me some dad's joke",
77-
"Write a quick sort function", "What is the price of Nvidia H100?", "Who won the world series in 2020?"]
78-
for number in range(TOTAL_REQUESTS):
90+
prompts = [
91+
"What is GPU?",
92+
"Who won the world cup 2022?",
93+
"Tell me so many dad's joke,",
94+
"Write a quick sort function,",
95+
"What is the price of Nvidia H100?",
96+
"Who won the world series in 2020?",
97+
"Tell me a very long story,",
98+
"Who is the best football player in the world?",
99+
"Tell me about compiler,",
100+
"Tell me about AI,"]
101+
for number in range(TOTAL_USERS):
79102
tasks.append(asyncio.ensure_future(
80-
send_request(session, random.choice(prompts))))
81-
82-
results = await asyncio.gather(*tasks)
83-
for res in results:
84-
# print(res)
85-
if res.get("usage"):
86-
total_token_processed += res["usage"]["total_tokens"]
87-
else:
88-
print(res)
103+
one_user(session, random.choice(prompts))))
104+
105+
list_results = await asyncio.gather(*tasks)
106+
for results in list_results:
107+
for res in results:
108+
# print(res)
109+
if res.get("usage"):
110+
total_token_processed += res["usage"]["total_tokens"]
111+
else:
112+
print(res)
89113
end = time.time()
90114
print("Finished in", end-start, "s")
91115
print("Total token:", total_token_processed)
92-
print("Throughput when run parallel:", total_token_processed/(end-start), "tokens/s")
116+
print("Throughput when run parallel:",
117+
total_token_processed/(end-start), "tokens/s")
93118
print("------------------------------------------------------------------------")
119+
with open("result.log","w") as writer:
120+
for results in list_results:
121+
for res in results:
122+
try:
123+
writer.write(res["choices"][0]["message"]["content"] + "\n\n")
124+
except:
125+
continue
94126
# start_server()
95127
load_model()
96128

97129
asyncio.run(main())
98130

99-
asyncio.run(send_request_sequence())
131+
# asyncio.run(send_request_sequence())
100132
print("--- %s seconds ---" % (time.time() - start_time))

src/llama_server_context.cc

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,7 @@ void LlamaServerContext::SendError(int id_task, int id_multi,
836836
res.stop = false;
837837
res.error = true;
838838
res.result_json = {{"content", error}};
839+
LOG_ERROR << "Internel error catched " << error;
839840
{
840841
std::lock_guard<std::mutex> lock(mutex_results);
841842
queue_results.push_back(res);
@@ -1150,44 +1151,50 @@ void LlamaServerContext::ProcessTasks() {
11501151
l.unlock();
11511152
break;
11521153
}
1154+
11531155
TaskServer task = queue_tasks.front();
1154-
queue_tasks.erase(queue_tasks.begin());
1155-
l.unlock();
1156-
switch (task.type) {
1157-
case TaskType::kCompletionTask: {
1158-
LlamaClientSlot* slot = GetSlot(json_value(task.data, "slot_id", -1));
1159-
if (slot == nullptr) {
1160-
LOG_WARN << "slot unavailable";
1161-
// send error result
1162-
SendError(task, "slot unavailable");
1163-
return;
1164-
}
11651156

1166-
if (task.data.contains("system_prompt")) {
1167-
ProcessSystemPromptData(task.data["system_prompt"]);
1157+
if (task.type == TaskType::kCancelTask) {
1158+
queue_tasks.erase(queue_tasks.begin());
1159+
1160+
for (auto& slot : slots) {
1161+
if (slot.task_id == task.target_id) {
1162+
slot.Release();
1163+
break;
11681164
}
1165+
}
1166+
l.unlock();
1167+
} else if (task.type == TaskType::kCompletionTask) {
1168+
LlamaClientSlot* slot = GetSlot(json_value(task.data, "slot_id", -1));
1169+
if (slot == nullptr) {
1170+
l.unlock();
1171+
return;
1172+
}
1173+
queue_tasks.erase(queue_tasks.begin());
1174+
l.unlock();
1175+
if (slot == nullptr) {
1176+
LOG_WARN << "slot unavailable";
1177+
// send error result
1178+
SendError(task, "slot unavailable");
1179+
return;
1180+
}
11691181

1170-
slot->Reset();
1182+
if (task.data.contains("system_prompt")) {
1183+
ProcessSystemPromptData(task.data["system_prompt"]);
1184+
}
11711185

1172-
slot->infill = task.infill_mode;
1173-
slot->embedding = task.embedding_mode;
1174-
slot->task_id = task.id;
1175-
slot->multitask_id = task.multitask_id;
1186+
slot->Reset();
11761187

1177-
if (!LaunchSlotWithData(slot, task.data)) {
1178-
// send error result
1179-
SendError(task, "internal_error");
1180-
break;
1181-
}
1182-
} break;
1183-
case TaskType::kCancelTask: { // release slot linked with the task id
1184-
for (auto& slot : slots) {
1185-
if (slot.task_id == task.target_id) {
1186-
slot.Release();
1187-
break;
1188-
}
1189-
}
1190-
} break;
1188+
slot->infill = task.infill_mode;
1189+
slot->embedding = task.embedding_mode;
1190+
slot->task_id = task.id;
1191+
slot->multitask_id = task.multitask_id;
1192+
1193+
if (!LaunchSlotWithData(slot, task.data)) {
1194+
// send error result
1195+
SendError(task, "internal_error");
1196+
return;
1197+
}
11911198
}
11921199
}
11931200

0 commit comments

Comments
 (0)