Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit c693e55

Browse files
authored
feat: run command (#1045)
1 parent 05b4b2c commit c693e55

18 files changed

+282
-67
lines changed

engine/commands/chat_cmd.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void ChatCmd::Exec(std::string msg) {
5757
}
5858
}
5959
// Some instruction for user here
60-
std::cout << "Inorder to exit, type exit()" << std::endl;
60+
std::cout << "Inorder to exit, type `exit()`" << std::endl;
6161
// Model is loaded, start to chat
6262
{
6363
while (true) {

engine/commands/cmd_info.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include "cmd_info.h"
2+
#include <vector>
3+
#include "trantor/utils/Logger.h"
4+
5+
namespace commands {
6+
namespace {
7+
constexpr const char* kDelimiter = ":";
8+
9+
std::vector<std::string> split(std::string& s, const std::string& delimiter) {
10+
std::vector<std::string> tokens;
11+
size_t pos = 0;
12+
std::string token;
13+
while ((pos = s.find(delimiter)) != std::string::npos) {
14+
token = s.substr(0, pos);
15+
tokens.push_back(token);
16+
s.erase(0, pos + delimiter.length());
17+
}
18+
tokens.push_back(s);
19+
20+
return tokens;
21+
}
22+
} // namespace
23+
24+
CmdInfo::CmdInfo(std::string model_id) {
25+
Parse(std::move(model_id));
26+
}
27+
28+
void CmdInfo::Parse(std::string model_id) {
29+
if (model_id.find(kDelimiter) == std::string::npos) {
30+
engine_name = "cortex.llamacpp";
31+
model_name = std::move(model_id);
32+
branch = "main";
33+
} else {
34+
auto res = split(model_id, kDelimiter);
35+
if (res.size() != 2) {
36+
LOG_ERROR << "model_id does not valid";
37+
return;
38+
} else {
39+
model_name = std::move(res[0]);
40+
branch = std::move(res[1]);
41+
if (branch.find("onnx") != std::string::npos) {
42+
engine_name = "cortex.onnx";
43+
} else if (branch.find("tensorrt") != std::string::npos) {
44+
engine_name = "cortex.tensorrt-llm";
45+
} else if (branch.find("gguf") != std::string::npos) {
46+
engine_name = "cortex.llamacpp";
47+
} else {
48+
LOG_ERROR << "Not a valid branch model_name " << branch;
49+
}
50+
}
51+
}
52+
}
53+
54+
} // namespace commands

engine/commands/cmd_info.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
#include <string>
3+
namespace commands {
4+
struct CmdInfo {
5+
explicit CmdInfo(std::string model_id);
6+
7+
std::string engine_name;
8+
std::string model_name;
9+
std::string branch;
10+
11+
private:
12+
void Parse(std::string model_id);
13+
};
14+
} // namespace commands

engine/commands/engine_init_cmd.cc

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ namespace commands {
1414
EngineInitCmd::EngineInitCmd(std::string engineName, std::string version)
1515
: engineName_(std::move(engineName)), version_(std::move(version)) {}
1616

17-
void EngineInitCmd::Exec() const {
17+
bool EngineInitCmd::Exec() const {
1818
if (engineName_.empty()) {
1919
LOG_ERROR << "Engine name is required";
20-
return;
20+
return false;
2121
}
2222

2323
// Check if the architecture and OS are supported
@@ -26,15 +26,15 @@ void EngineInitCmd::Exec() const {
2626
system_info.os == system_info_utils::kUnsupported) {
2727
LOG_ERROR << "Unsupported OS or architecture: " << system_info.os << ", "
2828
<< system_info.arch;
29-
return;
29+
return false;
3030
}
3131
LOG_INFO << "OS: " << system_info.os << ", Arch: " << system_info.arch;
3232

3333
// check if engine is supported
3434
if (std::find(supportedEngines_.begin(), supportedEngines_.end(),
3535
engineName_) == supportedEngines_.end()) {
3636
LOG_ERROR << "Engine not supported";
37-
return;
37+
return false;
3838
}
3939

4040
constexpr auto gitHubHost = "https://api.github.com";
@@ -78,7 +78,7 @@ void EngineInitCmd::Exec() const {
7878
LOG_INFO << "Matched variant: " << matched_variant;
7979
if (matched_variant.empty()) {
8080
LOG_ERROR << "No variant found for " << os_arch;
81-
return;
81+
return false;
8282
}
8383

8484
for (auto& asset : assets) {
@@ -103,36 +103,45 @@ void EngineInitCmd::Exec() const {
103103
.path = path,
104104
}}};
105105

106-
DownloadService().AddDownloadTask(
107-
downloadTask, [](const std::string& absolute_path) {
108-
// try to unzip the downloaded file
109-
std::filesystem::path downloadedEnginePath{absolute_path};
110-
LOG_INFO << "Downloaded engine path: "
111-
<< downloadedEnginePath.string();
112-
113-
archive_utils::ExtractArchive(
114-
downloadedEnginePath.string(),
115-
downloadedEnginePath.parent_path()
116-
.parent_path()
117-
.string());
118-
119-
// remove the downloaded file
120-
std::filesystem::remove(absolute_path);
121-
LOG_INFO << "Finished!";
122-
});
123-
124-
return;
106+
DownloadService().AddDownloadTask(downloadTask, [](const std::string&
107+
absolute_path,
108+
bool unused) {
109+
// try to unzip the downloaded file
110+
std::filesystem::path downloadedEnginePath{absolute_path};
111+
LOG_INFO << "Downloaded engine path: "
112+
<< downloadedEnginePath.string();
113+
114+
archive_utils::ExtractArchive(
115+
downloadedEnginePath.string(),
116+
downloadedEnginePath.parent_path().parent_path().string());
117+
118+
// remove the downloaded file
119+
// TODO(any) Could not delete file on Windows because it is currently hold by httplib(?)
120+
// Not sure about other platforms
121+
try {
122+
std::filesystem::remove(absolute_path);
123+
} catch (const std::exception& e) {
124+
LOG_ERROR << "Could not delete file: " << e.what();
125+
}
126+
LOG_INFO << "Finished!";
127+
});
128+
129+
return true;
125130
}
126131
}
127132
} catch (const json::parse_error& e) {
128133
std::cerr << "JSON parse error: " << e.what() << std::endl;
134+
return false;
129135
}
130136
} else {
131137
LOG_ERROR << "HTTP error: " << res->status;
138+
return false;
132139
}
133140
} else {
134141
auto err = res.error();
135142
LOG_ERROR << "HTTP error: " << httplib::to_string(err);
143+
return false;
136144
}
145+
return true;
137146
}
138147
}; // namespace commands

