Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions engine/commands/chat_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "server_start_cmd.h"
#include "trantor/utils/Logger.h"
#include "utils/logging_utils.h"
#include "utils/modellist_utils.h"

namespace commands {
namespace {
Expand Down Expand Up @@ -36,23 +37,36 @@ struct ChunkParser {
}
};

ChatCmd::ChatCmd(std::string host, int port, const config::ModelConfig& mc)
: host_(std::move(host)), port_(port), mc_(mc) {}
void ChatCmd::Exec(const std::string& host, int port,
const std::string& model_handle, std::string msg) {
modellist_utils::ModelListUtils modellist_handler;
config::YamlHandler yaml_handler;
try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();
Exec(host, port, mc, std::move(msg));
} catch (const std::exception& e) {
CLI_LOG("Fail to start model information with ID '" + model_handle +
"': " + e.what());
}
}

void ChatCmd::Exec(std::string msg) {
void ChatCmd::Exec(const std::string& host, int port,
const config::ModelConfig& mc, std::string msg) {
auto address = host + ":" + std::to_string(port);
// Check if server is started
{
if (!commands::IsServerAlive(host_, port_)) {
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Server is not started yet, please run `"
<< commands::GetCortexBinary() << " start` to start server!");
return;
}
}

auto address = host_ + ":" + std::to_string(port_);
// Only check if llamacpp engine
if ((mc_.engine.find("llamacpp") != std::string::npos) &&
!commands::ModelStatusCmd().IsLoaded(host_, port_, mc_)) {
if ((mc.engine.find("llamacpp") != std::string::npos) &&
!commands::ModelStatusCmd().IsLoaded(host, port, mc)) {
CLI_LOG("Model is not loaded yet!");
return;
}
Expand All @@ -78,12 +92,12 @@ void ChatCmd::Exec(std::string msg) {
new_data["role"] = kUser;
new_data["content"] = user_input;
histories_.push_back(std::move(new_data));
json_data["engine"] = mc_.engine;
json_data["engine"] = mc.engine;
json_data["messages"] = histories_;
json_data["model"] = mc_.name;
json_data["model"] = mc.name;
//TODO: support non-stream
json_data["stream"] = true;
json_data["stop"] = mc_.stop;
json_data["stop"] = mc.stop;
auto data_str = json_data.dump();
// std::cout << data_str << std::endl;
cli.set_read_timeout(std::chrono::seconds(60));
Expand Down
9 changes: 4 additions & 5 deletions engine/commands/chat_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
namespace commands {
class ChatCmd {
public:
ChatCmd(std::string host, int port, const config::ModelConfig& mc);
void Exec(std::string msg);
void Exec(const std::string& host, int port, const std::string& model_handle,
std::string msg);
void Exec(const std::string& host, int port, const config::ModelConfig& mc,
std::string msg);

private:
std::string host_;
int port_;
const config::ModelConfig& mc_;
std::vector<nlohmann::json> histories_;
};
} // namespace commands
56 changes: 38 additions & 18 deletions engine/commands/model_start_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,59 @@
#include "trantor/utils/Logger.h"
#include "utils/file_manager_utils.h"
#include "utils/logging_utils.h"
#include "utils/modellist_utils.h"

namespace commands {
ModelStartCmd::ModelStartCmd(std::string host, int port,
const config::ModelConfig& mc)
: host_(std::move(host)), port_(port), mc_(mc) {}
bool ModelStartCmd::Exec(const std::string& host, int port,
const std::string& model_handle) {

bool ModelStartCmd::Exec() {
modellist_utils::ModelListUtils modellist_handler;
config::YamlHandler yaml_handler;
try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();
return Exec(host, port, mc);
} catch (const std::exception& e) {
CLI_LOG("Fail to start model information with ID '" + model_handle +
"': " + e.what());
return false;
}
}

bool ModelStartCmd::Exec(const std::string& host, int port,
const config::ModelConfig& mc) {
// Check if server is started
if (!commands::IsServerAlive(host_, port_)) {
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Server is not started yet, please run `"
<< commands::GetCortexBinary() << " start` to start server!");
return false;
}

// Only check for llamacpp for now
if ((mc_.engine.find("llamacpp") != std::string::npos) &&
commands::ModelStatusCmd().IsLoaded(host_, port_, mc_)) {
if ((mc.engine.find("llamacpp") != std::string::npos) &&
commands::ModelStatusCmd().IsLoaded(host, port, mc)) {
CLI_LOG("Model has already been started!");
return true;
}

httplib::Client cli(host_ + ":" + std::to_string(port_));
httplib::Client cli(host + ":" + std::to_string(port));

nlohmann::json json_data;
if (mc_.files.size() > 0) {
if (mc.files.size() > 0) {
// TODO(sang) support multiple files
json_data["model_path"] = mc_.files[0];
json_data["model_path"] = mc.files[0];
} else {
LOG_WARN << "model_path is empty";
return false;
}
json_data["model"] = mc_.name;
json_data["system_prompt"] = mc_.system_template;
json_data["user_prompt"] = mc_.user_template;
json_data["ai_prompt"] = mc_.ai_template;
json_data["ctx_len"] = mc_.ctx_len;
json_data["stop"] = mc_.stop;
json_data["engine"] = mc_.engine;
json_data["model"] = mc.name;
json_data["system_prompt"] = mc.system_template;
json_data["user_prompt"] = mc.user_template;
json_data["ai_prompt"] = mc.ai_template;
json_data["ctx_len"] = mc.ctx_len;
json_data["stop"] = mc.stop;
json_data["engine"] = mc.engine;

auto data_str = json_data.dump();
cli.set_read_timeout(std::chrono::seconds(60));
Expand All @@ -52,13 +68,17 @@ bool ModelStartCmd::Exec() {
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
CLI_LOG("Model loaded!");
return true;
} else {
CTL_ERR("Model failed to load with status code: " << res->status);
return false;
}
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return false;
}
return true;
return false;
}

}; // namespace commands
9 changes: 2 additions & 7 deletions engine/commands/model_start_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,8 @@ namespace commands {

class ModelStartCmd {
public:
explicit ModelStartCmd(std::string host, int port,
const config::ModelConfig& mc);
bool Exec();
bool Exec(const std::string& host, int port, const std::string& model_handle);

private:
std::string host_;
int port_;
const config::ModelConfig& mc_;
bool Exec(const std::string& host, int port, const config::ModelConfig& mc);
};
} // namespace commands
17 changes: 17 additions & 0 deletions engine/commands/model_status_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,25 @@
#include "httplib.h"
#include "nlohmann/json.hpp"
#include "utils/logging_utils.h"
#include "utils/modellist_utils.h"

namespace commands {
bool ModelStatusCmd::IsLoaded(const std::string& host, int port,
const std::string& model_handle) {
modellist_utils::ModelListUtils modellist_handler;
config::YamlHandler yaml_handler;
try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();
return IsLoaded(host, port, mc);
} catch (const std::exception& e) {
CLI_LOG("Fail to get model status with ID '" + model_handle +
"': " + e.what());
return false;
}
}

bool ModelStatusCmd::IsLoaded(const std::string& host, int port,
const config::ModelConfig& mc) {
httplib::Client cli(host + ":" + std::to_string(port));
Expand Down
2 changes: 2 additions & 0 deletions engine/commands/model_status_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace commands {

class ModelStatusCmd {
public:
bool IsLoaded(const std::string& host, int port,
const std::string& model_handle);
bool IsLoaded(const std::string& host, int port,
const config::ModelConfig& mc);
};
Expand Down
101 changes: 53 additions & 48 deletions engine/commands/run_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,71 +5,76 @@
#include "model_start_cmd.h"
#include "model_status_cmd.h"
#include "server_start_cmd.h"
#include "utils/cortex_utils.h"
#include "utils/file_manager_utils.h"

#include "utils/modellist_utils.h"
namespace commands {

void RunCmd::Exec() {
std::optional<std::string> model_id = model_handle_;

modellist_utils::ModelListUtils modellist_handler;
config::YamlHandler yaml_handler;
auto address = host_ + ":" + std::to_string(port_);
CmdInfo ci(model_id_);
std::string model_file =
ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch;
// TODO should we clean all resource if something fails?
// Check if model existed. If not, download it
{
auto model_conf = model_service_.GetDownloadedModel(model_file + ".yaml");
if (!model_conf.has_value()) {
model_service_.DownloadModel(model_id_);
}
}

// Check if engine existed. If not, download it
// Download model if it does not exist
{
auto required_engine = engine_service_.GetEngineInfo(ci.engine_name);
if (!required_engine.has_value()) {
throw std::runtime_error("Engine not found: " + ci.engine_name);
}
if (required_engine.value().status == EngineService::kIncompatible) {
throw std::runtime_error("Engine " + ci.engine_name + " is incompatible");
}
if (required_engine.value().status == EngineService::kNotInstalled) {
engine_service_.InstallEngine(ci.engine_name);
if (!modellist_handler.HasModel(model_handle_)) {
model_id = model_service_.DownloadModel(model_handle_);
if (!model_id.has_value()) {
CTL_ERR("Error: Could not get model_id from handle: " << model_handle_);
return;
} else {
CTL_INF("model_id: " << model_id.value());
}
}
}

// Start server if it is not running
{
if (!commands::IsServerAlive(host_, port_)) {
CLI_LOG("Starting server ...");
commands::ServerStartCmd ssc;
if (!ssc.Exec(host_, port_)) {
return;
try {
auto model_entry = modellist_handler.GetModelInfo(*model_id);
yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();

// Check if engine existed. If not, download it
{
auto required_engine = engine_service_.GetEngineInfo(mc.engine);
if (!required_engine.has_value()) {
throw std::runtime_error("Engine not found: " + mc.engine);
}
if (required_engine.value().status == EngineService::kIncompatible) {
throw std::runtime_error("Engine " + mc.engine + " is incompatible");
}
if (required_engine.value().status == EngineService::kNotInstalled) {
engine_service_.InstallEngine(mc.engine);
}
}
}

config::YamlHandler yaml_handler;
yaml_handler.ModelConfigFromFile(
file_manager_utils::GetModelsContainerPath().string() + "/" + model_file +
".yaml");
auto mc = yaml_handler.GetModelConfig();
// Start server if it is not running
{
if (!commands::IsServerAlive(host_, port_)) {
CLI_LOG("Starting server ...");
commands::ServerStartCmd ssc;
if (!ssc.Exec(host_, port_)) {
return;
}
}
}

// Always start model if not llamacpp
// If it is llamacpp, then check model status first
{
if ((mc.engine.find("llamacpp") == std::string::npos) ||
!commands::ModelStatusCmd().IsLoaded(host_, port_, mc)) {
ModelStartCmd msc(host_, port_, mc);
if (!msc.Exec()) {
return;
// Always start model if not llamacpp
// If it is llamacpp, then check model status first
{
if ((mc.engine.find("llamacpp") == std::string::npos) ||
!commands::ModelStatusCmd().IsLoaded(host_, port_, mc)) {
if (!ModelStartCmd().Exec(host_, port_, mc)) {
return;
}
}
}
}

// Chat
{
ChatCmd cc(host_, port_, mc);
cc.Exec("");
// Chat
ChatCmd().Exec(host_, port_, mc, "");
} catch (const std::exception& e) {
CLI_LOG("Fail to run model with ID '" + model_handle_ + "': " + e.what());
}
}
}; // namespace commands
6 changes: 3 additions & 3 deletions engine/commands/run_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
namespace commands {
class RunCmd {
public:
explicit RunCmd(std::string host, int port, std::string model_id)
explicit RunCmd(std::string host, int port, std::string model_handle)
: host_{std::move(host)},
port_{port},
model_id_{std::move(model_id)},
model_handle_{std::move(model_handle)},
model_service_{ModelService()} {};

void Exec();

private:
std::string host_;
int port_;
std::string model_id_;
std::string model_handle_;

ModelService model_service_;
EngineService engine_service_;
Expand Down
Loading
Loading