diff --git a/engine/common/assistant.h b/engine/common/assistant.h new file mode 100644 index 000000000..e49147e9e --- /dev/null +++ b/engine/common/assistant.h @@ -0,0 +1,157 @@ +#pragma once + +#include +#include "common/assistant_tool.h" +#include "common/thread_tool_resources.h" +#include "common/variant_map.h" +#include "utils/result.hpp" + +namespace OpenAi { +// Deprecated. After jan's migration, we should remove this struct +struct JanAssistant : JsonSerializable { + std::string id; + + std::string name; + + std::string object = "assistant"; + + uint32_t created_at; + + Json::Value tools; + + Json::Value model; + + std::string instructions; + + ~JanAssistant() = default; + + cpp::result ToJson() override { + try { + Json::Value json; + + json["id"] = id; + json["name"] = name; + json["object"] = object; + json["created_at"] = created_at; + + json["tools"] = tools; + json["model"] = model; + json["instructions"] = instructions; + + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Empty JSON"); + } + + JanAssistant assistant; + if (json.isMember("assistant_id")) { + assistant.id = json["assistant_id"].asString(); + } else { + assistant.id = json["id"].asString(); + } + + if (json.isMember("assistant_name")) { + assistant.name = json["assistant_name"].asString(); + } else { + assistant.name = json["name"].asString(); + } + assistant.object = "assistant"; + assistant.created_at = 0; // Jan does not have this + if (json.isMember("tools")) { + assistant.tools = json["tools"]; + } + if (json.isMember("model")) { + assistant.model = json["model"]; + } + assistant.instructions = json["instructions"].asString(); + + return assistant; + } +}; + +struct Assistant { + /** + * The identifier, which can be referenced in API endpoints. + */ + std::string id; + + /** + * The object type, which is always assistant. + */ + std::string object = "assistant"; + + /** + * The Unix timestamp (in seconds) for when the assistant was created. + */ + uint64_t created_at; + + /** + * The name of the assistant. The maximum length is 256 characters. + */ + std::optional name; + + /** + * The description of the assistant. The maximum length is 512 characters. + */ + std::optional description; + + /** + * ID of the model to use. You can use the List models API to see all of + * your available models, or see our Model overview for descriptions of them. + */ + std::string model; + + /** + * The system instructions that the assistant uses. The maximum length is + * 256,000 characters. + */ + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of 128 + * tools per assistant. Tools can be of types code_interpreter, file_search, + * or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::optional> + tool_resources; + + /** + * Set of 16 key-value pairs that can be attached to an object. This can be + * useful for storing additional information about the object in a structured + * format. Keys can be a maximum of 64 characters long and values can be a + * maximum of 512 characters long. + */ + Cortex::VariantMap metadata; + + /** + * What sampling temperature to use, between 0 and 2. Higher values like + * 0.8 will make the output more random, while lower values like 0.2 will + * make it more focused and deterministic. + */ + std::optional temperature; + + /** + * An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p + * probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. + * + * We generally recommend altering this or temperature but not both. + */ + std::optional top_p; +}; +} // namespace OpenAi diff --git a/engine/common/assistant_tool.h b/engine/common/assistant_tool.h new file mode 100644 index 000000000..622721708 --- /dev/null +++ b/engine/common/assistant_tool.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include + +namespace OpenAi { +struct AssistantTool { + std::string type; + + AssistantTool(const std::string& type) : type{type} {} + + virtual ~AssistantTool() = default; +}; + +struct AssistantCodeInterpreterTool : public AssistantTool { + AssistantCodeInterpreterTool() : AssistantTool{"code_interpreter"} {} + + ~AssistantCodeInterpreterTool() = default; +}; + +struct AssistantFileSearchTool : public AssistantTool { + AssistantFileSearchTool() : AssistantTool("file_search") {} + + ~AssistantFileSearchTool() = default; + + /** + * The ranking options for the file search. If not specified, + * the file search tool will use the auto ranker and a score_threshold of 0. + * + * See the file search tool documentation for more information. + */ + struct RankingOption { + /** + * The ranker to use for the file search. If not specified will use the auto ranker. + */ + std::string ranker; + + /** + * The score threshold for the file search. All values must be a + * floating point number between 0 and 1. + */ + float score_threshold; + }; + + /** + * Overrides for the file search tool. + */ + struct FileSearch { + /** + * The maximum number of results the file search tool should output. + * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. + * This number should be between 1 and 50 inclusive. + * + * Note that the file search tool may output fewer than max_num_results results. + * See the file search tool documentation for more information. + */ + int max_num_result; + }; +}; + +struct AssistantFunctionTool : public AssistantTool { + AssistantFunctionTool() : AssistantTool("function") {} + + ~AssistantFunctionTool() = default; + + struct Function { + /** + * A description of what the function does, used by the model to choose + * when and how to call the function. + */ + std::string description; + + /** + * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + * underscores and dashes, with a maximum length of 64. + */ + std::string name; + + // TODO: namh handle parameters + + /** + * Whether to enable strict schema adherence when generating the function call. + * If set to true, the model will follow the exact schema defined in the parameters + * field. Only a subset of JSON Schema is supported when strict is true. + * + * Learn more about Structured Outputs in the function calling guide. + */ + std::optional strict; + }; +}; +} // namespace OpenAi diff --git a/engine/common/thread.h b/engine/common/thread.h index 20672ff72..60f408635 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -3,6 +3,7 @@ #include #include #include +#include "common/assistant.h" #include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" @@ -47,6 +48,9 @@ struct Thread : JsonSerializable { */ Cortex::VariantMap metadata; + // For supporting Jan + std::optional> assistants; + static cpp::result FromJson(const Json::Value& json) { Thread thread; @@ -90,6 +94,25 @@ struct Thread : JsonSerializable { } } + if (json.isMember("title") && !json["title"].isNull()) { + thread.metadata["title"] = json["title"].asString(); + } + + if (json.isMember("assistants") && json["assistants"].isArray()) { + std::vector assistants; + for (Json::ArrayIndex i = 0; i < json["assistants"].size(); ++i) { + Json::Value assistant_json = json["assistants"][i]; + auto assistant_result = + JanAssistant::FromJson(std::move(assistant_json)); + if (assistant_result.has_error()) { + return cpp::fail("Failed to parse assistant: " + + assistant_result.error()); + } + assistants.push_back(std::move(assistant_result.value())); + } + thread.assistants = std::move(assistants); + } + return thread; } diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 701547873..84e175d54 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -1,10 +1,8 @@ #pragma once #include -#include #include #include -#include #include #include #include @@ -12,7 +10,6 @@ #include #include "utils/format_utils.h" #include "utils/remote_models_utils.h" -#include "yaml-cpp/yaml.h" namespace config { diff --git a/engine/controllers/assistants.cc b/engine/controllers/assistants.cc new file mode 100644 index 000000000..405d7ed3c --- /dev/null +++ b/engine/controllers/assistants.cc @@ -0,0 +1,144 @@ +#include "assistants.h" +#include "utils/cortex_utils.h" +#include "utils/logging_utils.h" + +void Assistants::RetrieveAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const { + CTL_INF("RetrieveAssistant: " + assistant_id); + auto res = assistant_service_->RetrieveAssistant(assistant_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto to_json_res = res->ToJson(); + if (to_json_res.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_res.error()); + Json::Value ret; + ret["message"] = to_json_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + // TODO: namh need to use the text response because it contains model config + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Assistants::CreateAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Parse assistant from request body + auto assistant_result = OpenAi::JanAssistant::FromJson(std::move(*json_body)); + if (assistant_result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse assistant: " + assistant_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Call assistant service to create + auto create_result = assistant_service_->CreateAssistant( + assistant_id, assistant_result.value()); + if (create_result.has_error()) { + Json::Value ret; + ret["message"] = create_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Convert result to JSON and send response + auto to_json_result = create_result->ToJson(); + if (to_json_result.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_result.error()); + Json::Value ret; + ret["message"] = to_json_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(to_json_result.value()); + resp->setStatusCode(k201Created); + callback(resp); +} + +void Assistants::ModifyAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Parse assistant from request body + auto assistant_result = OpenAi::JanAssistant::FromJson(std::move(*json_body)); + if (assistant_result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse assistant: " + assistant_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Call assistant service to create + auto modify_result = assistant_service_->ModifyAssistant( + assistant_id, assistant_result.value()); + if (modify_result.has_error()) { + Json::Value ret; + ret["message"] = modify_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Convert result to JSON and send response + auto to_json_result = modify_result->ToJson(); + if (to_json_result.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_result.error()); + Json::Value ret; + ret["message"] = to_json_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(to_json_result.value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/assistants.h b/engine/controllers/assistants.h new file mode 100644 index 000000000..94ddd14b1 --- /dev/null +++ b/engine/controllers/assistants.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include "services/assistant_service.h" + +using namespace drogon; + +class Assistants : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(Assistants::RetrieveAssistant, "/v1/assistants/{assistant_id}", + Get); + + ADD_METHOD_TO(Assistants::CreateAssistant, "/v1/assistants/{assistant_id}", + Options, Post); + + ADD_METHOD_TO(Assistants::ModifyAssistant, "/v1/assistants/{assistant_id}", + Options, Patch); + METHOD_LIST_END + + explicit Assistants(std::shared_ptr assistant_srv) + : assistant_service_{assistant_srv} {}; + + void RetrieveAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const; + + void CreateAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + + void ModifyAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + + private: + std::shared_ptr assistant_service_; +}; diff --git a/engine/controllers/messages.cc b/engine/controllers/messages.cc index ef82b3412..27307803a 100644 --- a/engine/controllers/messages.cc +++ b/engine/controllers/messages.cc @@ -10,13 +10,13 @@ void Messages::ListMessages( const HttpRequestPtr& req, std::function&& callback, - const std::string& thread_id, std::optional limit, + const std::string& thread_id, std::optional limit, std::optional order, std::optional after, std::optional before, std::optional run_id) const { auto res = message_service_->ListMessages( - thread_id, limit.value_or(20), order.value_or("desc"), after.value_or(""), - before.value_or(""), run_id.value_or("")); + thread_id, std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or(""), run_id.value_or("")); Json::Value root; if (res.has_error()) { @@ -212,39 +212,88 @@ void Messages::ModifyMessage( } std::optional metadata = std::nullopt; - if (auto it = json_body->get("metadata", ""); it) { - if (it.empty()) { + if (json_body->isMember("metadata")) { + if (auto it = json_body->get("metadata", ""); it) { + if (it.empty()) { + Json::Value ret; + ret["message"] = "Metadata can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto convert_res = Cortex::ConvertJsonValueToMap(it); + if (convert_res.has_error()) { + Json::Value ret; + ret["message"] = + "Failed to convert metadata to map: " + convert_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + metadata = convert_res.value(); + } + } + + std::optional< + std::variant>>> + content = std::nullopt; + + if (json_body->get("content", "").isArray()) { + auto result = OpenAi::ParseContents(json_body->get("content", "")); + if (result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse content array: " + result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + if (result.value().empty()) { Json::Value ret; - ret["message"] = "Metadata can't be empty"; + ret["message"] = "Content array cannot be empty"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k400BadRequest); callback(resp); return; } - auto convert_res = Cortex::ConvertJsonValueToMap(it); - if (convert_res.has_error()) { + + content = std::move(result.value()); + } else if (json_body->get("content", "").isString()) { + auto content_str = json_body->get("content", "").asString(); + string_utils::Trim(content_str); + if (content_str.empty()) { Json::Value ret; - ret["message"] = - "Failed to convert metadata to map: " + convert_res.error(); + ret["message"] = "Content can't be empty"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k400BadRequest); callback(resp); return; } - metadata = convert_res.value(); + + content = content_str; + } else if (!json_body->get("content", "").empty()) { + Json::Value ret; + ret["message"] = "Content must be either a string or an array"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; } - if (!metadata.has_value()) { + if (!metadata.has_value() && !content.has_value()) { Json::Value ret; - ret["message"] = "Metadata is mandatory"; + ret["message"] = "Nothing to update"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k400BadRequest); callback(resp); return; } - auto res = - message_service_->ModifyMessage(thread_id, message_id, metadata.value()); + auto res = message_service_->ModifyMessage(thread_id, message_id, metadata, + std::move(content)); if (res.has_error()) { Json::Value ret; ret["message"] = "Failed to modify message: " + res.error(); diff --git a/engine/controllers/messages.h b/engine/controllers/messages.h index 340317eb8..045d8a207 100644 --- a/engine/controllers/messages.h +++ b/engine/controllers/messages.h @@ -34,7 +34,8 @@ class Messages : public drogon::HttpController { void ListMessages(const HttpRequestPtr& req, std::function&& callback, - const std::string& thread_id, std::optional limit, + const std::string& thread_id, + std::optional limit, std::optional order, std::optional after, std::optional before, diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc index a11c1071b..1cd3aaeef 100644 --- a/engine/controllers/threads.cc +++ b/engine/controllers/threads.cc @@ -7,12 +7,12 @@ void Threads::ListThreads( const HttpRequestPtr& req, std::function&& callback, - std::optional limit, std::optional order, + std::optional limit, std::optional order, std::optional after, std::optional before) const { CTL_INF("ListThreads"); - auto res = - thread_service_->ListThreads(limit.value_or(20), order.value_or("desc"), - after.value_or(""), before.value_or("")); + auto res = thread_service_->ListThreads( + std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or("")); if (res.has_error()) { Json::Value root; diff --git a/engine/controllers/threads.h b/engine/controllers/threads.h index 92c509525..f26e35785 100644 --- a/engine/controllers/threads.h +++ b/engine/controllers/threads.h @@ -34,7 +34,7 @@ class Threads : public drogon::HttpController { void ListThreads(const HttpRequestPtr& req, std::function&& callback, - std::optional limit, + std::optional limit, std::optional order, std::optional after, std::optional before) const; diff --git a/engine/main.cc b/engine/main.cc index 0177a2143..894e9d146 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -1,6 +1,7 @@ #include #include #include +#include "controllers/assistants.h" #include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" @@ -14,6 +15,7 @@ #include "migrations/migration_manager.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" +#include "services/assistant_service.h" #include "services/config_service.h" #include "services/file_watcher_service.h" #include "services/message_service.h" @@ -124,6 +126,7 @@ void RunServer(std::optional port, bool ignore_cout) { auto thread_repo = std::make_shared( file_manager_utils::GetCortexDataPath()); + auto assistant_srv = std::make_shared(thread_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); @@ -142,6 +145,7 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto assistant_ctl = std::make_shared(assistant_srv); auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); @@ -153,6 +157,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(assistant_ctl); drogon::app().registerController(thread_ctl); drogon::app().registerController(message_ctl); drogon::app().registerController(engine_ctl); diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index e576a7695..388409390 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -1,4 +1,5 @@ #include "message_fs_repository.h" +#include #include #include #include "utils/result.hpp" @@ -52,7 +53,61 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, auto mutex = GrabMutex(thread_id); std::shared_lock lock(*mutex); - return ReadMessageFromFile(thread_id); + auto read_result = ReadMessageFromFile(thread_id); + if (read_result.has_error()) { + return cpp::fail(read_result.error()); + } + + std::vector messages = std::move(read_result.value()); + + if (!run_id.empty()) { + messages.erase(std::remove_if(messages.begin(), messages.end(), + [&run_id](const OpenAi::Message& msg) { + return msg.run_id != run_id; + }), + messages.end()); + } + + std::sort(messages.begin(), messages.end(), + [&order](const OpenAi::Message& a, const OpenAi::Message& b) { + if (order == "desc") { + return a.created_at > b.created_at; + } + return a.created_at < b.created_at; + }); + + auto start_it = messages.begin(); + auto end_it = messages.end(); + + if (!after.empty()) { + start_it = std::find_if( + messages.begin(), messages.end(), + [&after](const OpenAi::Message& msg) { return msg.id == after; }); + if (start_it != messages.end()) { + ++start_it; // Start from the message after the 'after' message + } else { + start_it = messages.begin(); + } + } + + if (!before.empty()) { + end_it = std::find_if( + messages.begin(), messages.end(), + [&before](const OpenAi::Message& msg) { return msg.id == before; }); + } + + std::vector result; + size_t distance = std::distance(start_it, end_it); + size_t limit_size = static_cast(limit); + CTL_INF("Distance: " + std::to_string(distance) + + ", limit_size: " + std::to_string(limit_size)); + result.reserve(distance < limit_size ? distance : limit_size); + + for (auto it = start_it; it != end_it && result.size() < limit_size; ++it) { + result.push_back(std::move(*it)); + } + + return result; } cpp::result MessageFsRepository::RetrieveMessage( diff --git a/engine/repositories/thread_fs_repository.cc b/engine/repositories/thread_fs_repository.cc index 64dad6ea5..6b75db8e4 100644 --- a/engine/repositories/thread_fs_repository.cc +++ b/engine/repositories/thread_fs_repository.cc @@ -1,37 +1,67 @@ #include "thread_fs_repository.h" #include #include +#include "common/assistant.h" +#include "utils/result.hpp" cpp::result, std::string> ThreadFsRepository::ListThreads(uint8_t limit, const std::string& order, const std::string& after, const std::string& before) const { - CTL_INF("ListThreads: limit=" + std::to_string(limit) + ", order=" + order + - ", after=" + after + ", before=" + before); std::vector threads; try { auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + std::vector all_threads; + + // First load all valid threads for (const auto& entry : std::filesystem::directory_iterator(thread_container_path)) { if (!entry.is_directory()) continue; - if (!std::filesystem::exists(entry.path() / kThreadFileName)) + auto thread_file = entry.path() / kThreadFileName; + if (!std::filesystem::exists(thread_file)) continue; auto current_thread_id = entry.path().filename().string(); - CTL_INF("ListThreads: Found thread: " + current_thread_id); - std::shared_lock thread_lock(GrabThreadMutex(current_thread_id)); + // Apply pagination filters + if (!after.empty() && current_thread_id <= after) + continue; + if (!before.empty() && current_thread_id >= before) + continue; + + std::shared_lock thread_lock(GrabThreadMutex(current_thread_id)); auto thread_result = LoadThread(current_thread_id); + if (thread_result.has_value()) { - threads.push_back(std::move(thread_result.value())); + all_threads.push_back(std::move(thread_result.value())); } thread_lock.unlock(); } + // Sort threads based on order parameter using created_at + if (order == "desc") { + std::sort(all_threads.begin(), all_threads.end(), + [](const OpenAi::Thread& a, const OpenAi::Thread& b) { + return a.created_at > b.created_at; // Descending order + }); + } else { + std::sort(all_threads.begin(), all_threads.end(), + [](const OpenAi::Thread& a, const OpenAi::Thread& b) { + return a.created_at < b.created_at; // Ascending order + }); + } + + // Apply limit + size_t thread_count = + std::min(static_cast(limit), all_threads.size()); + for (size_t i = 0; i < thread_count; i++) { + threads.push_back(std::move(all_threads[i])); + } + return threads; } catch (const std::exception& e) { return cpp::fail(std::string("Failed to list threads: ") + e.what()); @@ -164,3 +194,85 @@ cpp::result ThreadFsRepository::DeleteThread( thread_mutexes_.erase(thread_id); return {}; } + +cpp::result +ThreadFsRepository::LoadAssistant(const std::string& thread_id) const { + auto path = GetThreadPath(thread_id) / kThreadFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + std::shared_lock thread_lock(GrabThreadMutex(thread_id)); + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + Json::Value assistants = root["assistants"]; + if (!assistants.isArray()) { + return cpp::fail("Assistants field is not an array"); + } + + if (assistants.empty()) { + return cpp::fail("Assistant not found in thread: " + thread_id); + } + + return OpenAi::JanAssistant::FromJson(std::move(assistants[0])); + } catch (const std::exception& e) { + return cpp::fail("Failed to load assistant: " + std::string(e.what())); + } +} + +cpp::result +ThreadFsRepository::ModifyAssistant(const std::string& thread_id, + const OpenAi::JanAssistant& assistant) { + std::unique_lock lock(GrabThreadMutex(thread_id)); + + // Load the existing thread + auto thread_result = LoadThread(thread_id); + if (!thread_result.has_value()) { + return cpp::fail("Failed to load thread: " + thread_result.error()); + } + + auto& thread = thread_result.value(); + if (thread.ToJson() + ->get("assistants", Json::Value(Json::arrayValue)) + .empty()) { + return cpp::fail("No assistants found in thread: " + thread_id); + } + + thread.assistants = {assistant}; + + auto save_result = SaveThread(thread); + if (!save_result.has_value()) { + return cpp::fail("Failed to save thread: " + save_result.error()); + } + + return assistant; +} + +cpp::result ThreadFsRepository::CreateAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant) { + std::unique_lock lock(GrabThreadMutex(thread_id)); + + // Load the existing thread + auto thread_result = LoadThread(thread_id); + if (!thread_result.has_value()) { + return cpp::fail("Failed to load thread: " + thread_result.error()); + } + + auto& thread = thread_result.value(); + thread.assistants = {assistant}; + + // Save the modified thread + return SaveThread(thread); +} diff --git a/engine/repositories/thread_fs_repository.h b/engine/repositories/thread_fs_repository.h index d834b8e44..b6f6032fa 100644 --- a/engine/repositories/thread_fs_repository.h +++ b/engine/repositories/thread_fs_repository.h @@ -3,11 +3,26 @@ #include #include #include +#include "common/assistant.h" #include "common/repository/thread_repository.h" #include "common/thread.h" #include "utils/logging_utils.h" -class ThreadFsRepository : public ThreadRepository { +// this interface is for backward supporting Jan +class AssistantBackwardCompatibleSupport { + public: + virtual cpp::result LoadAssistant( + const std::string& thread_id) const = 0; + + virtual cpp::result ModifyAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant) = 0; + + virtual cpp::result CreateAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant) = 0; +}; + +class ThreadFsRepository : public ThreadRepository, + public AssistantBackwardCompatibleSupport { private: constexpr static auto kThreadFileName = "thread.json"; constexpr static auto kThreadContainerFolderName = "threads"; @@ -58,5 +73,17 @@ class ThreadFsRepository : public ThreadRepository { cpp::result DeleteThread( const std::string& thread_id) override; + // for supporting Jan + cpp::result LoadAssistant( + const std::string& thread_id) const override; + + cpp::result ModifyAssistant( + const std::string& thread_id, + const OpenAi::JanAssistant& assistant) override; + + cpp::result CreateAssistant( + const std::string& thread_id, + const OpenAi::JanAssistant& assistant) override; + ~ThreadFsRepository() = default; }; diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc new file mode 100644 index 000000000..e769bf23f --- /dev/null +++ b/engine/services/assistant_service.cc @@ -0,0 +1,28 @@ +#include "assistant_service.h" +#include "utils/logging_utils.h" + +cpp::result +AssistantService::CreateAssistant(const std::string& thread_id, + const OpenAi::JanAssistant& assistant) { + CTL_INF("CreateAssistant: " + thread_id); + auto res = thread_repository_->CreateAssistant(thread_id, assistant); + + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return assistant; +} + +cpp::result +AssistantService::RetrieveAssistant(const std::string& assistant_id) const { + CTL_INF("RetrieveAssistant: " + assistant_id); + return thread_repository_->LoadAssistant(assistant_id); +} + +cpp::result +AssistantService::ModifyAssistant(const std::string& thread_id, + const OpenAi::JanAssistant& assistant) { + CTL_INF("RetrieveAssistant: " + thread_id); + return thread_repository_->ModifyAssistant(thread_id, assistant); +} diff --git a/engine/services/assistant_service.h b/engine/services/assistant_service.h new file mode 100644 index 000000000..e7f7414d1 --- /dev/null +++ b/engine/services/assistant_service.h @@ -0,0 +1,24 @@ +#pragma once + +#include "common/assistant.h" +#include "repositories/thread_fs_repository.h" +#include "utils/result.hpp" + +class AssistantService { + public: + explicit AssistantService( + std::shared_ptr thread_repository) + : thread_repository_{thread_repository} {} + + cpp::result CreateAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant); + + cpp::result RetrieveAssistant( + const std::string& thread_id) const; + + cpp::result ModifyAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant); + + private: + std::shared_ptr thread_repository_; +}; diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc index dfad74236..ddc9e096b 100644 --- a/engine/services/message_service.cc +++ b/engine/services/message_service.cc @@ -71,7 +71,10 @@ cpp::result MessageService::RetrieveMessage( cpp::result MessageService::ModifyMessage( const std::string& thread_id, const std::string& message_id, - std::optional metadata) { + std::optional metadata, + std::optional>>> + content) { LOG_TRACE << "ModifyMessage for thread " << thread_id << ", message " << message_id; auto msg = RetrieveMessage(thread_id, message_id); @@ -79,7 +82,24 @@ cpp::result MessageService::ModifyMessage( return cpp::fail("Failed to retrieve message: " + msg.error()); } - msg->metadata = metadata.value(); + if (metadata.has_value()) { + msg->metadata = metadata.value(); + } + if (content.has_value()) { + std::vector> content_list{}; + + // If content is string + if (std::holds_alternative(*content)) { + auto text_content = std::make_unique(); + text_content->text.value = std::get(*content); + content_list.push_back(std::move(text_content)); + } else { + content_list = std::move( + std::get>>(*content)); + } + + msg->content = std::move(content_list); + } auto ptr = &msg.value(); auto res = message_repository_->ModifyMessage(msg.value()); diff --git a/engine/services/message_service.h b/engine/services/message_service.h index 6c4880f32..456cdb3a3 100644 --- a/engine/services/message_service.h +++ b/engine/services/message_service.h @@ -21,16 +21,19 @@ class MessageService { std::optional> messages); cpp::result, std::string> ListMessages( - const std::string& thread_id, uint8_t limit = 20, - const std::string& order = "desc", const std::string& after = "", - const std::string& before = "", const std::string& run_id = "") const; + const std::string& thread_id, uint8_t limit, const std::string& order, + const std::string& after, const std::string& before, + const std::string& run_id) const; cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const; cpp::result ModifyMessage( const std::string& thread_id, const std::string& message_id, - std::optional metadata); + std::optional metadata, + std::optional>>> + content); cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id);