From 7a7de2a27e8bc3541eaedef1ecc64339309299ed Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 15 Nov 2025 23:26:31 +0100 Subject: [PATCH 1/3] fully working model maganement --- tools/server/server-http.cpp | 85 ++++++++++++++++++++ tools/server/server-http.h | 76 ++++++++++++++++++ tools/server/server.cpp | 149 ++++++++++++++++++++++++++++++++++- tools/server/utils.hpp | 130 ++++++++++++++++++++++++++++++ 4 files changed, 439 insertions(+), 1 deletion(-) diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 196ced443261a..103e4a449bb89 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -383,3 +383,88 @@ void server_http_context::post(const std::string & path, server_http_context::ha }); } + +// +// server_http_client +// + +server_http_client::server_http_client( + const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop) { + // shared between reader and writer threads + auto cli = std::make_shared(host, port); + auto pipe = std::make_shared>(); + + // setup Client + cli->set_connection_timeout(0, 200000); // 200 milliseconds + this->status = 500; // to be overwritten upon response + this->cleanup = [pipe]() { + pipe->close_read(); + pipe->close_write(); + }; + + // wire up the receive end of the pipe + this->next = [pipe, should_stop](std::string & out) -> bool { + msg_t msg; + bool has_next = pipe->read(msg, should_stop); + if (!msg.data.empty()) { + out = std::move(msg.data); + } + return has_next; + }; + + // wire up the HTTP client + // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends + httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) { + msg_t msg; + msg.status = response.status; + for (const auto & [key, value] : response.headers) { + msg.headers[key] = value; + } + pipe->write(std::move(msg)); // send headers first + return true; + }; + httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { + return pipe->write({{}, 0, std::string(data, data_length)}); // send data chunks + }; + + // prepare the request to destination server + httplib::Request req; + { + req.method = method; + req.path = path; + for (const auto & [key, value] : headers) { + req.set_header(key, value); + } + req.body = body; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + } + + // start the proxy thread + SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str()); + this->thread = std::thread([cli, pipe, req]() { + auto result = cli->send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("http client error: %s\n", err_str.c_str()); + pipe->write({{}, 500, ""}); // header + pipe->write({{}, 0, "proxy error: " + err_str}); // body + } + pipe->close_write(); // signal EOF to reader + SRV_DBG("%s", "client request thread ended\n"); + }); + this->thread.detach(); + + // wait for the first chunk (headers) + msg_t header; + pipe->read(header, should_stop); + SRV_DBG("%s", "received response headers\n"); + this->status = header.status; + this->headers = header.headers; +} diff --git a/tools/server/server-http.h b/tools/server/server-http.h index dc6ca92fd8751..977b30d3c70b7 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -75,3 +75,79 @@ struct server_http_context { // for debugging std::string listening_address; }; + + + +#include +#include +#include +#include + +struct server_http_client : server_http_res { + std::function cleanup = nullptr; +public: + server_http_client(const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop); + ~server_http_client() { + if (cleanup) { + cleanup(); + } + } +private: + std::thread thread; + struct msg_t { + std::map headers; + int status = 0; + std::string data; + }; + // simple implementation of a pipe + template + struct pipe_t { + std::mutex mutex; + std::condition_variable cv; + std::queue queue; + std::atomic writer_closed{false}; + std::atomic reader_closed{false}; + void close_write() { + writer_closed.store(true); + cv.notify_all(); + } + void close_read() { + reader_closed.store(true); + cv.notify_all(); + } + bool read(T & output, const std::function & should_stop) { + std::unique_lock lk(mutex); + constexpr auto poll_interval = std::chrono::milliseconds(500); + while (true) { + if (!queue.empty()) { + output = std::move(queue.front()); + queue.pop(); + return true; + } + if (writer_closed.load()) { + return false; // clean EOF + } + if (should_stop()) { + close_read(); // signal broken pipe to writer + return false; // cancelled / reader no longer alive + } + cv.wait_for(lk, poll_interval); + } + } + bool write(T && data) { + std::lock_guard lk(mutex); + if (reader_closed.load()) { + return false; // broken pipe + } + queue.push(std::move(data)); + cv.notify_one(); + return true; + } + }; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 1c9e9a58d7daf..186b1df312c1d 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5108,6 +5108,106 @@ struct server_routes { return res; }; + // + // router server + // + char ** envp; + std::map map_model_to_port; + void maybe_load_it_why_not(std::string & custom_model) { + // HACKYYYY, but for demo purpose; we load the model if it's in the cached list + if (map_model_to_port.find(custom_model) != map_model_to_port.end()) { + return; // already loaded, do nothing + } + auto models = common_list_cached_models(); + for (const auto & model : models) { + auto m = model.to_string(); + if (m == custom_model) { + server_router_create_instance(envp, map_model_to_port, m); + std::this_thread::sleep_for(std::chrono::seconds(5)); // hacky wait for the process to be ready + return; // nice + } + } + } + std::string get_one_if_has_only_one(std::string & custom_model) { + // HACKYYYY, but for demo purpose; we get the only model if there's only one + if (map_model_to_port.size() == 1) { + return map_model_to_port.begin()->first; + } + return custom_model; + } + server_http_context::handler_t proxy_get = [this](const server_http_req & req) { + std::string method = "GET"; + std::string model = req.get_param("model"); + maybe_load_it_why_not(model); + model = get_one_if_has_only_one(model); + return handle_proxy(req, method, model); + }; + server_http_context::handler_t proxy_post = [this](const server_http_req & req) { + std::string method = "POST"; + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + maybe_load_it_why_not(model); + model = get_one_if_has_only_one(model); + return handle_proxy(req, method, model); + }; + server_http_res_ptr handle_proxy(const server_http_req & req, std::string & method, std::string model) { + if (map_model_to_port.find(model) == map_model_to_port.end()) { + auto res = std::make_unique(ctx_server); + res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST)); + return server_http_res_ptr(std::move(res)); + } + server_http_res_ptr res(new server_http_client( + method, params.hostname, map_model_to_port[model].port, + req.path, req.headers, req.body, req.should_stop + )); + return res; + } + server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + int status = server_router_create_instance(envp, map_model_to_port, model); + if (status != 0) { + res->error(format_error_response("fail to start the process", ERROR_TYPE_SERVER)); + return res; + } + res->ok({{"success", true}}); + return res; + }; + server_http_context::handler_t get_router_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json models_json = json::array(); + auto models = common_list_cached_models(); + for (const auto & model : models) { + auto model_name = model.to_string(); + bool loaded = map_model_to_port.find(model.to_string()) != map_model_to_port.end(); // TODO: thread safety + models_json.push_back(json { + {"model", model_name}, + {"name", model_name}, + {"id", model_name}, + // TODO: other fields... + {"status", { + {"value", loaded ? "loaded" : "unloaded"} + }}, + }); + } + res->ok({{"data", models_json}}); + return res; + }; + server_http_context::handler_t post_router_models_unload = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + model = get_one_if_has_only_one(model); + if (map_model_to_port.find(model) == map_model_to_port.end()) { + res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + server_router_kill_single(map_model_to_port, model); + res->ok({{"success", true}}); + return res; + }; + private: std::unique_ptr handle_completions_impl( server_task_type type, @@ -5501,7 +5601,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t }; } -int main(int argc, char ** argv) { +int main(int argc, char ** argv, char ** envp) { // own arguments required by this example common_params params; @@ -5549,6 +5649,34 @@ int main(int argc, char ** argv) { // register API routes server_routes routes(params, ctx_server, ctx_http); + // hacky, replace handlers with proxy handlers if this is a router server + bool is_router_server = params.model.path == DEFAULT_MODEL_PATH; + if (is_router_server) { + routes.envp = envp; + routes.get_props = routes.proxy_get; + routes.post_props = routes.proxy_post; + // routes.get_models = routes.proxy_get; + routes.post_completions = routes.proxy_post; + routes.post_completions_oai = routes.proxy_post; + routes.post_chat_completions = routes.proxy_post; + routes.post_infill = routes.proxy_post; + routes.post_embeddings = routes.proxy_post; + routes.post_embeddings_oai = routes.proxy_post; + routes.post_rerank = routes.proxy_post; + routes.post_tokenize = routes.proxy_post; + routes.post_detokenize = routes.proxy_post; + routes.post_apply_template = routes.proxy_post; + routes.get_lora_adapters = routes.proxy_get; + routes.post_lora_adapters = routes.proxy_post; + routes.get_slots = routes.proxy_get; + routes.post_slots = routes.proxy_post; + + // custom routes for router + routes.get_models = routes.get_router_models; + ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load)); + ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload)); + } + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); @@ -5594,6 +5722,8 @@ int main(int argc, char ** argv) { llama_backend_free(); }; +if (!is_router_server) { // HACKY + // start the HTTP server before loading the model to be able to serve /health requests if (!ctx_http.start()) { clean_up(); @@ -5631,6 +5761,8 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.terminate(); }; +} // end of !is_router_server + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; @@ -5645,6 +5777,8 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif +if (!is_router_server) { // HACKY + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); LOG_INF("%s: starting the main loop...\n", __func__); // this call blocks the main thread until queue_tasks.terminate() is called @@ -5655,6 +5789,19 @@ int main(int argc, char ** argv) { ctx_http.thread.join(); } llama_memory_breakdown_print(ctx_server.ctx); +} else { + shutdown_handler = [&](int) { + ctx_http.stop(); + }; + if (!ctx_http.start()) { + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; + } + ctx_http.is_ready.store(true); + ctx_http.thread.join(); // keep the main thread alive + // kill_all_instances(routes.map_model_to_port); // why this also kill the main instance? + LOG_INF("%s: server stopped\n", __func__); +} // end of !is_router_server return 0; } diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index bf21726051e55..f8f4c267e0a42 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1555,3 +1555,133 @@ static server_tokens format_rerank(const struct llama_model * model, const struc return result; } + + + +// +// router server utils +// + +#include "utils.hpp" +#include "download.h" + +#include +#include +#include +#include // for kill() + +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include +#include +#endif + +inline std::filesystem::path server_router_get_server_exec_path() { +#if defined(_MSC_VER) + wchar_t path[FILENAME_MAX] = { 0 }; + GetModuleFileNameW(nullptr, path, FILENAME_MAX); + return std::filesystem::path(path); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + return std::filesystem::path(std::string(buf.data(), (size > 0) ? size : 0)); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + return std::filesystem::path(std::string(path, (count > 0) ? count: 0)); +#endif +} + +struct server_spawn_instance { + pid_t pid = 0; + int port = 0; + std::thread th; +}; + +inline int server_router_create_instance(char ** envp, std::map & mapping, const std::string & hf_model) { + server_spawn_instance inst; + inst.port = rand() % 10000 + 20000; // random port between 20000 and 29999 + + if (mapping.find(hf_model) != mapping.end()) { + throw std::runtime_error("model already loaded"); + } + + pid_t pid = 0; + { + // Prepare arguments (pass original or custom ones) using mutable storage for argv + std::string path = server_router_get_server_exec_path().string(); + + SRV_INF("spawning instance %s with hf=%s on port %d\n", path.c_str(), hf_model.c_str(), inst.port); + std::vector arg_strs; + arg_strs.push_back(path); + arg_strs.push_back("-hf"); + arg_strs.push_back(hf_model); + arg_strs.push_back("--port"); + arg_strs.push_back(std::to_string(inst.port)); + + std::vector child_argv; + child_argv.reserve(arg_strs.size() + 1); + for (auto &s : arg_strs) { + child_argv.push_back(const_cast(s.c_str())); + } + child_argv.push_back(nullptr); + + if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), envp) != 0) { + perror("posix_spawn"); + exit(1); // for testing only + } else { + inst.pid = pid; + SRV_INF("spawned instance with pid %d\n", pid); + } + } + + inst.th = std::thread([hf_model, pid, &mapping]() { + int status = 0; + waitpid(pid, &status, 0); + SRV_INF("instance with pid %d exited with status %d\n", pid, status); + mapping.erase(hf_model); // TODO: thread safety + }); + if (inst.th.joinable()) { + inst.th.detach(); + } + + mapping[hf_model] = std::move(inst); // TODO: thread safety + return 0; +} + +inline void kill_all_instances(std::map & mapping) { + for (auto & [hf_model, inst] : mapping) { + LOG_INF("killing instance with hf=%s on port %d (pid %d)\n", hf_model.c_str(), inst.port, inst.pid); + kill(inst.pid, SIGINT); + } + mapping.clear(); +} + +inline void server_router_kill_single(std::map & mapping, const std::string & hf_model) { + auto it = mapping.find(hf_model); + if (it != mapping.end()) { + auto & inst = it->second; + LOG_INF("killing instance with hf=%s on port %d (pid %d)\n", hf_model.c_str(), inst.port, inst.pid); + kill(inst.pid, SIGINT); + mapping.erase(it); + } +} From 2a200683b0e9ed7c28dbe893f35536a1d402eecb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 16 Nov 2025 11:58:11 +0100 Subject: [PATCH 2/3] improve maybe_load_it_why_not --- tools/server/server.cpp | 65 ++++++++++++++++++++++++++++++----------- tools/server/utils.hpp | 22 ++++++++++++-- 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 186b1df312c1d..a0f0191253415 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5118,28 +5118,32 @@ struct server_routes { if (map_model_to_port.find(custom_model) != map_model_to_port.end()) { return; // already loaded, do nothing } + // TODO: maybe unload least recently used model if too many models are loaded? + auto wait_until_loaded = [this, custom_model]() { + while (true) { + bool load_failed = map_model_to_port.find(custom_model) == map_model_to_port.end(); // model is deleted + bool is_loaded = !load_failed && map_model_to_port[custom_model].status == "loaded"; + if (is_loaded || load_failed) { + return; + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + }; auto models = common_list_cached_models(); for (const auto & model : models) { auto m = model.to_string(); if (m == custom_model) { - server_router_create_instance(envp, map_model_to_port, m); - std::this_thread::sleep_for(std::chrono::seconds(5)); // hacky wait for the process to be ready - return; // nice + server_router_create_instance(envp, map_model_to_port, m, params.port); + wait_until_loaded(); + SRV_INF("model %s loaded on-demand\n", custom_model.c_str()); + return; } } } - std::string get_one_if_has_only_one(std::string & custom_model) { - // HACKYYYY, but for demo purpose; we get the only model if there's only one - if (map_model_to_port.size() == 1) { - return map_model_to_port.begin()->first; - } - return custom_model; - } server_http_context::handler_t proxy_get = [this](const server_http_req & req) { std::string method = "GET"; std::string model = req.get_param("model"); maybe_load_it_why_not(model); - model = get_one_if_has_only_one(model); return handle_proxy(req, method, model); }; server_http_context::handler_t proxy_post = [this](const server_http_req & req) { @@ -5147,7 +5151,6 @@ struct server_routes { json body = json::parse(req.body); std::string model = json_value(body, "model", std::string()); maybe_load_it_why_not(model); - model = get_one_if_has_only_one(model); return handle_proxy(req, method, model); }; server_http_res_ptr handle_proxy(const server_http_req & req, std::string & method, std::string model) { @@ -5166,7 +5169,7 @@ struct server_routes { auto res = std::make_unique(ctx_server); json body = json::parse(req.body); std::string model = json_value(body, "model", std::string()); - int status = server_router_create_instance(envp, map_model_to_port, model); + int status = server_router_create_instance(envp, map_model_to_port, model, params.port); if (status != 0) { res->error(format_error_response("fail to start the process", ERROR_TYPE_SERVER)); return res; @@ -5174,20 +5177,33 @@ struct server_routes { res->ok({{"success", true}}); return res; }; + server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + std::string value = json_value(body, "value", std::string()); + if (map_model_to_port.find(model) == map_model_to_port.end()) { + res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + map_model_to_port[model].status = value; + res->ok({{"success", true}}); + return res; + }; server_http_context::handler_t get_router_models = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); json models_json = json::array(); auto models = common_list_cached_models(); for (const auto & model : models) { auto model_name = model.to_string(); - bool loaded = map_model_to_port.find(model.to_string()) != map_model_to_port.end(); // TODO: thread safety + bool found = map_model_to_port.find(model.to_string()) != map_model_to_port.end(); // TODO: thread safety models_json.push_back(json { {"model", model_name}, {"name", model_name}, {"id", model_name}, // TODO: other fields... {"status", { - {"value", loaded ? "loaded" : "unloaded"} + {"value", found ? map_model_to_port[model_name].status : "unloaded"} }}, }); } @@ -5198,7 +5214,6 @@ struct server_routes { auto res = std::make_unique(ctx_server); json body = json::parse(req.body); std::string model = json_value(body, "model", std::string()); - model = get_one_if_has_only_one(model); if (map_model_to_port.find(model) == map_model_to_port.end()) { res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST)); return res; @@ -5673,8 +5688,9 @@ int main(int argc, char ** argv, char ** envp) { // custom routes for router routes.get_models = routes.get_router_models; - ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load)); + ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load)); ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload)); + ctx_http.post("/models/status", ex_wrapper(routes.post_router_models_status)); } ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) @@ -5779,6 +5795,21 @@ if (!is_router_server) { // HACKY if (!is_router_server) { // HACKY + // notify to main router if needed + char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); + if (router_port != nullptr) { + SRV_INF("%s: notifying to main router on port %s\n", __func__, router_port); + server_http_client notify_router( + "POST", params.hostname, std::atoi(router_port), + "/models/status", + { {"Content-Type", "application/json"} }, + json {{ "model", params.model_alias }, { "value", "loaded" }}.dump(), + []() { return false; } + ); + std::string dummy; + notify_router.next(dummy); // ignore the response + } + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); LOG_INF("%s: starting the main loop...\n", __func__); // this call blocks the main thread until queue_tasks.terminate() is called diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index f8f4c267e0a42..8fb616c2ceaca 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1615,9 +1615,10 @@ struct server_spawn_instance { pid_t pid = 0; int port = 0; std::thread th; + std::string status = "loading"; // "loading", "loaded" }; -inline int server_router_create_instance(char ** envp, std::map & mapping, const std::string & hf_model) { +inline int server_router_create_instance(char ** envp, std::map & mapping, const std::string & hf_model, int router_port) { server_spawn_instance inst; inst.port = rand() % 10000 + 20000; // random port between 20000 and 29999 @@ -1635,6 +1636,8 @@ inline int server_router_create_instance(char ** envp, std::map child_envs; + std::vector child_envp; + { + for (char ** e = envp; *e != nullptr; ++e) { + child_envs.emplace_back(*e); + } + child_envs.emplace_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(router_port)); + child_envp.reserve(child_envs.size() + 1); + for (auto & s : child_envs) { + child_envp.push_back(const_cast(s.c_str())); + } + child_envp.push_back(nullptr); + } + + if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), child_envp.data()) != 0) { perror("posix_spawn"); exit(1); // for testing only } else { From bb123af07db165b8ac18a11c8a212594a4ba7cc6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 17 Nov 2025 11:55:07 +0100 Subject: [PATCH 3/3] thread-safe --- tools/server/server.cpp | 91 ++++++++--------- tools/server/utils.hpp | 220 +++++++++++++++++++++++++++------------- 2 files changed, 188 insertions(+), 123 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a0f0191253415..f1bb23c736e77 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5111,69 +5111,54 @@ struct server_routes { // // router server // - char ** envp; - std::map map_model_to_port; - void maybe_load_it_why_not(std::string & custom_model) { - // HACKYYYY, but for demo purpose; we load the model if it's in the cached list - if (map_model_to_port.find(custom_model) != map_model_to_port.end()) { - return; // already loaded, do nothing - } - // TODO: maybe unload least recently used model if too many models are loaded? - auto wait_until_loaded = [this, custom_model]() { - while (true) { - bool load_failed = map_model_to_port.find(custom_model) == map_model_to_port.end(); // model is deleted - bool is_loaded = !load_failed && map_model_to_port[custom_model].status == "loaded"; - if (is_loaded || load_failed) { - return; - } - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - } - }; - auto models = common_list_cached_models(); - for (const auto & model : models) { - auto m = model.to_string(); - if (m == custom_model) { - server_router_create_instance(envp, map_model_to_port, m, params.port); - wait_until_loaded(); - SRV_INF("model %s loaded on-demand\n", custom_model.c_str()); - return; - } - } - } + server_instances instances; server_http_context::handler_t proxy_get = [this](const server_http_req & req) { std::string method = "GET"; std::string model = req.get_param("model"); - maybe_load_it_why_not(model); + if (req.path == "/props" && model.empty()) { + return handle_default_props(req); + } + instances.ensure_model_loaded(model); return handle_proxy(req, method, model); }; server_http_context::handler_t proxy_post = [this](const server_http_req & req) { std::string method = "POST"; json body = json::parse(req.body); std::string model = json_value(body, "model", std::string()); - maybe_load_it_why_not(model); + instances.ensure_model_loaded(model); return handle_proxy(req, method, model); }; server_http_res_ptr handle_proxy(const server_http_req & req, std::string & method, std::string model) { - if (map_model_to_port.find(model) == map_model_to_port.end()) { + auto meta = instances.get_meta(model); + if (!meta.has_value()) { auto res = std::make_unique(ctx_server); - res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST)); - return server_http_res_ptr(std::move(res)); + res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE)); + return res; } server_http_res_ptr res(new server_http_client( - method, params.hostname, map_model_to_port[model].port, + method, params.hostname, meta->port, req.path, req.headers, req.body, req.should_stop )); return res; } + server_http_res_ptr handle_default_props(const server_http_req &) { + auto res = std::make_unique(ctx_server); + // this is a dummy response to make sure webui doesn't break + res->ok({ + {"model_alias", "llama-server"}, + {"model_path", "none"}, + {"default_generation_settings", { + {"params", json{}}, + {"n_ctx", 0}, + }}, + }); + return res; + } server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); json body = json::parse(req.body); std::string model = json_value(body, "model", std::string()); - int status = server_router_create_instance(envp, map_model_to_port, model, params.port); - if (status != 0) { - res->error(format_error_response("fail to start the process", ERROR_TYPE_SERVER)); - return res; - } + instances.create(model); res->ok({{"success", true}}); return res; }; @@ -5182,11 +5167,12 @@ struct server_routes { json body = json::parse(req.body); std::string model = json_value(body, "model", std::string()); std::string value = json_value(body, "value", std::string()); - if (map_model_to_port.find(model) == map_model_to_port.end()) { - res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST)); + if (!instances.get_meta(model).has_value()) { + auto res = std::make_unique(ctx_server); + res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE)); return res; } - map_model_to_port[model].status = value; + instances.update_status(model, value); res->ok({{"success", true}}); return res; }; @@ -5196,14 +5182,15 @@ struct server_routes { auto models = common_list_cached_models(); for (const auto & model : models) { auto model_name = model.to_string(); - bool found = map_model_to_port.find(model.to_string()) != map_model_to_port.end(); // TODO: thread safety + auto meta = instances.get_meta(model_name); + bool found = meta.has_value(); models_json.push_back(json { {"model", model_name}, {"name", model_name}, {"id", model_name}, // TODO: other fields... {"status", { - {"value", found ? map_model_to_port[model_name].status : "unloaded"} + {"value", found ? meta->status : "unloaded"} }}, }); } @@ -5214,11 +5201,12 @@ struct server_routes { auto res = std::make_unique(ctx_server); json body = json::parse(req.body); std::string model = json_value(body, "model", std::string()); - if (map_model_to_port.find(model) == map_model_to_port.end()) { - res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST)); + if (!instances.get_meta(model).has_value()) { + auto res = std::make_unique(ctx_server); + res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE)); return res; } - server_router_kill_single(map_model_to_port, model); + instances.kill_single(model); res->ok({{"success", true}}); return res; }; @@ -5667,10 +5655,13 @@ int main(int argc, char ** argv, char ** envp) { // hacky, replace handlers with proxy handlers if this is a router server bool is_router_server = params.model.path == DEFAULT_MODEL_PATH; if (is_router_server) { - routes.envp = envp; + // setup server instances manager + routes.instances.envp = envp; + routes.instances.router_port = params.port; + + // proxy handlers routes.get_props = routes.proxy_get; routes.post_props = routes.proxy_post; - // routes.get_models = routes.proxy_get; routes.post_completions = routes.proxy_post; routes.post_completions_oai = routes.proxy_post; routes.post_chat_completions = routes.proxy_post; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 8fb616c2ceaca..cd4d98c1829f6 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1567,6 +1567,8 @@ static server_tokens format_rerank(const struct llama_model * model, const struc #include #include +#include +#include #include #include // for kill() @@ -1576,7 +1578,7 @@ static server_tokens format_rerank(const struct llama_model * model, const struc #include #endif -inline std::filesystem::path server_router_get_server_exec_path() { +static std::filesystem::path get_server_exec_path() { #if defined(_MSC_VER) wchar_t path[FILENAME_MAX] = { 0 }; GetModuleFileNameW(nullptr, path, FILENAME_MAX); @@ -1611,95 +1613,167 @@ inline std::filesystem::path server_router_get_server_exec_path() { #endif } -struct server_spawn_instance { - pid_t pid = 0; - int port = 0; - std::thread th; - std::string status = "loading"; // "loading", "loaded" -}; +struct server_instances { + struct instance_metadata_t { + // this struct is copyable + int port = 0; + std::string status = "loading"; // "loading", "loaded" + }; + struct instance_t { + pid_t pid = 0; + std::thread th; + instance_metadata_t meta; + }; -inline int server_router_create_instance(char ** envp, std::map & mapping, const std::string & hf_model, int router_port) { - server_spawn_instance inst; - inst.port = rand() % 10000 + 20000; // random port between 20000 and 29999 +private: + std::mutex mutex; + std::condition_variable cv; + std::map mapping; + + void remove(const std::string & model_name) { + std::unique_lock lock(mutex); + mapping.erase(model_name); + cv.notify_all(); + } - if (mapping.find(hf_model) != mapping.end()) { - throw std::runtime_error("model already loaded"); + void insert(const std::string & model_name, instance_t && inst) { + std::unique_lock lock(mutex); + if (mapping.find(model_name) != mapping.end()) { + throw std::runtime_error("instance with name=" + model_name + " already exists"); + } + mapping[model_name] = std::move(inst); + cv.notify_all(); } - pid_t pid = 0; - { - // Prepare arguments (pass original or custom ones) using mutable storage for argv - std::string path = server_router_get_server_exec_path().string(); +public: + char ** envp; + int router_port; - SRV_INF("spawning instance %s with hf=%s on port %d\n", path.c_str(), hf_model.c_str(), inst.port); - std::vector arg_strs; - arg_strs.push_back(path); - arg_strs.push_back("-hf"); - arg_strs.push_back(hf_model); - arg_strs.push_back("--alias"); - arg_strs.push_back(hf_model); - arg_strs.push_back("--port"); - arg_strs.push_back(std::to_string(inst.port)); + std::optional get_meta(const std::string & model_name) { + std::unique_lock lock(mutex); + auto it = mapping.find(model_name); + if (it != mapping.end()) { + return it->second.meta; + } + return std::nullopt; + } - std::vector child_argv; - child_argv.reserve(arg_strs.size() + 1); - for (auto &s : arg_strs) { - child_argv.push_back(const_cast(s.c_str())); + void update_status(const std::string & model_name, const std::string & status) { + std::unique_lock lock(mutex); + auto it = mapping.find(model_name); + if (it != mapping.end()) { + it->second.meta.status = status; + cv.notify_all(); } - child_argv.push_back(nullptr); + } + + void wait_until_loaded(const std::string & model_name) { + std::unique_lock lock(mutex); + cv.wait(lock, [&]() { + auto it = mapping.find(model_name); + // either being deleted (load failed), or loaded + return it == mapping.end() || it->second.meta.status == "loaded"; + }); + } + + void ensure_model_loaded(std::string & model_name) { + if (get_meta(model_name).has_value()) { + return; // already loaded, do nothing + } + // TODO: maybe unload least recently used model if too many models are loaded? + auto models = common_list_cached_models(); + for (const auto & m : models) { + auto name = m.to_string(); + if (name == model_name) { + create(name); + wait_until_loaded(name); + SRV_INF("model %s loaded on-demand\n", name.c_str()); + return; + } + } + } + + void create(const std::string & model_name) { + instance_t inst; + // TODO: use a better port allocation strategy + inst.meta.port = rand() % 10000 + 20000; // random port between 20000 and 29999 - // clone envp while adding LLAMA_SERVER_ROUTER_PORT - std::vector child_envs; - std::vector child_envp; + if (get_meta(model_name).has_value()) { + throw std::runtime_error("instance with model_name " + model_name + " already exists"); + } + + pid_t pid = 0; { - for (char ** e = envp; *e != nullptr; ++e) { - child_envs.emplace_back(*e); + // Prepare arguments (pass original or custom ones) using mutable storage for argv + std::string path = get_server_exec_path().string(); + + SRV_INF("spawning instance %s with name=%s on port %d\n", path.c_str(), model_name.c_str(), inst.meta.port); + std::vector arg_strs; + arg_strs.push_back(path); + arg_strs.push_back("-hf"); + arg_strs.push_back(model_name); + arg_strs.push_back("--alias"); + arg_strs.push_back(model_name); + arg_strs.push_back("--port"); + arg_strs.push_back(std::to_string(inst.meta.port)); + + std::vector child_argv; + child_argv.reserve(arg_strs.size() + 1); + for (auto &s : arg_strs) { + child_argv.push_back(const_cast(s.c_str())); } - child_envs.emplace_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(router_port)); - child_envp.reserve(child_envs.size() + 1); - for (auto & s : child_envs) { - child_envp.push_back(const_cast(s.c_str())); + child_argv.push_back(nullptr); + + // clone envp while adding LLAMA_SERVER_ROUTER_PORT + std::vector child_envs; + std::vector child_envp; + { + for (char ** e = envp; *e != nullptr; ++e) { + child_envs.emplace_back(*e); + } + child_envs.emplace_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(router_port)); + child_envp.reserve(child_envs.size() + 1); + for (auto & s : child_envs) { + child_envp.push_back(const_cast(s.c_str())); + } + child_envp.push_back(nullptr); + } + + if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), child_envp.data()) != 0) { + perror("posix_spawn"); + exit(1); // for testing only + } else { + inst.pid = pid; + SRV_INF("spawned instance with pid %d\n", pid); } - child_envp.push_back(nullptr); } - if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), child_envp.data()) != 0) { - perror("posix_spawn"); - exit(1); // for testing only - } else { - inst.pid = pid; - SRV_INF("spawned instance with pid %d\n", pid); + inst.th = std::thread([this, model_name, pid]() { + int status = 0; + waitpid(pid, &status, 0); + SRV_INF("instance with pid %d exited with status %d\n", pid, status); + this->remove(model_name); + }); + if (inst.th.joinable()) { + inst.th.detach(); } - } - inst.th = std::thread([hf_model, pid, &mapping]() { - int status = 0; - waitpid(pid, &status, 0); - SRV_INF("instance with pid %d exited with status %d\n", pid, status); - mapping.erase(hf_model); // TODO: thread safety - }); - if (inst.th.joinable()) { - inst.th.detach(); + insert(model_name, std::move(inst)); } - mapping[hf_model] = std::move(inst); // TODO: thread safety - return 0; -} - -inline void kill_all_instances(std::map & mapping) { - for (auto & [hf_model, inst] : mapping) { - LOG_INF("killing instance with hf=%s on port %d (pid %d)\n", hf_model.c_str(), inst.port, inst.pid); - kill(inst.pid, SIGINT); + void kill_all() { + for (auto & inst : mapping) { + kill_single(inst.first); + } } - mapping.clear(); -} -inline void server_router_kill_single(std::map & mapping, const std::string & hf_model) { - auto it = mapping.find(hf_model); - if (it != mapping.end()) { - auto & inst = it->second; - LOG_INF("killing instance with hf=%s on port %d (pid %d)\n", hf_model.c_str(), inst.port, inst.pid); - kill(inst.pid, SIGINT); - mapping.erase(it); + void kill_single(const std::string & model_name) { + auto it = mapping.find(model_name); + if (it != mapping.end()) { + auto & inst = it->second; + LOG_INF("killing instance with name=%s on port %d (pid %d)\n", model_name.c_str(), inst.meta.port, inst.pid); + kill(inst.pid, SIGINT); + remove(model_name); + } } -} +};