diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 1fccfdd17f138..7fbca32016274 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -13,9 +13,14 @@ endif() set(TARGET_SRCS server.cpp - utils.hpp server-http.cpp server-http.h + server-task.cpp + server-task.h + server-queue.cpp + server-queue.h + server-common.cpp + server-common.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/utils.hpp b/tools/server/server-common.cpp similarity index 67% rename from tools/server/utils.hpp rename to tools/server/server-common.cpp index bf21726051e55..18328f3afbdd5 100644 --- a/tools/server/utils.hpp +++ b/tools/server/server-common.cpp @@ -1,502 +1,752 @@ -#pragma once - #include "common.h" #include "log.h" #include "llama.h" -#include "arg.h" // common_remote_get_content -#include "base64.hpp" #include "mtmd.h" #include "mtmd-helper.h" #include "chat.h" +#include "arg.h" // for common_remote_get_content; TODO: use download.h only +#include "base64.hpp" -#define JSON_ASSERT GGML_ASSERT -#include +#include "server-common.h" #include #include -#include -#include -#include -#include - -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" - -using json = nlohmann::ordered_json; - -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) - -#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -using raw_buffer = std::vector; - -template -static T json_value(const json & body, const std::string & key, const T & default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { - LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what()); - return default_value; - } - } else { - return default_value; + +json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + case ERROR_TYPE_EXCEED_CONTEXT_SIZE: + type_str = "exceed_context_size_error"; + code = 400; + break; } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; } -const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); +// +// random string / id +// -// thin wrapper around common_grammar_trigger with (de)serialization functions -struct server_grammar_trigger { - common_grammar_trigger value; +std::string random_string() { + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); - server_grammar_trigger() = default; - server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} - server_grammar_trigger(const json & in) { - value.type = (common_grammar_trigger_type) in.at("type").get(); - value.value = in.at("value").get(); - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - value.token = (llama_token) in.at("token").get(); - } - } + std::random_device rd; + std::mt19937 generator(rd()); - json to_json() const { - json out { - {"type", (int) value.type}, - {"value", value.value}, - }; - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int) value.token; - } - return out; + std::string result(32, ' '); + + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; } -}; + + return result; +} + +std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); +} + +std::string gen_tool_call_id() { + return random_string(); +} // -// tokenizer and input processing utils +// lora utils // -static bool json_is_array_of_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (!e.is_number_integer()) { +bool lora_all_alora(const std::vector & loras) { + bool found_alora = false; + for (const auto & lora : loras) { + if (lora.scale != 0) { + if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { return false; } + found_alora = true; } - return true; } - return false; + return found_alora; } -// is array having BOTH numbers & strings? -static bool json_is_array_of_mixed_numbers_strings(const json & data) { - bool seen_string = false; - bool seen_number = false; - if (data.is_array()) { - for (const auto & e : data) { - seen_string |= e.is_string(); - seen_number |= e.is_number_integer(); - if (seen_number && seen_string) { - return true; - } +bool lora_should_clear_cache( + const std::vector & current, + const std::vector & next) { + + // This should always be called after determining that the two sets are + // _not_ equal. This assert is therefore some slightly wasted work and + // should be safe to remove as long as this method is called correctly. + GGML_ASSERT(!are_lora_equal(current, next)); + + return ( + !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || + !lora_all_alora(next)); +} + +std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); } } - return false; + + return lora; } -// does array have any individual integers/tokens? -static bool json_is_array_and_contains_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (e.is_number_integer()) { - return true; - } - } +bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { return false; } - return false; + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; } -// get value by path(key1 / key2) -static json json_get_nested_values(const std::vector & paths, const json & js) { - json result = json::object(); - - for (const std::string & path : paths) { - json current = js; - const auto keys = string_split(path, /*separator*/ '/'); - bool valid_path = true; - for (const std::string & k : keys) { - if (valid_path && current.is_object() && current.contains(k)) { - current = current[k]; - } else { - valid_path = false; - } - } - if (valid_path) { - result[path] = current; +std::vector lora_get_enabled_ids(const std::vector & loras) { + std::vector enabled_ids; + for (size_t i = 0; i < loras.size(); ++i) { + if (loras[i].scale > 0) { + enabled_ids.push_back(i); } } - return result; + return enabled_ids; } -/** - * this handles 2 cases: - * - only string, example: "string" - * - mixed string and tokens, example: [12, 34, "string", 56, 78] - */ -static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - llama_tokens prompt_tokens; +// +// base64 utils (TODO: use the base64::decode from base64.hpp) +// - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; - llama_tokens p; - if (first) { - p = common_tokenize(vocab, s, add_special, parse_special); - first = false; - } else { - p = common_tokenize(vocab, s, false, parse_special); - } +static inline bool is_base64(uint8_t c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } +static inline raw_buffer base64_decode(const std::string & encoded_string) { + int i = 0; + int j = 0; + int in_ = 0; - prompt_tokens.push_back(p.template get()); + int in_len = encoded_string.size(); + + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + + raw_buffer ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) { + ret.push_back(char_array_3[i]); } + + i = 0; } - } else { - auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); } - return prompt_tokens; -} + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } -// return the last index of character that can form a valid string -// if the last character is potentially cut in half, return the index before the cut -// if validate_utf8(text) == text.size(), then the whole text is valid utf8 -static size_t validate_utf8(const std::string& text) { - size_t len = text.size(); - if (len == 0) return 0; + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } - // Check the last few bytes to see if a multi-byte character is cut off - for (size_t i = 1; i <= 4 && i <= len; ++i) { - unsigned char c = text[len - i]; - // Check for start of a multi-byte sequence from the end - if ((c & 0xE0) == 0xC0) { - // 2-byte character start: 110xxxxx - // Needs at least 2 bytes - if (i < 2) return len - i; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character start: 1110xxxx - // Needs at least 3 bytes - if (i < 3) return len - i; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character start: 11110xxx - // Needs at least 4 bytes - if (i < 4) return len - i; + char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret.push_back(char_array_3[j]); } } - // If no cut-off multi-byte character is found, return full length - return len; + return ret; } // -// template utils +// server_tokens implementation // -// format infill task -static llama_tokens format_infill( - const llama_vocab * vocab, - const json & input_prefix, - const json & input_suffix, - const json & input_extra, - const int n_batch, - const int n_predict, - const int n_ctx, - const bool spm_infill, - const llama_tokens & tokens_prompt - ) { - // TODO: optimize this block by reducing memory allocations and movement +server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } +} - // use FIM repo-level pattern: - // ref: https://arxiv.org/pdf/2409.12186 - // - // [FIM_REP]myproject - // [FIM_SEP]filename0 - // extra chunk 0 - // [FIM_SEP]filename1 - // extra chunk 1 - // ... - // [FIM_SEP]filename - // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt - // - llama_tokens extra_tokens; - extra_tokens.reserve(n_ctx); +server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { +} - auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); - auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); +llama_pos server_tokens::pos_next() const { + if (!has_mtmd) { + return tokens.size(); + } - if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: make project name an input - static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + llama_pos res = tokens.size(); - extra_tokens.push_back(llama_vocab_fim_rep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { + const auto & chunk = it->second; + res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); } - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - const std::string text = json_value(chunk, "text", std::string()); - const std::string filename = json_value(chunk, "filename", std::string("tmp")); - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + return res; +} - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); +std::string server_tokens::str() const { + std::ostringstream oss; + oss << "tokens: "; + for (size_t idx = 0; idx < tokens.size(); ++idx) { + llama_token t = tokens[idx]; + oss << "idx:" << idx << " "; + if (t == LLAMA_TOKEN_NULL) { + oss << " "; } else { - // chunk separator in binary form to avoid confusing the AI - static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; - static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); - - extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + oss << t << " "; } - - const auto chunk_tokens = common_tokenize(vocab, text, false, false); - extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } - - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: current filename - static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + oss << "\n"; + oss << "image idx: "; + for (const auto & it : map_idx_to_media) { + oss << it.first << ", "; } + return oss.str(); +} - // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) - const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); - const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); - - SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); - - // fill the rest of the context with extra chunks - const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); - - tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); - tokens_suffix.resize(n_suffix_take); - - tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); - tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); - - auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; - auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - - if (llama_vocab_get_add_bos(vocab)) { - embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); +const mtmd::input_chunk_ptr & server_tokens::find_chunk(size_t idx) const { + auto it = map_idx_to_media.find(idx); + if (it != map_idx_to_media.end()) { + return it->second; } - - SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); - - // put the extra context before the FIM prefix - embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); - - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_vocab_fim_mid(vocab)); - - return embd_inp; + throw std::runtime_error("Chunk not found"); } -// -// base64 utils (TODO: move to common in the future) -// - -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; +void server_tokens::push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); +} -static inline bool is_base64(uint8_t c) { - return (isalnum(c) || (c == '+') || (c == '/')); +void server_tokens::push_back(const mtmd_input_chunk * chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + GGML_ASSERT(has_mtmd); + const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); + size_t start_idx = tokens.size(); + for (size_t i = 0; i < n_tokens; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } } -static inline raw_buffer base64_decode(const std::string & encoded_string) { - int i = 0; - int j = 0; - int in_ = 0; +void server_tokens::push_back(server_tokens & tokens) { + size_t start_idx = size(); + for (size_t i = 0; i < tokens.size(); i++) { + push_back(tokens[i]); + } + if (tokens.has_mtmd) { + // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. + // We could also just check, but this will prevent silently dropping MTMD data. + GGML_ASSERT(has_mtmd); + for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { + auto * chunk = tokens.map_idx_to_media[it->first].get(); + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx + it->first] = std::move(new_chunk); + } + } +} - int in_len = encoded_string.size(); +void server_tokens::insert(const llama_tokens & inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); +} - uint8_t char_array_4[4]; - uint8_t char_array_3[3]; +const llama_tokens & server_tokens::get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; +} - raw_buffer ret; +void server_tokens::set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; +} - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) { - for (i = 0; i < 4; i++) { - char_array_4[i] = base64_chars.find(char_array_4[i]); +void server_tokens::keep_first(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + // make sure we never remove tokens in the middle of an image + // note that the case where we keep a full image at the end is allowed: + // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL + if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) { - ret.push_back(char_array_3[i]); + } + // remove all image chunks that are not used anymore + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { + size_t idx = it->first; + if (idx >= n) { + it = map_idx_to_media.erase(it); + } else { + ++it; } - - i = 0; } } + tokens.resize(n); +} - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; +std::string server_tokens::detokenize(const llama_context * ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); } + } + return common_detokenize(ctx, text_tokens, special); +} - for (j = 0; j < 4; j++) { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } +size_t server_tokens::get_common_prefix(const server_tokens & b) const { + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { + continue; + } - for (j = 0; j < i - 1; j++) { - ret.push_back(char_array_3[j]); + return i; } + + return max_idx; } - return ret; -} + for (size_t i = 0; i < max_idx; ++i) { + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; -// -// random string / id -// + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); -static std::string random_string() { - static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + GGML_ASSERT(a_chunk && b_chunk); - std::random_device rd; - std::mt19937 generator(rd()); + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - std::string result(32, ' '); + const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); + const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - for (int i = 0; i < 32; ++i) { - result[i] = str[generator() % str.size()]; + if (id_ai == id_bi && n_tok_a == n_tok_b) { + GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen + i += n_tok_a - 1; // will be +1 by the for loop + continue; + } + + return i; + } + + if (ai == bi) { + continue; + } + + return i; } - return result; + return max_idx; // all tokens are equal } -static std::string gen_chatcmplid() { - return "chatcmpl-" + random_string(); +bool server_tokens::validate(const struct llama_context * ctx) const { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + const auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto & chunk = find_chunk(i); + size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); + i += n_tokens - 1; // will be +1 by the for loop + } catch (const std::exception & e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; } -static std::string gen_tool_call_id() { - return random_string(); +int32_t server_tokens::process_chunk( + llama_context * ctx, + mtmd_context * mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t & n_tokens_out) const { + const auto & chunk = find_chunk(idx); + const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE + ? "image" : "audio"; + SRV_INF("processing %s...\n", name); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past; // unused for now + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + chunk.get(), + pos, + seq_id, + n_batch, + true, // logits last + &new_n_past); + SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_tokens_out = 0; + return result; + } + n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); + return 0; } // -// other common utils +// tokenizer and input processing utils // -static std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - -// TODO: reuse llama_detokenize -template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += common_token_to_piece(ctx, *begin); +bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; } - - return ret; + return false; } -// format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); - - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; +bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } } - - return out; + return false; } -// format server-sent event (SSE), return the formatted string to send -// note: if data is a json array, it will be sent as multiple events, one per item -static std::string format_sse(const json & data) { - std::ostringstream ss; - auto send_single = [&ss](const json & data) { - ss << "data: " << - safe_json_to_str(data) << - "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - }; - +bool json_is_array_and_contains_numbers(const json & data) { if (data.is_array()) { - for (const auto & item : data) { - send_single(item); + for (const auto & e : data) { + if (e.is_number_integer()) { + return true; + } } - } else { - send_single(data); + return false; } - - return ss.str(); + return false; } -// -// OAI utils -// - -// used by /completions endpoint -static json oaicompat_completion_params_parse(const json & body) { - json llama_params; +json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); - if (!body.contains("prompt")) { - throw std::runtime_error("\"prompt\" is required"); + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } } + return result; +} - // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) { - llama_params["stop"] = json::array({body.at("stop").get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } +llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; - // Handle "n" field - int n_choices = json_value(body, "n", 1); + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); + } + + return prompt_tokens; +} + +size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + +// Computes FNV-1a hash of the data +static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); +} + +server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { + mtmd::bitmaps bitmaps; + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image or audio file"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + // process prompt + std::vector inputs; + // multimodal + mtmd_input_text inp_txt = { + prompt.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + auto result = server_tokens(chunks, true); + return result; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * use tokenize_input_prompts() if the input could be an array. + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + */ +static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { + constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; + constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; + const bool has_mtmd = mctx != nullptr; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); + return server_tokens(tmp, false); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + llama_tokens tmp = json_prompt.get(); + return server_tokens(tmp, false); + } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) { + // JSON object with prompt key. + if (json_prompt.contains(JSON_MTMD_DATA_KEY)) { + if (!has_mtmd) + throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests."); + + // JSON object with prompt and multimodal key. + std::vector files; + for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { + files.push_back(base64_decode(entry)); + } + return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files); + } else { + // Not multimodal, but contains a subobject. + llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special); + return server_tokens(tmp, false); + } + } else { + throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens."); + } +} + +std::vector tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special)); + } + } else { + result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + + +// +// OAI utils +// + +// used by /completions endpoint +json oaicompat_completion_params_parse(const json & body) { + json llama_params; + + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); if (n_choices != 1) { throw std::runtime_error("Only one completion choice is allowed"); } @@ -525,19 +775,8 @@ static json oaicompat_completion_params_parse(const json & body) { return llama_params; } -struct oaicompat_parser_options { - bool use_jinja; - bool prefill_assistant; - common_reasoning_format reasoning_format; - std::map chat_template_kwargs; - common_chat_templates * tmpls; - bool allow_image; - bool allow_audio; - bool enable_thinking = true; -}; - // used by /chat/completions endpoint -static json oaicompat_chat_params_parse( +json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ const oaicompat_parser_options & opt, std::vector & out_files) @@ -809,7 +1048,7 @@ static json oaicompat_chat_params_parse( return llama_params; } -static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { +json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { json data = json::array(); int32_t n_tokens = 0; int i = 0; @@ -851,7 +1090,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } -static json format_response_rerank( +json format_response_rerank( const json & request, const json & ranks, bool is_tei_format, @@ -896,63 +1135,12 @@ static json format_response_rerank( return res; } -static bool is_valid_utf8(const std::string & str) { - const unsigned char* bytes = reinterpret_cast(str.data()); - const unsigned char* end = bytes + str.length(); - - while (bytes < end) { - if (*bytes <= 0x7F) { - // 1-byte sequence (0xxxxxxx) - bytes++; - } else if ((*bytes & 0xE0) == 0xC0) { - // 2-byte sequence (110xxxxx 10xxxxxx) - if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) - return false; - bytes += 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) - return false; - bytes += 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || - (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) - return false; - bytes += 4; - } else { - // Invalid UTF-8 lead byte - return false; - } - } - - return true; -} - -static json format_tokenizer_response(const json & tokens) { - return json { - {"tokens", tokens} - }; -} - -static json format_detokenized_response(const std::string & content) { - return json { - {"content", content} - }; -} -static json format_logit_bias(const std::vector & logit_bias) { - json data = json::array(); - for (const auto & lb : logit_bias) { - data.push_back(json{ - {"bias", lb.bias}, - {"token", lb.token}, - }); - } - return data; -} +// +// other utils +// -static std::vector get_token_probabilities(llama_context * ctx, int idx) { +std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); @@ -986,538 +1174,203 @@ static std::vector get_token_probabilities(llama_context * ctx return cur; } -static bool are_lora_equal( - const std::vector & l1, - const std::vector & l2) { - if (l1.size() != l2.size()) { - return false; - } - for (size_t i = 0; i < l1.size(); ++i) { - // we don't check lora.path to reduce the time complexity - if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { - return false; - } - } - return true; +std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); } -// get the ids of all enabled loras -static std::vector lora_get_enabled_ids(const std::vector & loras) { - std::vector enabled_ids; - for (size_t i = 0; i < loras.size(); ++i) { - if (loras[i].scale > 0) { - enabled_ids.push_back(i); - } +// TODO: reuse llama_detokenize +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); } - return enabled_ids; + + return ret; } -// check whether the given lora set has only aloras activated (empty => false) -static bool lora_all_alora(const std::vector & loras) { - bool found_alora = false; - for (const auto & lora : loras) { - if (lora.scale != 0) { - if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { - return false; - } - found_alora = true; - } - } - return found_alora; +std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) { + return tokens_to_str(ctx, tokens.begin(), tokens.end()); } -// if the two sets of loras are different, they require a cache clear unless the -// change is only from aloras to aloras. -static bool lora_should_clear_cache( - const std::vector & current, - const std::vector & next) { +// format incomplete utf-8 multibyte character for output +std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); - // This should always be called after determining that the two sets are - // _not_ equal. This assert is therefore some slightly wasted work and - // should be safe to remove as long as this method is called correctly. - GGML_ASSERT(!are_lora_equal(current, next)); + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } - return ( - !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || - !lora_all_alora(next)); + return out; } -// parse lora config from JSON request, returned a copy of lora_base with updated scale -static std::vector parse_lora_request( - const std::vector & lora_base, - const json & data) { - std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto & entry : lora) { - entry.scale = 0.0f; - } +// format server-sent event (SSE), return the formatted string to send +// note: if data is a json array, it will be sent as multiple events, one per item +std::string format_sse(const json & data) { + std::ostringstream ss; + auto send_single = [&ss](const json & data) { + ss << "data: " << + safe_json_to_str(data) << + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). + }; - // set value - for (const auto & entry : data) { - int id = json_value(entry, "id", -1); - float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); + if (data.is_array()) { + for (const auto & item : data) { + send_single(item); } + } else { + send_single(data); } - return lora; + return ss.str(); } -// -// utils for interacting with libmtmd -// (may need to refactor in near future) -// +bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); -/** - * server_tokens is a helper to manage the input tokens and image for the server. - * it is made this way to simplify the logic of KV cache management. - */ -struct server_tokens { - bool has_mtmd = false; - -private: // disallow accessing these members directly, risking out-of-sync - - // map a **start** index in tokens to the image chunk - // note: the order need to be in-sync with tokens - std::map map_idx_to_media; - - // list of tokens - // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk - // otherwise, it is a normal text token - // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list - // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos - llama_tokens tokens; - - // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] - // idx 0 1 2 3 4 5 6 7 8 9 10 - // pos 0 1 2 3 4 5 5 5 7 7 7 - // map_idx_to_media will contain: {5, img0}, {8, img1} - -public: - server_tokens() = default; - ~server_tokens() = default; - - // Prevent copying - // TODO: server_tokens should be copyable - remove this: - server_tokens(const server_tokens&) = delete; - server_tokens& operator=(const server_tokens&) = delete; - - // Allow moving (usually implicitly generated if members are movable) - server_tokens(server_tokens&&) = default; - server_tokens& operator=(server_tokens&&) = default; - - // Allow accessing elements using [] operator - llama_token operator[](size_t index) { return tokens[index]; } - const llama_token& operator[](size_t index) const { return tokens[index]; } - - server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { - for (size_t i = 0; i < mtmd_chunks.size(); ++i) { - push_back(mtmd_chunks[i]); + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; } } - server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { - } + return true; +} - llama_pos pos_next() const { - if (!has_mtmd) { - return tokens.size(); - } +llama_tokens format_prompt_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement - llama_pos res = tokens.size(); + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - const auto & chunk = it->second; - res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); - } + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); - return res; - } + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); - // for debugging - std::string str() const { - std::ostringstream oss; - oss << "tokens: "; - for (size_t idx = 0; idx < tokens.size(); ++idx) { - llama_token t = tokens[idx]; - oss << "idx:" << idx << " "; - if (t == LLAMA_TOKEN_NULL) { - oss << " "; - } else { - oss << t << " "; - } - } - oss << "\n"; - oss << "image idx: "; - for (const auto & it : map_idx_to_media) { - oss << it.first << ", "; - } - return oss.str(); + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); - const mtmd::input_chunk_ptr & find_chunk(size_t idx) const { - auto it = map_idx_to_media.find(idx); - if (it != map_idx_to_media.end()) { - return it->second; - } - throw std::runtime_error("Chunk not found"); - } + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); - void push_back(llama_token tok) { - if (tok == LLAMA_TOKEN_NULL) { - throw std::runtime_error("Invalid token"); - } - tokens.emplace_back(tok); - } - - // will create a copy of the chunk if it contains non-text data - void push_back(const mtmd_input_chunk * chunk) { - auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - GGML_ASSERT(has_mtmd); - const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); - size_t start_idx = tokens.size(); - for (size_t i = 0; i < n_tokens; ++i) { - tokens.emplace_back(LLAMA_TOKEN_NULL); - } - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx] = std::move(new_chunk); - } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - for (size_t i = 0; i < n_tokens; ++i) { - push_back(text_tokens[i]); - } + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } else { - GGML_ABORT("Invalid chunk type"); - } - } + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); - // appends server tokens, updates the media map. copies media chunks. - void push_back(server_tokens & tokens) { - size_t start_idx = size(); - for (size_t i = 0; i < tokens.size(); i++) { - push_back(tokens[i]); - } - if (tokens.has_mtmd) { - // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. - // We could also just check, but this will prevent silently dropping MTMD data. - GGML_ASSERT(has_mtmd); - for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { - auto * chunk = tokens.map_idx_to_media[it->first].get(); - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx + it->first] = std::move(new_chunk); - } + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); } - } - // for compatibility with context shift and prompt truncation - void insert(const llama_tokens & inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); - } - - // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens & get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - return tokens; - } - - // for compatibility with speculative decoding - void set_token(llama_pos pos, llama_token id) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens[pos] = id; - } - - size_t size() const { - return tokens.size(); - } - - bool empty() const { - return tokens.empty(); - } - - void clear() { - map_idx_to_media.clear(); - tokens.clear(); + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } - void keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); - if (has_mtmd) { - if (n == tokens.size()) { - return; // nothing to do - } - // we throw an error if we try to remove a token in the middle of an image - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // n 1 2 3 4 5 6 7 8 9 10 - // allowed to resize ^ ^ - // disallowed to resize ^ ^ ^ - if (n > 0) { - // make sure we never remove tokens in the middle of an image - // note that the case where we keep a full image at the end is allowed: - // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL - if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk - } - } - // remove all image chunks that are not used anymore - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { - size_t idx = it->first; - if (idx >= n) { - it = map_idx_to_media.erase(it); - } else { - ++it; - } - } - } - tokens.resize(n); - } + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - std::string detokenize(const llama_context * ctx, bool special) const { - llama_tokens text_tokens; - text_tokens.reserve(tokens.size()); - for (const auto & t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - text_tokens.push_back(t); - } - } - return common_detokenize(ctx, text_tokens, special); + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } - size_t get_common_prefix(const server_tokens & b) const { - const size_t max_idx = std::min(tokens.size(), b.tokens.size()); - - if (!has_mtmd) { - for (size_t i = 0; i < max_idx; ++i) { - if (tokens[i] == b.tokens[i]) { - continue; - } - - return i; - } - - return max_idx; - } - - for (size_t i = 0; i < max_idx; ++i) { - const llama_token ai = tokens[i]; - const llama_token bi = b.tokens[i]; - - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - const auto & a_chunk = find_chunk(i); - const auto & b_chunk = b.find_chunk(i); - - GGML_ASSERT(a_chunk && b_chunk); - - const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); - const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - - const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); - const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - - if (id_ai == id_bi && n_tok_a == n_tok_b) { - GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen - i += n_tok_a - 1; // will be +1 by the for loop - continue; - } - - return i; - } - - if (ai == bi) { - continue; - } - - return i; - } + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); - return max_idx; // all tokens are equal - } + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); - // make sure all text tokens are within the vocab range - bool validate(const struct llama_context * ctx) const { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); - for (size_t i = 0; i < tokens.size(); ++i) { - const auto & t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - try { - const auto & chunk = find_chunk(i); - size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); - i += n_tokens - 1; // will be +1 by the for loop - } catch (const std::exception & e) { - return false; - } - } else if (t < 0 || t >= n_vocab) { - return false; - } - } - return true; - } + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); - // encode and decode the image chunk - int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const { - const auto & chunk = find_chunk(idx); - const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - pos, - seq_id, - n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_tokens_out = 0; - return result; - } - n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); - return 0; - } -}; + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); -// Computes FNV-1a hash of the data -static std::string fnv_hash(const uint8_t * data, size_t len) { - const uint64_t fnv_prime = 0x100000001b3ULL; - uint64_t hash = 0xcbf29ce484222325ULL; + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - for (size_t i = 0; i < len; ++i) { - hash ^= data[i]; - hash *= fnv_prime; + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } - return std::to_string(hash); -} -static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { - mtmd::bitmaps bitmaps; - for (auto & file : files) { - mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); - if (!bmp.ptr) { - throw std::runtime_error("Failed to load image or audio file"); - } - // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); - bmp.set_id(hash.c_str()); - bitmaps.entries.push_back(std::move(bmp)); - } - // process prompt - std::vector inputs; - // multimodal - mtmd_input_text inp_txt = { - prompt.c_str(), - /* add_special */ true, - /* parse_special */ true, - }; - mtmd::input_chunks chunks(mtmd_input_chunks_init()); - auto bitmaps_c_ptr = bitmaps.c_ptr(); - int32_t tokenized = mtmd_tokenize(mctx, - chunks.ptr.get(), - &inp_txt, - bitmaps_c_ptr.data(), - bitmaps_c_ptr.size()); - if (tokenized != 0) { - throw std::runtime_error("Failed to tokenize prompt"); - } - auto result = server_tokens(chunks, true); - return result; -} + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * use tokenize_input_prompts() if the input could be an array. - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - */ -static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; - constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; - const bool has_mtmd = mctx != nullptr; - if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { - // string or mixed - llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); - return server_tokens(tmp, false); - } else if (json_is_array_of_numbers(json_prompt)) { - // array of tokens - llama_tokens tmp = json_prompt.get(); - return server_tokens(tmp, false); - } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) { - // JSON object with prompt key. - if (json_prompt.contains(JSON_MTMD_DATA_KEY)) { - if (!has_mtmd) - throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests."); + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); - // JSON object with prompt and multimodal key. - std::vector files; - for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { - files.push_back(base64_decode(entry)); - } - return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files); - } else { - // Not multimodal, but contains a subobject. - llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special); - return server_tokens(tmp, false); - } - } else { - throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens."); - } -} + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - * and multiple prompts (multi-tasks): - * - "prompt": ["string1", "string2"] - * - "prompt": ["string1", [12, 34, 56]] - * - "prompt": [[12, 34, 56], [78, 90, 12]] - * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] - */ -static std::vector tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - std::vector result; - if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { - result.reserve(json_prompt.size()); - for (const auto & p : json_prompt) { - result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special)); - } - } else { - result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); - } - if (result.empty()) { - throw std::runtime_error("\"prompt\" must not be empty"); - } - return result; + return embd_inp; } -// format rerank task: [BOS]query[EOS][SEP]doc[EOS]. -static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) { +server_tokens format_prompt_rerank( + const struct llama_model * model, + const struct llama_vocab * vocab, + mtmd_context * mctx, + const std::string & query, + const std::string & doc) { server_tokens result = {}; const char * rerank_prompt = llama_model_chat_template(model, "rerank"); diff --git a/tools/server/server-common.h b/tools/server/server-common.h new file mode 100644 index 0000000000000..868c50610319c --- /dev/null +++ b/tools/server/server-common.h @@ -0,0 +1,349 @@ +#pragma once + +#include "common.h" +#include "log.h" +#include "llama.h" +#include "chat.h" +#include "mtmd.h" + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" + +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +using json = nlohmann::ordered_json; + +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +using raw_buffer = std::vector; + +template +static T json_value(const json & body, const std::string & key, const T & default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what()); + return default_value; + } + } else { + return default_value; + } +} + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error + ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error +}; + +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} + server_grammar_trigger(const json & in) { + value.type = (common_grammar_trigger_type) in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token) in.at("token").get(); + } + } + + json to_json() const { + json out { + {"type", (int) value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int) value.token; + } + return out; + } +}; + +json format_error_response(const std::string & message, const enum error_type type); + +// +// random string / id +// + +std::string random_string(); +std::string gen_chatcmplid(); +std::string gen_tool_call_id(); + +// +// lora utils +// + +// check whether the given lora set has only aloras activated (empty => false) +bool lora_all_alora(const std::vector & loras); + +// if the two sets of loras are different, they require a cache clear unless the +// change is only from aloras to aloras. +bool lora_should_clear_cache( + const std::vector & current, + const std::vector & next); + +std::vector parse_lora_request( + const std::vector & lora_base, + const json & data); + +bool are_lora_equal( + const std::vector & l1, + const std::vector & l2); + +// get the ids of all enabled loras +std::vector lora_get_enabled_ids(const std::vector & loras); + +// +// server_tokens +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** index in tokens to the image chunk + // note: the order need to be in-sync with tokens + std::map map_idx_to_media; + + // list of tokens + // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk + // otherwise, it is a normal text token + // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list + // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] + // idx 0 1 2 3 4 5 6 7 8 9 10 + // pos 0 1 2 3 4 5 5 5 7 7 7 + // map_idx_to_media will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + // TODO: server_tokens should be copyable - remove this: + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd); + server_tokens(const llama_tokens & tokens, bool has_mtmd); + + // for debugging + std::string str() const; + + llama_pos pos_next() const; + const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; + + void push_back(llama_token tok); + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk * chunk); + + // appends server tokens, updates the media map. copies media chunks. + void push_back(server_tokens & tokens); + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens & inp_tokens); + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens & get_text_tokens() const; + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id); + + size_t size() const { return tokens.size(); } + + bool empty() const { return tokens.empty(); } + + void clear() { + map_idx_to_media.clear(); + tokens.clear(); + } + + void keep_first(size_t n); + + std::string detokenize(const llama_context * ctx, bool special) const; + + size_t get_common_prefix(const server_tokens & b) const; + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context * ctx) const; + + // encode and decode the image chunk + int32_t process_chunk( + llama_context * ctx, + mtmd_context * mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t & n_tokens_out) const; +}; + + +// +// tokenizer and input processing utils +// + +bool json_is_array_of_numbers(const json & data); + +// is array having BOTH numbers & strings? +bool json_is_array_of_mixed_numbers_strings(const json & data); + +// does array have any individual integers/tokens? +bool json_is_array_and_contains_numbers(const json & data); + +// get value by path(key1 / key2) +json json_get_nested_values(const std::vector & paths, const json & js); + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special); + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +size_t validate_utf8(const std::string& text); + +// process mtmd prompt, return the server_tokens containing both text tokens and media chunks +server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files); + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] + */ +std::vector tokenize_input_prompts( + const llama_vocab * vocab, + mtmd_context * mctx, + const json & json_prompt, + bool add_special, + bool parse_special); + +// +// OAI utils +// + +// used by /completions endpoint +json oaicompat_completion_params_parse(const json & body); + +struct oaicompat_parser_options { + bool use_jinja; + bool prefill_assistant; + common_reasoning_format reasoning_format; + std::map chat_template_kwargs; + common_chat_templates * tmpls; + bool allow_image; + bool allow_audio; + bool enable_thinking = true; +}; + +// used by /chat/completions endpoint +json oaicompat_chat_params_parse( + json & body, /* openai api json semantics */ + const oaicompat_parser_options & opt, + std::vector & out_files); + +// TODO: move it to server-task.cpp +json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); + +// TODO: move it to server-task.cpp +json format_response_rerank( + const json & request, + const json & ranks, + bool is_tei_format, + std::vector & texts, + int top_n); + +// +// other utils +// + +std::vector get_token_probabilities(llama_context * ctx, int idx); + +std::string safe_json_to_str(const json & data); + +std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens); + +// format incomplete utf-8 multibyte character for output +std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token); + +// format server-sent event (SSE), return the formatted string to send +// note: if data is a json array, it will be sent as multiple events, one per item +std::string format_sse(const json & data); + +bool is_valid_utf8(const std::string & str); + +// +// formatting output responses +// TODO: move these to server-task.cpp +// + +llama_tokens format_prompt_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt); + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS]. +server_tokens format_prompt_rerank( + const struct llama_model * model, + const struct llama_vocab * vocab, + mtmd_context * mctx, + const std::string & query, + const std::string & doc); diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index bebe0b49c4f6a..a82aa86b19e40 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -1,6 +1,6 @@ -#include "utils.hpp" #include "common.h" #include "server-http.h" +#include "server-common.h" #include diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp new file mode 100644 index 0000000000000..5a74fd76ac3c1 --- /dev/null +++ b/tools/server/server-queue.cpp @@ -0,0 +1,268 @@ +#include "server-task.h" +#include "server-queue.h" + +#include "log.h" + +#include + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +// +// server_queue +// + +int server_queue::post(server_task && task, bool front) { + std::unique_lock lock(mutex_tasks); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + const int task_id = task.id; + QUE_DBG("new task, id = %d, front = %d\n", task_id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + condition_tasks.notify_one(); + return task_id; +} + +int server_queue::post(std::vector && tasks, bool front) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; +} + +void server_queue::defer(server_task && task) { + std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); + queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); +} + +int server_queue::get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + return new_id; +} + +void server_queue::on_new_task(std::function callback) { + callback_new_task = std::move(callback); +} + +void server_queue::on_update_slots(std::function callback) { + callback_update_slots = std::move(callback); +} + +void server_queue::pop_deferred_task() { + std::unique_lock lock(mutex_tasks); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } + condition_tasks.notify_one(); +} + +void server_queue::terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); +} + +void server_queue::start_loop() { + running = true; + + while (true) { + QUE_DBG("%s", "processing new tasks\n"); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = std::move(queue_tasks.front()); + queue_tasks.pop_front(); + lock.unlock(); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); + } + + // all tasks in the current loop is processed, slots data is now ready + QUE_DBG("%s", "update slots\n"); + + callback_update_slots(); + + QUE_DBG("%s", "waiting for new tasks\n"); + { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); + } + } + } +} + +void server_queue::cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); +} + +// +// server_response +// + +void server_response::add_waiting_task_id(int id_task) { + RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); +} + +void server_response::add_waiting_tasks(const std::vector & tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } +} + +void server_response::remove_waiting_task_id(int id_task) { + RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); +} + +void server_response::remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } +} + +server_task_result_ptr server_response::recv(const std::unordered_set & id_tasks) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + if (!running) { + RES_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + return !queue_results.empty(); + }); + + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here +} + +server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + RES_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here +} + +server_task_result_ptr server_response::recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); +} + +void server_response::send(server_task_result_ptr && result) { + RES_DBG("sending result for task id = %d\n", result->id); + + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + if (result->id == id_task) { + RES_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); + condition_results.notify_all(); + return; + } + } +} + +void server_response::terminate() { + running = false; + condition_results.notify_all(); +} diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h new file mode 100644 index 0000000000000..47ef58425eaaf --- /dev/null +++ b/tools/server/server-queue.h @@ -0,0 +1,110 @@ +#pragma once + +#include "server-task.h" + +#include +#include +#include +#include + +struct server_queue { +private: + int id = 0; + bool running; + + // queues + std::deque queue_tasks; + std::deque queue_tasks_deferred; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_update_slots; + +public: + // Add a new task to the end of the queue + int post(server_task && task, bool front = false); + + // multi-task version of post() + int post(std::vector && tasks, bool front = false); + + // Add a new task, but defer until one slot is available + void defer(server_task && task); + + // Get the next id for creating a new task + int get_new_id(); + + // Register function to process a new task + void on_new_task(std::function callback); + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback); + + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task(); + + // end the start_loop routine + void terminate(); + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop(); + + // for metrics + size_t queue_tasks_deferred_size() { + std::unique_lock lock(mutex_tasks); + return queue_tasks_deferred.size(); + } + +private: + void cleanup_pending_task(int id_target); +}; + +struct server_response { +private: + bool running = true; + + // for keeping track of all tasks waiting for the result + std::unordered_set waiting_task_ids; + + // the main result queue (using ptr for polymorphism) + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + +public: + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task); + + void add_waiting_tasks(const std::vector & tasks); + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task); + + // remove multiple tasks from waiting list + void remove_waiting_task_ids(const std::unordered_set & id_tasks); + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set & id_tasks); + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout); + + // single-task version of recv() + server_task_result_ptr recv(int id_task); + + // Send a new result to a waiting id_task + void send(server_task_result_ptr && result); + + // terminate the waiting loop + void terminate(); +}; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp new file mode 100644 index 0000000000000..bc4436ba65ba4 --- /dev/null +++ b/tools/server/server-task.cpp @@ -0,0 +1,1192 @@ +#include "server-common.h" +#include "server-task.h" + +#include "common.h" +#include "llama.h" +#include "chat.h" +#include "sampling.h" +#include "json-schema-to-grammar.h" + +using json = nlohmann::ordered_json; + +// +// task_params +// + +json task_params::format_logit_bias(const std::vector & logit_bias) const { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +json task_params::to_json(bool only_metrics) const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + if (only_metrics) { + return json { + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } + + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); + } + + return json { + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; +} + +// +// server_task +// + +task_params server_task::params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + task_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) + task_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; + defaults.antiprompt = params_base.antiprompt; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + auto stream_opt = json_value(data, "stream_options", json::object()); + params.include_usage = json_value(stream_opt, "include_usage", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.return_progress = json_value(data, "return_progress", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 0); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_syntax.format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); + } else { + params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; + } + common_reasoning_format reasoning_format = params_base.reasoning_format; + if (data.contains("reasoning_format")) { + reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); + } + params.oaicompat_chat_syntax.reasoning_format = reasoning_format; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + } + + { + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + params.sampling.grammar_triggers.push_back(std::move(trigger)); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { + SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); + } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { + SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); + } else { + throw std::runtime_error("Unknown grammar trigger type"); + } + params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + params.sampling.logit_bias.clear(); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } else if (logit_bias != data.end() && logit_bias->is_object()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : logit_bias->items()) { + float bias; + const auto & key = el.key(); + const auto & value = el.value(); + if (value.is_number()) { + bias = value.get(); + } else if (value.is_boolean() && !value.get()) { + bias = -INFINITY; + } else { + continue; + } + + char *end; + llama_token tok = strtol(key.c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else { + auto toks = common_tokenize(vocab, key, false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + + params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); + if (params.sampling.ignore_eos) { + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + // set reverse prompt from cli args if not set in the request + if (params.antiprompt.empty()) { + params.antiprompt = defaults.antiprompt; + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; +} + +// +// result_timings +// + +json result_timings::to_json() const { + json base = { + {"cache_n", cache_n}, + + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; +} + +// +// result_prompt_progress +// +json result_prompt_progress::to_json() const { + return json { + {"total", total}, + {"cache", cache}, + {"processed", processed}, + {"time_ms", time_ms}, + }; +} + +static inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +// +// completion_token_output +// + +json completion_token_output::to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; +} + +json completion_token_output::probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; +} + +float completion_token_output::logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); +} + +std::vector completion_token_output::str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; +} + +// +// server_task_result_cmpl_final +// +json server_task_result_cmpl_final::to_json() { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } +} + +json server_task_result_cmpl_final::to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); +} + +json server_task_result_cmpl_final::to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_final::to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", msg.to_json_oaicompat()}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json deltas = json::array(); + for (const auto & diff : oaicompat_msg_diffs) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + } + + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + + if (include_usage) { + // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage + // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices + deltas.push_back({ + {"choices", json::array()}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + } + + if (timings.prompt_n >= 0) { + deltas.back().push_back({"timings", timings.to_json()}); + } + + // extra fields for debugging purposes + if (verbose && !deltas.empty()) { + deltas.front()["__verbose"] = to_json_non_oaicompat(); + } + + return deltas; +} + +// +// server_task_result_cmpl_partial +// +json server_task_result_cmpl_partial::to_json() { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } +} + +json server_task_result_cmpl_partial::to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; +} + +json server_task_result_cmpl_partial::to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_partial::to_json_oaicompat_chat() { + bool first = n_decoded == 1; + std::time_t t = std::time(0); + json choices; + + std::vector deltas; + auto add_delta = [&](const json & delta) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + }; + // We have to send an initial update to conform to openai behavior + if (first || is_progress) { + add_delta({ + {"role", "assistant"}, + {"content", nullptr}, + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + } + + if (!deltas.empty()) { + auto & last_json = deltas[deltas.size() - 1]; + GGML_ASSERT(last_json.at("choices").size() >= 1); + + if (prob_output.probs.size() > 0) { + last_json.at("choices").at(0)["logprobs"] = json { + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + last_json.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + last_json.push_back({"prompt_progress", progress.to_json()}); + } + } + + return deltas; +} + +// +// server_task_result_embd +// +json server_task_result_embd::to_json() { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); +} + +json server_task_result_embd::to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; +} + +json server_task_result_embd::to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; +} + +// +// server_task_result_rerank +// +json server_task_result_rerank::to_json() { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; +} + +// +// server_task_result_error +// +json server_task_result_error::to_json() { + json res = format_error_response(err_msg, err_type); + if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { + res["n_prompt_tokens"] = n_prompt_tokens; + res["n_ctx"] = n_ctx; + } + return res; +} + +// +// server_task_result_metrics +// +json server_task_result_metrics::to_json() { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_tokens_max", n_tokens_max }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "slots", slots_data }, + }; +} + +// +// server_task_result_slot_save_load +// +json server_task_result_slot_save_load::to_json() { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } + + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; +} + +// +// server_task_result_slot_erase +// +json server_task_result_slot_erase::to_json() { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; +} + +// +// server_task_result_apply_lora +// + +json server_task_result_apply_lora::to_json() { + return json {{ "success", true }}; +} + +// +// server_prompt_cache +// +size_t server_prompt_cache::size() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.size(); + } + + return res; +} + +size_t server_prompt_cache::n_tokens() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.n_tokens(); + } + + return res; +} + +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); + + if (cur_lcp_len == (int) prompt.tokens.size()) { + SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const int len = it->tokens.get_common_prefix(prompt.tokens); + + if (len == (int) it->tokens.size()) { + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } catch (const std::bad_alloc & e) { + SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4*size()); + + SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; +} + +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float sim_best = float(lcp_best) / tokens_new.size(); + + SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + const float sim_cur = float(lcp_cur) / tokens_new.size(); + + // don't trash large prompts + if (f_keep_cur < 0.25f) { + continue; + } + + if (f_keep_best < f_keep_cur && sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; +} + +void server_prompt_cache::update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + // average size per token + const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); + + // dynamically increase the token limit if it can fit in the memory limit + const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; + + if (limit_tokens > 0) { + while (states.size() > 1 && n_tokens() > limit_tokens_cur) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", + limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", + (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } +} diff --git a/tools/server/server-task.h b/tools/server/server-task.h new file mode 100644 index 0000000000000..0271caae1162b --- /dev/null +++ b/tools/server/server-task.h @@ -0,0 +1,453 @@ +#pragma once + +#include "common.h" +#include "llama.h" + +#include +#include +#include + +// TODO: prevent including the whole server-common.h as we only use server_tokens +#include "server-common.h" + +using json = nlohmann::ordered_json; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, +}; + +struct task_params { + bool stream = true; + bool include_usage = false; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + bool return_progress = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; + + // Embeddings + int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) + + json format_logit_bias(const std::vector & logit_bias) const; + json to_json(bool only_metrics = false) const; +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + int id_slot = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + task_params params; + server_tokens tokens; + + server_task_type type; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task() = default; + + server_task(server_task_type type) : type(type) {} + + int32_t n_tokens() const { + return tokens.size(); + } + + static task_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data); + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t cache_n = -1; + + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + + json to_json() const; +}; + +struct result_prompt_progress { + int32_t total = 0; + int32_t cache = 0; + int32_t processed = 0; + int64_t time_ms = 0; + + json to_json() const; +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return true; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const; + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs); + + static float logarithm(float x); + + static std::vector str_to_bytes(const std::string & str); + +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + bool include_usage; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + task_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); + + json to_json_oaicompat_chat(); + + json to_json_oaicompat_chat_stream(); +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + bool is_progress = false; + completion_token_output prob_output; + result_timings timings; + result_prompt_progress progress; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); + + json to_json_oaicompat_chat(); +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override; +}; + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + // for ERROR_TYPE_EXCEED_CONTEXT_SIZE + int32_t n_prompt_tokens = 0; + int32_t n_ctx = 0; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override; +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_tokens_max = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override; +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override; +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override; +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override; +}; + +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto & checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const; + + size_t n_tokens() const; + + server_prompt * alloc(const server_prompt & prompt, size_t state_size); + + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + + void update(); +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3750c8fdb6065..0f39def3794d6 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,25 +1,21 @@ -#include "chat.h" -#include "utils.hpp" +#include "server-common.h" #include "server-http.h" +#include "server-task.h" +#include "server-queue.h" #include "arg.h" #include "common.h" -#include "json-schema-to-grammar.h" #include "llama.h" #include "log.h" #include "sampling.h" #include "speculative.h" #include "mtmd.h" +#include "mtmd-helper.h" #include -#include -#include #include #include -#include #include -#include -#include #include #include #include @@ -37,1589 +33,39 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -enum stop_type { - STOP_TYPE_NONE, - STOP_TYPE_EOS, - STOP_TYPE_WORD, - STOP_TYPE_LIMIT, -}; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - -enum server_task_type { - SERVER_TASK_TYPE_COMPLETION, - SERVER_TASK_TYPE_EMBEDDING, - SERVER_TASK_TYPE_RERANK, - SERVER_TASK_TYPE_INFILL, - SERVER_TASK_TYPE_CANCEL, - SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS, - SERVER_TASK_TYPE_SLOT_SAVE, - SERVER_TASK_TYPE_SLOT_RESTORE, - SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, -}; - -enum oaicompat_type { - OAICOMPAT_TYPE_NONE, - OAICOMPAT_TYPE_CHAT, - OAICOMPAT_TYPE_COMPLETION, - OAICOMPAT_TYPE_EMBEDDING, -}; - -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error - ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error -}; - -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - -struct slot_params { - bool stream = true; - bool include_usage = false; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - bool return_tokens = false; - bool return_progress = false; - - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict - int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters - - int64_t t_max_prompt_ms = -1; // TODO: implement - int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - - std::vector lora; - - std::vector antiprompt; - std::vector response_fields; - bool timings_per_token = false; - bool post_sampling_probs = false; - - struct common_params_sampling sampling; - struct common_params_speculative speculative; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; - - // Embeddings - int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - json to_json(bool only_metrics = false) const { - std::vector samplers; - samplers.reserve(sampling.samplers.size()); - for (const auto & sampler : sampling.samplers) { - samplers.emplace_back(common_sampler_type_to_str(sampler)); - } - - json lora = json::array(); - for (size_t i = 0; i < this->lora.size(); ++i) { - lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); - } - - if (only_metrics) { - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } - - auto grammar_triggers = json::array(); - for (const auto & trigger : sampling.grammar_triggers) { - server_grammar_trigger ct(trigger); - grammar_triggers.push_back(ct.to_json()); - } - - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"dry_sequence_breakers", sampling.dry_sequence_breakers}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"stop", antiprompt}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"logit_bias", format_logit_bias(sampling.logit_bias)}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"grammar", sampling.grammar}, - {"grammar_lazy", sampling.grammar_lazy}, - {"grammar_triggers", grammar_triggers}, - {"preserved_tokens", sampling.preserved_tokens}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } -}; - -struct server_task { - int id = -1; // to be filled by server_queue - int index = -1; // used when there are multiple prompts (batch request) - - // used by SERVER_TASK_TYPE_CANCEL - int id_target = -1; - int id_slot = -1; - - // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - server_tokens tokens; - - server_task_type type; - - // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE - struct slot_action { - int slot_id; - std::string filename; - std::string filepath; - }; - slot_action slot_action; - - // used by SERVER_TASK_TYPE_METRICS - bool metrics_reset_bucket = false; - - // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; - - server_task() = default; - - server_task(server_task_type type) : type(type) {} - - int32_t n_tokens() const { - return tokens.size(); - } - - static slot_params params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - slot_params params; - - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - slot_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.antiprompt = params_base.antiprompt; - - // enabling this will output extra debug information in the HTTP responses from the server - params.verbose = params_base.verbosity > 9; - params.timings_per_token = json_value(data, "timings_per_token", false); - - params.stream = json_value(data, "stream", false); - auto stream_opt = json_value(data, "stream_options", json::object()); - params.include_usage = json_value(stream_opt, "include_usage", false); - params.cache_prompt = json_value(data, "cache_prompt", true); - params.return_tokens = json_value(data, "return_tokens", false); - params.return_progress = json_value(data, "return_progress", false); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - params.n_indent = json_value(data, "n_indent", defaults.n_indent); - params.n_keep = json_value(data, "n_keep", defaults.n_keep); - params.n_discard = json_value(data, "n_discard", defaults.n_discard); - //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - - params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - - params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); - params.speculative.n_min = std::max(params.speculative.n_min, 0); - params.speculative.n_max = std::max(params.speculative.n_max, 0); - - // Use OpenAI API logprobs only if n_probs wasn't provided - if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ - params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); - } - - if (data.contains("lora")) { - if (data.at("lora").is_array()) { - params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); - } else { - throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); - } - } else { - params.lora = params_base.lora_adapters; - } - - // TODO: add more sanity checks for the input parameters - - if (params.sampling.penalty_last_n < -1) { - throw std::runtime_error("Error: repeat_last_n must be >= -1"); - } - - if (params.sampling.dry_penalty_last_n < -1) { - throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); - } - - if (params.sampling.penalty_last_n == -1) { - // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_base < 1.0f) { - params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (params.sampling.dry_sequence_breakers.empty()) { - throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - params.sampling.grammar = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); - } catch (const std::exception & e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); - params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); - SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); - } - - { - auto it = data.find("chat_format"); - if (it != data.end()) { - params.oaicompat_chat_syntax.format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); - } else { - params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; - } - common_reasoning_format reasoning_format = params_base.reasoning_format; - if (data.contains("reasoning_format")) { - reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); - } - params.oaicompat_chat_syntax.reasoning_format = reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); - params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); - } - - { - const auto preserved_tokens = data.find("preserved_tokens"); - if (preserved_tokens != data.end()) { - for (const auto & t : *preserved_tokens) { - auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Preserved token: %d\n", ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - } else { - // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); - } - } - } - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - server_grammar_trigger ct(t); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto & word = ct.value.value; - auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - auto token = ids[0]; - if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { - throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); - } - SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = word; - trigger.token = token; - params.sampling.grammar_triggers.push_back(std::move(trigger)); - } else { - SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); - params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); - } - } else { - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { - SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); - } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { - SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); - } else { - throw std::runtime_error("Unknown grammar trigger type"); - } - params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); - } - } - } - if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { - throw std::runtime_error("Error: no triggers set for lazy grammar!"); - } - } - - { - params.sampling.logit_bias.clear(); - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(vocab, el[0].get(), false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } else if (logit_bias != data.end() && logit_bias->is_object()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : logit_bias->items()) { - float bias; - const auto & key = el.key(); - const auto & value = el.value(); - if (value.is_number()) { - bias = value.get(); - } else if (value.is_boolean() && !value.get()) { - bias = -INFINITY; - } else { - continue; - } - - char *end; - llama_token tok = strtol(key.c_str(), &end, 10); - if (*end == 0) { - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else { - auto toks = common_tokenize(vocab, key, false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - - params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); - if (params.sampling.ignore_eos) { - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); - } - } - - { - params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - params.antiprompt.push_back(word); - } - } - } - // set reverse prompt from cli args if not set in the request - if (params.antiprompt.empty()) { - params.antiprompt = defaults.antiprompt; - } - } - - { - const auto samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(*samplers, false); - } else if (samplers->is_string()){ - params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); - } - } else { - params.sampling.samplers = defaults.sampling.samplers; - } - } - - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - params.oaicompat_model = json_value(data, "model", model_name); - - return params; - } - - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } -}; - -struct result_timings { - int32_t cache_n = -1; - - int32_t prompt_n = -1; - double prompt_ms; - double prompt_per_token_ms; - double prompt_per_second; - - int32_t predicted_n = -1; - double predicted_ms; - double predicted_per_token_ms; - double predicted_per_second; - - // Optional speculative metrics - only included when > 0 - int32_t draft_n = 0; - int32_t draft_n_accepted = 0; - - json to_json() const { - json base = { - {"cache_n", cache_n}, - - {"prompt_n", prompt_n}, - {"prompt_ms", prompt_ms}, - {"prompt_per_token_ms", prompt_per_token_ms}, - {"prompt_per_second", prompt_per_second}, - - {"predicted_n", predicted_n}, - {"predicted_ms", predicted_ms}, - {"predicted_per_token_ms", predicted_per_token_ms}, - {"predicted_per_second", predicted_per_second}, - }; - - if (draft_n > 0) { - base["draft_n"] = draft_n; - base["draft_n_accepted"] = draft_n_accepted; - } - - return base; - } -}; - -struct result_prompt_progress { - int32_t total = 0; - int32_t cache = 0; - int32_t processed = 0; - int64_t time_ms = 0; - - json to_json() const { - return json { - {"total", total}, - {"cache", cache}, - {"processed", processed}, - {"time_ms", time_ms}, - }; - } -}; - -struct server_task_result { - int id = -1; - int id_slot = -1; - virtual bool is_error() { - // only used by server_task_result_error - return false; - } - virtual bool is_stop() { - // only used by server_task_result_cmpl_* - return true; - } - virtual int get_index() { - return -1; - } - virtual json to_json() = 0; - virtual ~server_task_result() = default; -}; - -// using shared_ptr for polymorphism of server_task_result -using server_task_result_ptr = std::unique_ptr; - -static inline std::string stop_type_to_str(stop_type type) { - switch (type) { - case STOP_TYPE_EOS: return "eos"; - case STOP_TYPE_WORD: return "word"; - case STOP_TYPE_LIMIT: return "limit"; - default: return "none"; - } -} - -struct completion_token_output { - llama_token tok; - float prob; - std::string text_to_send; - struct prob_info { - llama_token tok; - std::string txt; - float prob; - }; - std::vector probs; - - json to_json(bool post_sampling_probs) const { - json probs_for_token = json::array(); - for (const auto & p : probs) { - std::string txt(p.txt); - txt.resize(validate_utf8(txt)); - probs_for_token.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.txt)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - }); - } - return probs_for_token; - } - - static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { - json out = json::array(); - for (const auto & p : probs) { - std::string txt(p.text_to_send); - txt.resize(validate_utf8(txt)); - out.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, - }); - } - return out; - } - - static float logarithm(float x) { - // nlohmann::json converts -inf to null, so we need to prevent that - return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); - } - - static std::vector str_to_bytes(const std::string & str) { - std::vector bytes; - for (unsigned char c : str) { - bytes.push_back(c); - } - return bytes; - } -}; - -struct server_task_result_cmpl_final : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - bool stream; - bool include_usage; - result_timings timings; - std::string prompt; - - bool truncated; - int32_t n_decoded; - int32_t n_prompt_tokens; - int32_t n_tokens_cached; - bool has_new_line; - std::string stopping_word; - stop_type stop = STOP_TYPE_NONE; - - bool post_sampling_probs; - std::vector probs_output; - std::vector response_fields; - - slot_params generation_params; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; - - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return true; // in stream mode, final responses are considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - json res = json { - {"index", index}, - {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"tokens", stream ? llama_tokens {} : tokens}, - {"id_slot", id_slot}, - {"stop", true}, - {"model", oaicompat_model}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - {"generation_settings", generation_params.to_json()}, - {"prompt", prompt}, - {"has_new_line", has_new_line}, - {"truncated", truncated}, - {"stop_type", stop_type_to_str(stop)}, - {"stopping_word", stopping_word}, - {"tokens_cached", n_tokens_cached}, - {"timings", timings.to_json()}, - }; - if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); - } - return response_fields.empty() ? res : json_get_nested_values(response_fields, res); - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (!stream && probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - json finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } - json res = json { - {"choices", json::array({ - json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", finish_reason}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - std::string finish_reason = "length"; - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } else { - msg.role = "assistant"; - msg.content = content; - } - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json choice { - {"finish_reason", finish_reason}, - {"index", 0}, - {"message", msg.to_json_oaicompat()}, - }; - - if (!stream && probs_output.size() > 0) { - choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - - std::time_t t = std::time(0); - - json res = json { - {"choices", json::array({choice})}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat_stream() { - std::time_t t = std::time(0); - std::string finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json deltas = json::array(); - for (const auto & diff : oaicompat_msg_diffs) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - } - - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - - if (include_usage) { - // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage - // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices - deltas.push_back({ - {"choices", json::array()}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); - } - - if (timings.prompt_n >= 0) { - deltas.back().push_back({"timings", timings.to_json()}); - } - - // extra fields for debugging purposes - if (verbose && !deltas.empty()) { - deltas.front()["__verbose"] = to_json_non_oaicompat(); - } - - return deltas; - } -}; - -struct server_task_result_cmpl_partial : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - int32_t n_decoded; - int32_t n_prompt_tokens; - - bool post_sampling_probs; - bool is_progress = false; - completion_token_output prob_output; - result_timings timings; - result_prompt_progress progress; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return false; // in stream mode, partial responses are not considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - // non-OAI-compat JSON - json res = json { - {"index", index}, - {"content", content}, - {"tokens", tokens}, - {"stop", false}, - {"id_slot", id_slot}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - }; - // populate the timings object when needed (usually for the last response or with timings_per_token enabled) - if (timings.prompt_n > 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - if (!prob_output.probs.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); - } - return res; - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (prob_output.probs.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - json res = json { - {"choices", json::array({ - json{ - {"text", content}, - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", nullptr}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - bool first = n_decoded == 1; - std::time_t t = std::time(0); - json choices; - - std::vector deltas; - auto add_delta = [&](const json & delta) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", delta}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - }; - // We have to send an initial update to conform to openai behavior - if (first || is_progress) { - add_delta({ - {"role", "assistant"}, - {"content", nullptr}, - }); - } - - for (const auto & diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); - } - - if (!deltas.empty()) { - auto & last_json = deltas[deltas.size() - 1]; - GGML_ASSERT(last_json.at("choices").size() >= 1); - - if (prob_output.probs.size() > 0) { - last_json.at("choices").at(0)["logprobs"] = json { - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - - if (timings.prompt_n >= 0) { - last_json.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - last_json.push_back({"prompt_progress", progress.to_json()}); - } - } - - return deltas; - } -}; - -struct server_task_result_embd : server_task_result { - int index = 0; - std::vector> embedding; - - int32_t n_tokens; - - // OAI-compat fields - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? to_json_oaicompat() - : to_json_non_oaicompat(); - } - - json to_json_non_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding}, - }; - } - - json to_json_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding[0]}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -struct server_task_result_rerank : server_task_result { - int index = 0; - float score = -1e6; - - int32_t n_tokens; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return json { - {"index", index}, - {"score", score}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -// this function maybe used outside of server_task_result_error -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - case ERROR_TYPE_EXCEED_CONTEXT_SIZE: - type_str = "exceed_context_size_error"; - code = 400; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -} - -struct server_task_result_error : server_task_result { - int index = 0; - error_type err_type = ERROR_TYPE_SERVER; - std::string err_msg; - - // for ERROR_TYPE_EXCEED_CONTEXT_SIZE - int32_t n_prompt_tokens = 0; - int32_t n_ctx = 0; - - virtual bool is_error() override { - return true; - } - - virtual json to_json() override { - json res = format_error_response(err_msg, err_type); - if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - res["n_prompt_tokens"] = n_prompt_tokens; - res["n_ctx"] = n_ctx; - } - return res; - } -}; - -struct server_task_result_metrics : server_task_result { - int n_idle_slots; - int n_processing_slots; - int n_tasks_deferred; - int64_t t_start; - - // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - // while we can also use std::vector this requires copying the slot object which can be quite messy - // therefore, we use json to temporarily store the slot.to_json() result - json slots_data = json::array(); - - virtual json to_json() override { - return json { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", n_tasks_deferred }, - { "t_start", t_start }, - - { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, - { "t_tokens_generation_total", t_tokens_generation_total }, - { "n_tokens_predicted_total", n_tokens_predicted_total }, - { "t_prompt_processing_total", t_prompt_processing_total }, - - { "n_tokens_max", n_tokens_max }, - - { "n_prompt_tokens_processed", n_prompt_tokens_processed }, - { "t_prompt_processing", t_prompt_processing }, - { "n_tokens_predicted", n_tokens_predicted }, - { "t_tokens_generation", t_tokens_generation }, - - { "n_decode_total", n_decode_total }, - { "n_busy_slots_total", n_busy_slots_total }, - - { "slots", slots_data }, - }; - } -}; - -struct server_task_result_slot_save_load : server_task_result { - std::string filename; - bool is_save; // true = save, false = load - - size_t n_tokens; - size_t n_bytes; - double t_ms; - - virtual json to_json() override { - if (is_save) { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", n_tokens }, - { "n_written", n_bytes }, - { "timings", { - { "save_ms", t_ms } - }}, - }; - } - - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; - } -}; - -struct server_task_result_slot_erase : server_task_result { - size_t n_erased; - - virtual json to_json() override { - return json { - { "id_slot", id_slot }, - { "n_erased", n_erased }, - }; - } -}; - -struct server_task_result_apply_lora : server_task_result { - virtual json to_json() override { - return json {{ "success", true }}; - } -}; - -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; - - size_t size() const { - return data.size(); - } +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, }; -struct server_prompt { - server_tokens tokens; - - std::vector data; - - std::list checkpoints; - - size_t size() const { - size_t res = data.size(); - - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); - } - - return res; - } - - int n_tokens() const { - return tokens.size(); - } +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded }; -struct server_prompt_cache { - server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { - this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); - this->limit_tokens = limit_tokens; - } - - std::list states; - - // in bytes, 0 = no limit - size_t limit_size = 0; - - // in tokens, 0 = no limit - size_t limit_tokens = 0; - - size_t size() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.size(); - } - - return res; - } - - size_t n_tokens() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.n_tokens(); - } - - return res; - } - - server_prompt * alloc(const server_prompt & prompt, size_t state_size) { - // first check if the current state is contained fully in the cache - for (auto it = states.begin(); it != states.end(); ++it) { - const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); - - if (cur_lcp_len == (int) prompt.tokens.size()) { - SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); - return nullptr; - } - } - - // next, remove any cached prompts that are fully contained in the current prompt - for (auto it = states.begin(); it != states.end();) { - const int len = it->tokens.get_common_prefix(prompt.tokens); - - if (len == (int) it->tokens.size()) { - SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); - - it = states.erase(it); - } else { - ++it; - } - } - - std::vector state_data; - - // check if we can allocate enough memory for the new state - try { - state_data.resize(state_size); - } catch (const std::bad_alloc & e) { - SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); - - limit_size = std::max(1, 0.4*size()); - - SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); - - update(); - - return nullptr; - } - - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround - auto & cur = states.emplace_back(); - cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), - /*.data =*/ std::move(state_data), - /*.checkpoints =*/ prompt.checkpoints, - }; - - return &cur; - } - - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { - const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); - - float f_keep_best = float(lcp_best) / prompt.tokens.size(); - float sim_best = float(lcp_best) / tokens_new.size(); - - SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - auto it_best = states.end(); - - // find the most similar cached prompt, that would also preserve the most context - for (auto it = states.begin(); it != states.end(); ++it) { - const int lcp_cur = it->tokens.get_common_prefix(tokens_new); - - const float f_keep_cur = float(lcp_cur) / it->tokens.size(); - const float sim_cur = float(lcp_cur) / tokens_new.size(); - - // don't trash large prompts - if (f_keep_cur < 0.25f) { - continue; - } - - if (f_keep_best < f_keep_cur && sim_best < sim_cur) { - f_keep_best = f_keep_cur; - sim_best = sim_cur; - - it_best = it; - } - } - - if (it_best != states.end()) { - SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); - - return false; - } - - it_best->data.clear(); - it_best->data.shrink_to_fit(); - - prompt = std::move(*it_best); - - states.erase(it_best); - } - - return true; +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; } +} - void update() { - if (limit_size > 0) { - // always keep at least one state, regardless of the limits - while (states.size() > 1 && size() > limit_size) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - // average size per token - const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); - - // dynamically increase the token limit if it can fit in the memory limit - const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; - - if (limit_tokens > 0) { - while (states.size() > 1 && n_tokens() > limit_tokens_cur) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", - limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", - states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); - - for (const auto & state : states) { - SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", - (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); - } +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; } -}; +} struct server_slot { int id; @@ -2022,303 +468,6 @@ struct server_metrics { } }; -struct server_queue { - int id = 0; - bool running; - - // queues - std::deque queue_tasks; - std::deque queue_tasks_deferred; - - std::mutex mutex_tasks; - std::condition_variable condition_tasks; - - // callback functions - std::function callback_new_task; - std::function callback_update_slots; - - // Add a new task to the end of the queue - int post(server_task && task, bool front = false) { - std::unique_lock lock(mutex_tasks); - GGML_ASSERT(task.id != -1); - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - const int task_id = task.id; - QUE_DBG("new task, id = %d, front = %d\n", task_id, front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - condition_tasks.notify_one(); - return task_id; - } - - // multi-task version of post() - int post(std::vector && tasks, bool front = false) { - std::unique_lock lock(mutex_tasks); - for (auto & task : tasks) { - if (task.id == -1) { - task.id = id++; - } - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - } - condition_tasks.notify_one(); - return 0; - } - - // Add a new task, but defer until one slot is available - void defer(server_task && task) { - std::unique_lock lock(mutex_tasks); - QUE_DBG("defer task, id = %d\n", task.id); - queue_tasks_deferred.push_back(std::move(task)); - condition_tasks.notify_one(); - } - - // Get the next id for creating a new task - int get_new_id() { - std::unique_lock lock(mutex_tasks); - int new_id = id++; - return new_id; - } - - // Register function to process a new task - void on_new_task(std::function callback) { - callback_new_task = std::move(callback); - } - - // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback) { - callback_update_slots = std::move(callback); - } - - // Call when the state of one slot is changed, it will move one task from deferred to main queue - void pop_deferred_task() { - std::unique_lock lock(mutex_tasks); - if (!queue_tasks_deferred.empty()) { - queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); - queue_tasks_deferred.pop_front(); - } - condition_tasks.notify_one(); - } - - // end the start_loop routine - void terminate() { - std::unique_lock lock(mutex_tasks); - running = false; - condition_tasks.notify_all(); - } - - /** - * Main loop consists of these steps: - * - Wait until a new task arrives - * - Process the task (i.e. maybe copy data into slot) - * - Check if multitask is finished - * - Update all slots - */ - void start_loop() { - running = true; - - while (true) { - QUE_DBG("%s", "processing new tasks\n"); - - while (true) { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - server_task task = std::move(queue_tasks.front()); - queue_tasks.pop_front(); - lock.unlock(); - - QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(std::move(task)); - } - - // all tasks in the current loop is processed, slots data is now ready - QUE_DBG("%s", "update slots\n"); - - callback_update_slots(); - - QUE_DBG("%s", "waiting for new tasks\n"); - { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - condition_tasks.wait(lock, [&]{ - return (!queue_tasks.empty() || !running); - }); - } - } - } - } - -private: - void cleanup_pending_task(int id_target) { - // no need lock because this is called exclusively by post() - auto rm_func = [id_target](const server_task & task) { - return task.id == id_target; - }; - queue_tasks.erase( - std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), - queue_tasks.end()); - queue_tasks_deferred.erase( - std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), - queue_tasks_deferred.end()); - } -}; - -struct server_response { - bool running = true; - - // for keeping track of all tasks waiting for the result - std::unordered_set waiting_task_ids; - - // the main result queue (using ptr for polymorphism) - std::vector queue_results; - - std::mutex mutex_results; - std::condition_variable condition_results; - - // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(id_task); - } - - void add_waiting_tasks(const std::vector & tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & task : tasks) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); - waiting_task_ids.insert(task.id); - } - } - - // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(id_task); - // make sure to clean up all pending results - queue_results.erase( - std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { - return res->id == id_task; - }), - queue_results.end()); - } - - void remove_waiting_task_ids(const std::unordered_set & id_tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & id_task : id_tasks) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - waiting_task_ids.erase(id_task); - } - } - - // This function blocks the thread until there is a response for one of the id_tasks - server_task_result_ptr recv(const std::unordered_set & id_tasks) { - while (true) { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - return !queue_results.empty(); - }); - - for (size_t i = 0; i < queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here - } - - // same as recv(), but have timeout in seconds - // if timeout is reached, nullptr is returned - server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { - while (true) { - std::unique_lock lock(mutex_results); - - for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - - std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - if (cr_res == std::cv_status::timeout) { - return nullptr; - } - } - - // should never reach here - } - - // single-task version of recv() - server_task_result_ptr recv(int id_task) { - std::unordered_set id_tasks = {id_task}; - return recv(id_tasks); - } - - // Send a new result to a waiting id_task - void send(server_task_result_ptr && result) { - SRV_DBG("sending result for task id = %d\n", result->id); - - std::unique_lock lock(mutex_results); - for (const auto & id_task : waiting_task_ids) { - if (result->id == id_task) { - SRV_DBG("task id = %d pushed to result queue\n", result->id); - - queue_results.emplace_back(std::move(result)); - condition_results.notify_all(); - return; - } - } - } - - // terminate the waiting loop - void terminate() { - running = false; - condition_results.notify_all(); - } -}; - struct server_context { common_params params_base; @@ -3323,7 +1472,7 @@ struct server_context { res->slots_data = std::move(slots_data); res->n_idle_slots = n_idle_slots; res->n_processing_slots = n_processing_slots; - res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); res->t_start = metrics.t_start; res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; @@ -4645,7 +2794,7 @@ struct server_routes { json default_generation_settings_for_props; { - slot_params params; + task_params params; params.sampling = ctx_server.params_base.sampling; @@ -4784,7 +2933,7 @@ struct server_routes { std::string prompt = json_value(data, "prompt", std::string()); std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_infill( + data["prompt"] = format_prompt_infill( ctx_server.vocab, data.at("input_prefix"), data.at("input_suffix"), @@ -4939,8 +3088,7 @@ struct server_routes { } } - const json data = format_tokenizer_response(tokens_response); - res->ok(data); + res->ok(json{{"tokens", std::move(tokens_response)}}); return res; }; @@ -4951,11 +3099,10 @@ struct server_routes { std::string content; if (body.count("tokens") != 0) { const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + content = tokens_to_str(ctx_server.ctx, tokens); } - const json data = format_detokenized_response(content); - res->ok(data); + res->ok(json{{"content", std::move(content)}}); return res; }; @@ -5009,7 +3156,7 @@ struct server_routes { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; @@ -5460,10 +3607,10 @@ struct server_routes { } }; -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; +static std::function shutdown_handler; +static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; -inline void signal_handler(int signal) { +static inline void signal_handler(int signal) { if (is_terminating.test_and_set()) { // in case it hangs, we can force terminate the server by hitting Ctrl+C twice // this is for better developer experience, we can remove when the server is stable enough @@ -5632,6 +3779,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.terminate(); }; + // TODO: refactor in common/console #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler;