Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,12 @@ Takes a prefix and a suffix and returns the predicted completion as stream.

- `input_prefix`: Set the prefix of the code to infill.
- `input_suffix`: Set the suffix of the code to infill.
- `prompt`: Added after the `FIM_MID` token
- `extra_context`: Additional context inserted before the FIM prefix. See https://github.com/ggerganov/llama.cpp/pull/9874
- `input_extra`: Additional context inserted before the FIM prefix.
- `prompt`: Added after the `FIM_MID` token

It also accepts all the options of `/completion`.
`input_extra` is array of `{"filename": string, "text": string}` objects.

The endpoint also accepts all the options of `/completion`.

If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:

Expand All @@ -545,7 +547,7 @@ If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](
If the tokens are missing, then the extra context is simply prefixed at the start:

```txt
[extra_context]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
[input_extra]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
```

### **GET** `/props`: Get server global properties.
Expand Down
73 changes: 27 additions & 46 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ struct slot_params {
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<std::string> antiprompt;

json input_prefix;
json input_suffix;
json extra_context;
};

struct server_slot {
Expand Down Expand Up @@ -169,6 +165,10 @@ struct server_slot {

json prompt; // can be either a string, array of strings or array of token ids

json input_prefix;
json input_suffix;
json input_extra;

// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;
Expand Down Expand Up @@ -908,12 +908,12 @@ struct server_context {
}

// infill
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
slot.input_prefix = json_value(data, "input_prefix", json());
slot.input_suffix = json_value(data, "input_suffix", json());
slot.input_extra = json_value(data, "input_extra", json());

SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
for (const auto & chunk : slot.params.extra_context) {
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk["text"].is_string()) {
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
Expand All @@ -930,7 +930,7 @@ struct server_context {
}

// get prompt
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
{
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
Expand Down Expand Up @@ -1954,6 +1954,8 @@ struct server_context {
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
// TODO: optimize this block by reducing memory allocations and movement

// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
Expand All @@ -1964,10 +1966,11 @@ struct server_context {
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
auto tokens_prompt = tokenize(slot.prompt, false, false);

slot.extra_tokens.clear();
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
Expand All @@ -1977,7 +1980,7 @@ struct server_context {
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}

for (const auto & chunk : slot.params.extra_context) {
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
const std::string text = chunk.value("text", "");
const std::string filename = chunk.value("filename", "tmp");
Expand Down Expand Up @@ -2008,20 +2011,21 @@ struct server_context {
}

// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);

// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());

prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
suffix_tokens.resize(n_suffix_take);
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);

prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));

auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;

if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
Expand Down Expand Up @@ -2136,40 +2140,17 @@ struct server_context {

while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) {
if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
break;
}

if (llama_token_is_control(model, prompt_tokens[head_p]) &&
prompt_tokens[head_p] != llama_token_fim_rep(model) &&
prompt_tokens[head_p] != llama_token_fim_sep(model)) {
break;
}

size_t n_match = 0;

while (head_c + n_match < slot.cache_tokens.size() &&
head_p + n_match < prompt_tokens.size() &&
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
break;
}

if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
break;
}

n_match++;
}

if (n_match >= (size_t) params.n_cache_reuse) {
SLT_DBG(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
//}
Expand Down
Loading