Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit b6372bb

Browse files
feat-add-llamacpp-params (#221)
1 parent 1a6fb1a commit b6372bb

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

src/chat_completion_request.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include "json/value.h"
3+
#include "sampling.h"
34

45
namespace llama::inferences {
56
struct ChatCompletionRequest {
@@ -12,10 +13,29 @@ struct ChatCompletionRequest {
1213
Json::Value stop = Json::Value(Json::arrayValue);
1314
Json::Value messages = Json::Value(Json::arrayValue);
1415
std::string model_id;
16+
17+
int seed = -1;
18+
float dynatemp_range = 0.0f;
19+
float dynatemp_exponent = 1.0f;
20+
int top_k = 40;
21+
float min_p = 0.05f;
22+
float tfs_z = 1.0f;
23+
float typ_p = 1.0f;
24+
int repeat_last_n = 64;
25+
float penalty_repeat = 1.0f;
26+
bool mirostat = false;
27+
float mirostat_tau = 5.0f;
28+
float mirostat_eta = 0.1f;
29+
bool penalize_nl = false;
30+
bool ignore_eos = false;
31+
int n_probs = 0;
32+
int min_keep = 0;
33+
std::string grammar;
1534
};
1635

1736
inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
1837
ChatCompletionRequest completion;
38+
gpt_sampler_params default_params;
1939
if (jsonBody) {
2040
completion.stream = (*jsonBody).get("stream", false).asBool();
2141
completion.max_tokens = (*jsonBody).get("max_tokens", 500).asInt();
@@ -28,6 +48,24 @@ inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
2848
completion.messages = (*jsonBody)["messages"];
2949
completion.stop = (*jsonBody)["stop"];
3050
completion.model_id = (*jsonBody).get("model", {}).asString();
51+
52+
completion.seed = (*jsonBody).get("seed", -1).asInt();
53+
completion.dynatemp_range = (*jsonBody).get("dynatemp_range", 0.0f).asFloat();
54+
completion.dynatemp_exponent = (*jsonBody).get("dynatemp_exponent", 0.0f).asFloat();
55+
completion.top_k = (*jsonBody).get("top_k", 40).asInt();
56+
completion.min_p = (*jsonBody).get("min_p", 0.05f).asFloat();
57+
completion.tfs_z = (*jsonBody).get("tfs_z", 1.0f).asFloat();
58+
completion.typ_p = (*jsonBody).get("typ_p", 1.0f).asFloat();
59+
completion.repeat_last_n = (*jsonBody).get("repeat_last_n", 64).asInt();
60+
completion.penalty_repeat = (*jsonBody).get("repeat_penalty", 1.1f).asFloat();
61+
completion.mirostat = (*jsonBody).get("mirostat", false).asBool();
62+
completion.mirostat_tau = (*jsonBody).get("mirostat_tau", 5.0f).asFloat();
63+
completion.mirostat_eta = (*jsonBody).get("mirostat_eta", 0.1f).asFloat();
64+
completion.penalize_nl = (*jsonBody).get("penalize_nl", true).asBool();
65+
completion.ignore_eos = (*jsonBody).get("ignore_eos", false).asBool();
66+
completion.n_probs = (*jsonBody).get("n_probs", 0).asInt();
67+
completion.min_keep = (*jsonBody).get("min_keep", 0).asInt();
68+
completion.grammar = (*jsonBody).get("grammar", "").asString();
3169
}
3270
return completion;
3371
}

src/llama_engine.cc

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,10 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> json_body) {
480480
if (!params.use_mmap) {
481481
LOG_DEBUG << "Disabled mmap";
482482
}
483+
params.n_predict = json_body->get("n_predict", -1).asInt();
484+
params.prompt = json_body->get("prompt", "").asString();
485+
params.conversation = json_body->get("conversation", false).asBool();
486+
params.special = json_body->get("special", false).asBool();
483487

484488
server_map_[model_id].caching_enabled =
485489
json_body->get("caching_enabled", true).asBool();
@@ -599,6 +603,24 @@ void LlamaEngine::HandleInferenceImpl(
599603
data["temperature"] = completion.temperature;
600604
data["frequency_penalty"] = completion.frequency_penalty;
601605
data["presence_penalty"] = completion.presence_penalty;
606+
data["seed"] = completion.seed;
607+
data["dynatemp_range"] = completion.dynatemp_range;
608+
data["dynatemp_exponent"] = completion.dynatemp_exponent;
609+
data["top_k"] = completion.top_k;
610+
data["min_p"] = completion.min_p;
611+
data["tfs_z"] = completion.tfs_z;
612+
data["typical_p"] = completion.typ_p;
613+
data["repeat_last_n"] = completion.repeat_last_n;
614+
data["repeat_penalty"] = completion.penalty_repeat;
615+
data["mirostat"] = completion.mirostat;
616+
data["mirostat_tau"] = completion.mirostat_tau;
617+
data["mirostat_eta"] = completion.mirostat_eta;
618+
data["penalize_nl"] = completion.penalize_nl;
619+
data["ignore_eos"] = completion.ignore_eos;
620+
data["n_probs"] = completion.n_probs;
621+
data["min_keep"] = completion.min_keep;
622+
data["grammar"] = completion.grammar;
623+
int n_probs = completion.n_probs;
602624
const Json::Value& messages = completion.messages;
603625

604626
if (!si.grammar_file_content.empty()) {
@@ -717,12 +739,17 @@ void LlamaEngine::HandleInferenceImpl(
717739
auto state = CreateInferenceState(si.ctx);
718740

719741
// Queued task
720-
si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id]() {
742+
si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id, n_probs]() {
721743
state->task_id = state->llama.RequestCompletion(data, false, false, -1);
722744
while (state->llama.model_loaded_external) {
723745
TaskResult result = state->llama.NextResult(state->task_id);
724746
if (!result.error) {
725-
std::string to_send = result.result_json["content"];
747+
std::string to_send;
748+
if (n_probs > 0){
749+
to_send = result.result_json["completion_probabilities"].dump();
750+
}else{
751+
to_send = result.result_json["content"];
752+
}
726753
// trim the leading space if it is the first token
727754
if (std::exchange(state->is_first_token, false)) {
728755
llama_utils::ltrim(to_send);

src/llama_server_context.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,15 @@ bool LlamaServerContext::LaunchSlotWithData(LlamaClientSlot*& slot, json data) {
459459
slot->params.seed = json_value(data, "seed", default_params.seed);
460460
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
461461
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
462+
slot->sparams.min_keep =
463+
json_value(data, "min_keep", default_sparams.min_keep);
464+
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
465+
slot->sparams.dynatemp_range =
466+
json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
467+
slot->sparams.dynatemp_exponent =
468+
json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
469+
slot->sparams.ignore_eos =
470+
json_value(data, "ignore_eos", default_sparams.ignore_eos);
462471

463472
// infill
464473
if (data.count("input_prefix") != 0) {
@@ -970,8 +979,13 @@ void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) {
970979
slot.generated_token_probs.begin(),
971980
slot.generated_token_probs.begin() + slot.sent_token_probs_index);
972981
}
973-
res.result_json["completion_probabilities"] =
982+
if(!slot.params.stream ){
983+
res.result_json["completion_probabilities"] =
974984
probs_vector_to_json(ctx, probs);
985+
}
986+
else{
987+
res.result_json["completion_probabilities"] = std::move(json());
988+
}
975989
}
976990

977991
if (slot.oaicompat) {

0 commit comments

Comments
 (0)