diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 196ced443261a..103e4a449bb89 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -383,3 +383,88 @@ void server_http_context::post(const std::string & path, server_http_context::ha }); } + +// +// server_http_client +// + +server_http_client::server_http_client( + const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop) { + // shared between reader and writer threads + auto cli = std::make_shared(host, port); + auto pipe = std::make_shared>(); + + // setup Client + cli->set_connection_timeout(0, 200000); // 200 milliseconds + this->status = 500; // to be overwritten upon response + this->cleanup = [pipe]() { + pipe->close_read(); + pipe->close_write(); + }; + + // wire up the receive end of the pipe + this->next = [pipe, should_stop](std::string & out) -> bool { + msg_t msg; + bool has_next = pipe->read(msg, should_stop); + if (!msg.data.empty()) { + out = std::move(msg.data); + } + return has_next; + }; + + // wire up the HTTP client + // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends + httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) { + msg_t msg; + msg.status = response.status; + for (const auto & [key, value] : response.headers) { + msg.headers[key] = value; + } + pipe->write(std::move(msg)); // send headers first + return true; + }; + httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { + return pipe->write({{}, 0, std::string(data, data_length)}); // send data chunks + }; + + // prepare the request to destination server + httplib::Request req; + { + req.method = method; + req.path = path; + for (const auto & [key, value] : headers) { + req.set_header(key, value); + } + req.body = body; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + } + + // start the proxy thread + SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str()); + this->thread = std::thread([cli, pipe, req]() { + auto result = cli->send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("http client error: %s\n", err_str.c_str()); + pipe->write({{}, 500, ""}); // header + pipe->write({{}, 0, "proxy error: " + err_str}); // body + } + pipe->close_write(); // signal EOF to reader + SRV_DBG("%s", "client request thread ended\n"); + }); + this->thread.detach(); + + // wait for the first chunk (headers) + msg_t header; + pipe->read(header, should_stop); + SRV_DBG("%s", "received response headers\n"); + this->status = header.status; + this->headers = header.headers; +} diff --git a/tools/server/server-http.h b/tools/server/server-http.h index dc6ca92fd8751..977b30d3c70b7 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -75,3 +75,79 @@ struct server_http_context { // for debugging std::string listening_address; }; + + + +#include +#include +#include +#include + +struct server_http_client : server_http_res { + std::function cleanup = nullptr; +public: + server_http_client(const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop); + ~server_http_client() { + if (cleanup) { + cleanup(); + } + } +private: + std::thread thread; + struct msg_t { + std::map headers; + int status = 0; + std::string data; + }; + // simple implementation of a pipe + template + struct pipe_t { + std::mutex mutex; + std::condition_variable cv; + std::queue queue; + std::atomic writer_closed{false}; + std::atomic reader_closed{false}; + void close_write() { + writer_closed.store(true); + cv.notify_all(); + } + void close_read() { + reader_closed.store(true); + cv.notify_all(); + } + bool read(T & output, const std::function & should_stop) { + std::unique_lock lk(mutex); + constexpr auto poll_interval = std::chrono::milliseconds(500); + while (true) { + if (!queue.empty()) { + output = std::move(queue.front()); + queue.pop(); + return true; + } + if (writer_closed.load()) { + return false; // clean EOF + } + if (should_stop()) { + close_read(); // signal broken pipe to writer + return false; // cancelled / reader no longer alive + } + cv.wait_for(lk, poll_interval); + } + } + bool write(T && data) { + std::lock_guard lk(mutex); + if (reader_closed.load()) { + return false; // broken pipe + } + queue.push(std::move(data)); + cv.notify_one(); + return true; + } + }; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 1c9e9a58d7daf..f1bb23c736e77 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5108,6 +5108,109 @@ struct server_routes { return res; }; + // + // router server + // + server_instances instances; + server_http_context::handler_t proxy_get = [this](const server_http_req & req) { + std::string method = "GET"; + std::string model = req.get_param("model"); + if (req.path == "/props" && model.empty()) { + return handle_default_props(req); + } + instances.ensure_model_loaded(model); + return handle_proxy(req, method, model); + }; + server_http_context::handler_t proxy_post = [this](const server_http_req & req) { + std::string method = "POST"; + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + instances.ensure_model_loaded(model); + return handle_proxy(req, method, model); + }; + server_http_res_ptr handle_proxy(const server_http_req & req, std::string & method, std::string model) { + auto meta = instances.get_meta(model); + if (!meta.has_value()) { + auto res = std::make_unique(ctx_server); + res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE)); + return res; + } + server_http_res_ptr res(new server_http_client( + method, params.hostname, meta->port, + req.path, req.headers, req.body, req.should_stop + )); + return res; + } + server_http_res_ptr handle_default_props(const server_http_req &) { + auto res = std::make_unique(ctx_server); + // this is a dummy response to make sure webui doesn't break + res->ok({ + {"model_alias", "llama-server"}, + {"model_path", "none"}, + {"default_generation_settings", { + {"params", json{}}, + {"n_ctx", 0}, + }}, + }); + return res; + } + server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + instances.create(model); + res->ok({{"success", true}}); + return res; + }; + server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + std::string value = json_value(body, "value", std::string()); + if (!instances.get_meta(model).has_value()) { + auto res = std::make_unique(ctx_server); + res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE)); + return res; + } + instances.update_status(model, value); + res->ok({{"success", true}}); + return res; + }; + server_http_context::handler_t get_router_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json models_json = json::array(); + auto models = common_list_cached_models(); + for (const auto & model : models) { + auto model_name = model.to_string(); + auto meta = instances.get_meta(model_name); + bool found = meta.has_value(); + models_json.push_back(json { + {"model", model_name}, + {"name", model_name}, + {"id", model_name}, + // TODO: other fields... + {"status", { + {"value", found ? meta->status : "unloaded"} + }}, + }); + } + res->ok({{"data", models_json}}); + return res; + }; + server_http_context::handler_t post_router_models_unload = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + if (!instances.get_meta(model).has_value()) { + auto res = std::make_unique(ctx_server); + res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE)); + return res; + } + instances.kill_single(model); + res->ok({{"success", true}}); + return res; + }; + private: std::unique_ptr handle_completions_impl( server_task_type type, @@ -5501,7 +5604,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t }; } -int main(int argc, char ** argv) { +int main(int argc, char ** argv, char ** envp) { // own arguments required by this example common_params params; @@ -5549,6 +5652,38 @@ int main(int argc, char ** argv) { // register API routes server_routes routes(params, ctx_server, ctx_http); + // hacky, replace handlers with proxy handlers if this is a router server + bool is_router_server = params.model.path == DEFAULT_MODEL_PATH; + if (is_router_server) { + // setup server instances manager + routes.instances.envp = envp; + routes.instances.router_port = params.port; + + // proxy handlers + routes.get_props = routes.proxy_get; + routes.post_props = routes.proxy_post; + routes.post_completions = routes.proxy_post; + routes.post_completions_oai = routes.proxy_post; + routes.post_chat_completions = routes.proxy_post; + routes.post_infill = routes.proxy_post; + routes.post_embeddings = routes.proxy_post; + routes.post_embeddings_oai = routes.proxy_post; + routes.post_rerank = routes.proxy_post; + routes.post_tokenize = routes.proxy_post; + routes.post_detokenize = routes.proxy_post; + routes.post_apply_template = routes.proxy_post; + routes.get_lora_adapters = routes.proxy_get; + routes.post_lora_adapters = routes.proxy_post; + routes.get_slots = routes.proxy_get; + routes.post_slots = routes.proxy_post; + + // custom routes for router + routes.get_models = routes.get_router_models; + ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load)); + ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload)); + ctx_http.post("/models/status", ex_wrapper(routes.post_router_models_status)); + } + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); @@ -5594,6 +5729,8 @@ int main(int argc, char ** argv) { llama_backend_free(); }; +if (!is_router_server) { // HACKY + // start the HTTP server before loading the model to be able to serve /health requests if (!ctx_http.start()) { clean_up(); @@ -5631,6 +5768,8 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.terminate(); }; +} // end of !is_router_server + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; @@ -5645,6 +5784,23 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif +if (!is_router_server) { // HACKY + + // notify to main router if needed + char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); + if (router_port != nullptr) { + SRV_INF("%s: notifying to main router on port %s\n", __func__, router_port); + server_http_client notify_router( + "POST", params.hostname, std::atoi(router_port), + "/models/status", + { {"Content-Type", "application/json"} }, + json {{ "model", params.model_alias }, { "value", "loaded" }}.dump(), + []() { return false; } + ); + std::string dummy; + notify_router.next(dummy); // ignore the response + } + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); LOG_INF("%s: starting the main loop...\n", __func__); // this call blocks the main thread until queue_tasks.terminate() is called @@ -5655,6 +5811,19 @@ int main(int argc, char ** argv) { ctx_http.thread.join(); } llama_memory_breakdown_print(ctx_server.ctx); +} else { + shutdown_handler = [&](int) { + ctx_http.stop(); + }; + if (!ctx_http.start()) { + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; + } + ctx_http.is_ready.store(true); + ctx_http.thread.join(); // keep the main thread alive + // kill_all_instances(routes.map_model_to_port); // why this also kill the main instance? + LOG_INF("%s: server stopped\n", __func__); +} // end of !is_router_server return 0; } diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index bf21726051e55..cd4d98c1829f6 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1555,3 +1555,225 @@ static server_tokens format_rerank(const struct llama_model * model, const struc return result; } + + + +// +// router server utils +// + +#include "utils.hpp" +#include "download.h" + +#include +#include +#include +#include +#include +#include // for kill() + +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include +#include +#endif + +static std::filesystem::path get_server_exec_path() { +#if defined(_MSC_VER) + wchar_t path[FILENAME_MAX] = { 0 }; + GetModuleFileNameW(nullptr, path, FILENAME_MAX); + return std::filesystem::path(path); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + return std::filesystem::path(std::string(buf.data(), (size > 0) ? size : 0)); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + return std::filesystem::path(std::string(path, (count > 0) ? count: 0)); +#endif +} + +struct server_instances { + struct instance_metadata_t { + // this struct is copyable + int port = 0; + std::string status = "loading"; // "loading", "loaded" + }; + struct instance_t { + pid_t pid = 0; + std::thread th; + instance_metadata_t meta; + }; + +private: + std::mutex mutex; + std::condition_variable cv; + std::map mapping; + + void remove(const std::string & model_name) { + std::unique_lock lock(mutex); + mapping.erase(model_name); + cv.notify_all(); + } + + void insert(const std::string & model_name, instance_t && inst) { + std::unique_lock lock(mutex); + if (mapping.find(model_name) != mapping.end()) { + throw std::runtime_error("instance with name=" + model_name + " already exists"); + } + mapping[model_name] = std::move(inst); + cv.notify_all(); + } + +public: + char ** envp; + int router_port; + + std::optional get_meta(const std::string & model_name) { + std::unique_lock lock(mutex); + auto it = mapping.find(model_name); + if (it != mapping.end()) { + return it->second.meta; + } + return std::nullopt; + } + + void update_status(const std::string & model_name, const std::string & status) { + std::unique_lock lock(mutex); + auto it = mapping.find(model_name); + if (it != mapping.end()) { + it->second.meta.status = status; + cv.notify_all(); + } + } + + void wait_until_loaded(const std::string & model_name) { + std::unique_lock lock(mutex); + cv.wait(lock, [&]() { + auto it = mapping.find(model_name); + // either being deleted (load failed), or loaded + return it == mapping.end() || it->second.meta.status == "loaded"; + }); + } + + void ensure_model_loaded(std::string & model_name) { + if (get_meta(model_name).has_value()) { + return; // already loaded, do nothing + } + // TODO: maybe unload least recently used model if too many models are loaded? + auto models = common_list_cached_models(); + for (const auto & m : models) { + auto name = m.to_string(); + if (name == model_name) { + create(name); + wait_until_loaded(name); + SRV_INF("model %s loaded on-demand\n", name.c_str()); + return; + } + } + } + + void create(const std::string & model_name) { + instance_t inst; + // TODO: use a better port allocation strategy + inst.meta.port = rand() % 10000 + 20000; // random port between 20000 and 29999 + + if (get_meta(model_name).has_value()) { + throw std::runtime_error("instance with model_name " + model_name + " already exists"); + } + + pid_t pid = 0; + { + // Prepare arguments (pass original or custom ones) using mutable storage for argv + std::string path = get_server_exec_path().string(); + + SRV_INF("spawning instance %s with name=%s on port %d\n", path.c_str(), model_name.c_str(), inst.meta.port); + std::vector arg_strs; + arg_strs.push_back(path); + arg_strs.push_back("-hf"); + arg_strs.push_back(model_name); + arg_strs.push_back("--alias"); + arg_strs.push_back(model_name); + arg_strs.push_back("--port"); + arg_strs.push_back(std::to_string(inst.meta.port)); + + std::vector child_argv; + child_argv.reserve(arg_strs.size() + 1); + for (auto &s : arg_strs) { + child_argv.push_back(const_cast(s.c_str())); + } + child_argv.push_back(nullptr); + + // clone envp while adding LLAMA_SERVER_ROUTER_PORT + std::vector child_envs; + std::vector child_envp; + { + for (char ** e = envp; *e != nullptr; ++e) { + child_envs.emplace_back(*e); + } + child_envs.emplace_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(router_port)); + child_envp.reserve(child_envs.size() + 1); + for (auto & s : child_envs) { + child_envp.push_back(const_cast(s.c_str())); + } + child_envp.push_back(nullptr); + } + + if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), child_envp.data()) != 0) { + perror("posix_spawn"); + exit(1); // for testing only + } else { + inst.pid = pid; + SRV_INF("spawned instance with pid %d\n", pid); + } + } + + inst.th = std::thread([this, model_name, pid]() { + int status = 0; + waitpid(pid, &status, 0); + SRV_INF("instance with pid %d exited with status %d\n", pid, status); + this->remove(model_name); + }); + if (inst.th.joinable()) { + inst.th.detach(); + } + + insert(model_name, std::move(inst)); + } + + void kill_all() { + for (auto & inst : mapping) { + kill_single(inst.first); + } + } + + void kill_single(const std::string & model_name) { + auto it = mapping.find(model_name); + if (it != mapping.end()) { + auto & inst = it->second; + LOG_INF("killing instance with name=%s on port %d (pid %d)\n", model_name.c_str(), inst.meta.port, inst.pid); + kill(inst.pid, SIGINT); + remove(model_name); + } + } +};