engine/commands/engine_init_cmd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class EngineInitCmd {
99
public:
1010
EngineInitCmd(std::string engineName, std::string version);
1111

12-
void Exec() const;
12+
bool Exec() const;
1313

1414
private:
1515
std::string engineName_;

engine/commands/model_pull_cmd.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66
#include "utils/model_callback_utils.h"
77

88
namespace commands {
9-
ModelPullCmd::ModelPullCmd(std::string modelHandle)
10-
: modelHandle_(std::move(modelHandle)) {}
9+
ModelPullCmd::ModelPullCmd(std::string model_handle, std::string branch)
10+
: model_handle_(std::move(model_handle)), branch_(std::move(branch)) {}
1111

12-
void ModelPullCmd::Exec() {
13-
auto downloadTask = cortexso_parser::getDownloadTask(modelHandle_);
12+
bool ModelPullCmd::Exec() {
13+
auto downloadTask = cortexso_parser::getDownloadTask(model_handle_, branch_);
1414
if (downloadTask.has_value()) {
1515
DownloadService downloadService;
1616
downloadService.AddDownloadTask(downloadTask.value(),
1717
model_callback_utils::DownloadModelCb);
1818
std::cout << "Download finished" << std::endl;
19+
return true;
1920
} else {
2021
std::cout << "Model not found" << std::endl;
22+
return false;
2123
}
2224
}
2325

engine/commands/model_pull_cmd.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ namespace commands {
66

77
class ModelPullCmd {
88
public:
9-
ModelPullCmd(std::string modelHandle);
10-
void Exec();
9+
explicit ModelPullCmd(std::string model_handle, std::string branch);
10+
bool Exec();
1111

1212
private:
13-
std::string modelHandle_;
13+
std::string model_handle_;
14+
std::string branch_;
1415
};
1516
} // namespace commands
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
#include "start_model_cmd.h"
1+
#include "model_start_cmd.h"
22
#include "httplib.h"
33
#include "nlohmann/json.hpp"
44
#include "trantor/utils/Logger.h"
55

66
namespace commands {
7-
StartModelCmd::StartModelCmd(std::string host, int port,
7+
ModelStartCmd::ModelStartCmd(std::string host, int port,
88
const config::ModelConfig& mc)
99
: host_(std::move(host)), port_(port), mc_(mc) {}
1010

11-
void StartModelCmd::Exec() {
11+
bool ModelStartCmd::Exec() {
1212
httplib::Client cli(host_ + ":" + std::to_string(port_));
1313
nlohmann::json json_data;
1414
if (mc_.files.size() > 0) {
1515
// TODO(sang) support multiple files
1616
json_data["model_path"] = mc_.files[0];
1717
} else {
1818
LOG_WARN << "model_path is empty";
19-
return;
19+
return false;
2020
}
2121
json_data["model"] = mc_.name;
2222
json_data["system_prompt"] = mc_.system_template;
@@ -27,7 +27,7 @@ void StartModelCmd::Exec() {
2727
json_data["engine"] = mc_.engine;
2828

2929
auto data_str = json_data.dump();
30-
30+
cli.set_read_timeout(std::chrono::seconds(60));
3131
auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(),
3232
data_str.data(), data_str.size(), "application/json");
3333
if (res) {
@@ -37,7 +37,9 @@ void StartModelCmd::Exec() {
3737
} else {
3838
auto err = res.error();
3939
LOG_WARN << "HTTP error: " << httplib::to_string(err);
40+
return false;
4041
}
42+
return true;
4143
}
4244

4345
}; // namespace commands

engine/commands/start_model_cmd.h renamed to engine/commands/model_start_cmd.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
namespace commands {
77

8-
class StartModelCmd{
8+
class ModelStartCmd{
99
public:
10-
StartModelCmd(std::string host, int port, const config::ModelConfig& mc);
11-
void Exec();
10+
explicit ModelStartCmd(std::string host, int port, const config::ModelConfig& mc);
11+
bool Exec();
1212

1313
private:
1414
std::string host_;

engine/commands/run_cmd.cc

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include "run_cmd.h"
2+
#include "chat_cmd.h"
3+
#include "cmd_info.h"
4+
#include "config/yaml_config.h"
5+
#include "engine_init_cmd.h"
6+
#include "httplib.h"
7+
#include "model_pull_cmd.h"
8+
#include "model_start_cmd.h"
9+
#include "trantor/utils/Logger.h"
10+
#include "utils/cortex_utils.h"
11+
12+
namespace commands {
13+
14+
RunCmd::RunCmd(std::string host, int port, std::string model_id)
15+
: host_(std::move(host)), port_(port), model_id_(std::move(model_id)) {}
16+
17+
void RunCmd::Exec() {
18+
auto address = host_ + ":" + std::to_string(port_);
19+
CmdInfo ci(model_id_);
20+
std::string model_file =
21+
ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch;
22+
// TODO should we clean all resource if something fails?
23+
// Check if model existed. If not, download it
24+
{
25+
if (!IsModelExisted(model_file)) {
26+
ModelPullCmd model_pull_cmd(ci.model_name, ci.branch);
27+
if (!model_pull_cmd.Exec()) {
28+
return;
29+
}
30+
}
31+
}
32+
33+
// Check if engine existed. If not, download it
34+
{
35+
if (!IsEngineExisted(ci.engine_name)) {
36+
EngineInitCmd eic(ci.engine_name, "");
37+
if (!eic.Exec())
38+
return;
39+
}
40+
}
41+
42+
// Start model
43+
config::YamlHandler yaml_handler;
44+
yaml_handler.ModelConfigFromFile(cortex_utils::GetCurrentPath() + "/models/" +
45+
model_file + ".yaml");
46+
{
47+
ModelStartCmd msc(host_, port_, yaml_handler.GetModelConfig());
48+
if (!msc.Exec()) {
49+
return;
50+
}
51+
}
52+
53+
// Chat
54+
{
55+
ChatCmd cc(host_, port_, yaml_handler.GetModelConfig());
56+
cc.Exec("");
57+
}
58+
}
59+
60+
bool RunCmd::IsModelExisted(const std::string& model_id) {
61+
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
62+
cortex_utils::models_folder) &&
63+
std::filesystem::is_directory(cortex_utils::GetCurrentPath() + "/" +
64+
cortex_utils::models_folder)) {
65+
// Iterate through directory
66+
for (const auto& entry : std::filesystem::directory_iterator(
67+
cortex_utils::GetCurrentPath() + "/" +
68+
cortex_utils::models_folder)) {
69+
if (entry.is_regular_file() && entry.path().extension() == ".yaml") {
70+
try {
71+
config::YamlHandler handler;
72+
handler.ModelConfigFromFile(entry.path().string());
73+
std::cout << entry.path().stem().string() << std::endl;
74+
if (entry.path().stem().string() == model_id) {
75+
return true;
76+
}
77+
} catch (const std::exception& e) {
78+
LOG_ERROR << "Error reading yaml file '" << entry.path().string()
79+
<< "': " << e.what();
80+
}
81+
}
82+
}
83+
}
84+
return false;
85+
}
86+
87+
bool RunCmd::IsEngineExisted(const std::string& e) {
88+
if (std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
89+
"engines") &&
90+
std::filesystem::exists(cortex_utils::GetCurrentPath() + "/" +
91+
"engines/" + e)) {
92+
return true;
93+
}
94+
return false;
95+
}
96+
97+
}; // namespace commands

0 commit comments

Comments
 (0)