Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit fb55754

Browse files
nguyenhoangthuan99namchuaivansangpfiev
authored
Model update command/api (#1309)
* Model update command/api * fix: resume download failed Signed-off-by: James <[email protected]> * fix: align github syntax for cuda (#1316) * fix: require sudo for cortex update (#1318) * fix: require sudo for cortex update * fix: comment * refactor code * Format code * Add clean up when finish test * remove model.list after finish test * Fix windows CI build --------- Signed-off-by: James <[email protected]> Co-authored-by: James <[email protected]> Co-authored-by: vansangpfiev <[email protected]>
1 parent 7c31b71 commit fb55754

File tree

11 files changed

+460
-25
lines changed

11 files changed

+460
-25
lines changed

engine/commands/model_upd_cmd.cc

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include "model_upd_cmd.h"
2+
3+
#include "utils/logging_utils.h"
4+
5+
namespace commands {
6+
7+
ModelUpdCmd::ModelUpdCmd(std::string model_handle)
8+
: model_handle_(std::move(model_handle)) {}
9+
10+
void ModelUpdCmd::Exec(
11+
const std::unordered_map<std::string, std::string>& options) {
12+
try {
13+
auto model_entry = model_list_utils_.GetModelInfo(model_handle_);
14+
yaml_handler_.ModelConfigFromFile(model_entry.path_to_model_yaml);
15+
model_config_ = yaml_handler_.GetModelConfig();
16+
17+
for (const auto& [key, value] : options) {
18+
if (!value.empty()) {
19+
UpdateConfig(key, value);
20+
}
21+
}
22+
23+
yaml_handler_.UpdateModelConfig(model_config_);
24+
yaml_handler_.WriteYamlFile(model_entry.path_to_model_yaml);
25+
CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!");
26+
} catch (const std::exception& e) {
27+
CLI_LOG("Failed to update model with model ID '" + model_handle_ +
28+
"': " + e.what());
29+
}
30+
}
31+
32+
void ModelUpdCmd::UpdateConfig(const std::string& key,
33+
const std::string& value) {
34+
static const std::unordered_map<
35+
std::string,
36+
std::function<void(ModelUpdCmd*, const std::string&, const std::string&)>>
37+
updaters = {
38+
{"name",
39+
[](ModelUpdCmd* self, const std::string&, const std::string& v) {
40+
self->model_config_.name = v;
41+
}},
42+
{"model",
43+
[](ModelUpdCmd* self, const std::string&, const std::string& v) {
44+
self->model_config_.model = v;
45+
}},
46+
{"version",
47+
[](ModelUpdCmd* self, const std::string&, const std::string& v) {
48+
self->model_config_.version = v;
49+
}},
50+
{"stop", &ModelUpdCmd::UpdateVectorField},
51+
{"top_p",
52+
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
53+
self->UpdateNumericField(
54+
k, v, [self](float f) { self->model_config_.top_p = f; });
55+
}},
56+
{"temperature",
57+
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
58+
self->UpdateNumericField(k, v, [self](float f) {
59+
self->model_config_.temperature = f;
60+
});
61+
}},
62+
{"frequency_penalty",
63+
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
64+
self->UpdateNumericField(k, v, [self](float f) {
65+
self->model_config_.frequency_penalty = f;
66+
});
67+
}},
68+
{"presence_penalty",
69+
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
70+
self->UpdateNumericField(k, v, [self](float f) {
71+
self->model_config_.presence_penalty = f;
72+
});
73+
}},
74+
{"max_tokens",
75+
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
76+
self->UpdateNumericField(k, v, [self](float f) {
77+
self->model_config_.max_tokens = static_cast<int>(f);
78+
});
79+
}},
80+
{"stream",
81+
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
82+
self->UpdateBooleanField(
83+
k, v, [self](bool b) { self->model_config_.stream = b; });
84+
}},
85+
// Add more fields here...
86+
};
87+
88+
if (auto it = updaters.find(key); it != updaters.end()) {
89+
it->second(this, key, value);
90+
LogUpdate(key, value);
91+
}
92+
}
93+
94+
void ModelUpdCmd::UpdateVectorField(const std::string& key,
95+
const std::string& value) {
96+
std::vector<std::string> tokens;
97+
std::istringstream iss(value);
98+
std::string token;
99+
while (std::getline(iss, token, ',')) {
100+
tokens.push_back(token);
101+
}
102+
model_config_.stop = tokens;
103+
}
104+
105+
void ModelUpdCmd::UpdateNumericField(const std::string& key,
106+
const std::string& value,
107+
std::function<void(float)> setter) {
108+
try {
109+
float numericValue = std::stof(value);
110+
setter(numericValue);
111+
} catch (const std::exception& e) {
112+
CLI_LOG("Failed to parse numeric value for " << key << ": " << e.what());
113+
}
114+
}
115+
116+
void ModelUpdCmd::UpdateBooleanField(const std::string& key,
117+
const std::string& value,
118+
std::function<void(bool)> setter) {
119+
bool boolValue = (value == "true" || value == "1");
120+
setter(boolValue);
121+
}
122+
123+
void ModelUpdCmd::LogUpdate(const std::string& key, const std::string& value) {
124+
CLI_LOG("Updated " << key << " to: " << value);
125+
}
126+
127+
} // namespace commands

engine/commands/model_upd_cmd.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
#include <iostream>
3+
#include <optional>
4+
#include <string>
5+
#include <unordered_map>
6+
#include <vector>
7+
#include "config/model_config.h"
8+
#include "utils/modellist_utils.h"
9+
#include "config/yaml_config.h"
10+
namespace commands {
11+
class ModelUpdCmd {
12+
public:
13+
ModelUpdCmd(std::string model_handle);
14+
void Exec(const std::unordered_map<std::string, std::string>& options);
15+
16+
private:
17+
std::string model_handle_;
18+
config::ModelConfig model_config_;
19+
config::YamlHandler yaml_handler_;
20+
modellist_utils::ModelListUtils model_list_utils_;
21+
22+
void UpdateConfig(const std::string& key, const std::string& value);
23+
void UpdateVectorField(const std::string& key, const std::string& value);
24+
void UpdateNumericField(const std::string& key, const std::string& value,
25+
std::function<void(float)> setter);
26+
void UpdateBooleanField(const std::string& key, const std::string& value,
27+
std::function<void(bool)> setter);
28+
void LogUpdate(const std::string& key, const std::string& value);
29+
};
30+
} // namespace commands

engine/config/model_config.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,115 @@ struct ModelConfig {
5858
int n_probs = 0;
5959
int min_keep = 0;
6060
std::string grammar;
61+
62+
void FromJson(const Json::Value& json) {
63+
// do now allow to update ID and model field because it is unique identifier
64+
// if (json.isMember("id"))
65+
// id = json["id"].asString();
66+
if (json.isMember("name"))
67+
name = json["name"].asString();
68+
// if (json.isMember("model"))
69+
// model = json["model"].asString();
70+
if (json.isMember("version"))
71+
version = json["version"].asString();
6172

73+
if (json.isMember("stop") && json["stop"].isArray()) {
74+
stop.clear();
75+
for (const auto& s : json["stop"]) {
76+
stop.push_back(s.asString());
77+
}
78+
}
79+
80+
if (json.isMember("stream"))
81+
stream = json["stream"].asBool();
82+
if (json.isMember("top_p"))
83+
top_p = json["top_p"].asFloat();
84+
if (json.isMember("temperature"))
85+
temperature = json["temperature"].asFloat();
86+
if (json.isMember("frequency_penalty"))
87+
frequency_penalty = json["frequency_penalty"].asFloat();
88+
if (json.isMember("presence_penalty"))
89+
presence_penalty = json["presence_penalty"].asFloat();
90+
if (json.isMember("max_tokens"))
91+
max_tokens = json["max_tokens"].asInt();
92+
if (json.isMember("seed"))
93+
seed = json["seed"].asInt();
94+
if (json.isMember("dynatemp_range"))
95+
dynatemp_range = json["dynatemp_range"].asFloat();
96+
if (json.isMember("dynatemp_exponent"))
97+
dynatemp_exponent = json["dynatemp_exponent"].asFloat();
98+
if (json.isMember("top_k"))
99+
top_k = json["top_k"].asInt();
100+
if (json.isMember("min_p"))
101+
min_p = json["min_p"].asFloat();
102+
if (json.isMember("tfs_z"))
103+
tfs_z = json["tfs_z"].asFloat();
104+
if (json.isMember("typ_p"))
105+
typ_p = json["typ_p"].asFloat();
106+
if (json.isMember("repeat_last_n"))
107+
repeat_last_n = json["repeat_last_n"].asInt();
108+
if (json.isMember("repeat_penalty"))
109+
repeat_penalty = json["repeat_penalty"].asFloat();
110+
if (json.isMember("mirostat"))
111+
mirostat = json["mirostat"].asBool();
112+
if (json.isMember("mirostat_tau"))
113+
mirostat_tau = json["mirostat_tau"].asFloat();
114+
if (json.isMember("mirostat_eta"))
115+
mirostat_eta = json["mirostat_eta"].asFloat();
116+
if (json.isMember("penalize_nl"))
117+
penalize_nl = json["penalize_nl"].asBool();
118+
if (json.isMember("ignore_eos"))
119+
ignore_eos = json["ignore_eos"].asBool();
120+
if (json.isMember("n_probs"))
121+
n_probs = json["n_probs"].asInt();
122+
if (json.isMember("min_keep"))
123+
min_keep = json["min_keep"].asInt();
124+
if (json.isMember("ngl"))
125+
ngl = json["ngl"].asInt();
126+
if (json.isMember("ctx_len"))
127+
ctx_len = json["ctx_len"].asInt();
128+
if (json.isMember("engine"))
129+
engine = json["engine"].asString();
130+
if (json.isMember("prompt_template"))
131+
prompt_template = json["prompt_template"].asString();
132+
if (json.isMember("system_template"))
133+
system_template = json["system_template"].asString();
134+
if (json.isMember("user_template"))
135+
user_template = json["user_template"].asString();
136+
if (json.isMember("ai_template"))
137+
ai_template = json["ai_template"].asString();
138+
if (json.isMember("os"))
139+
os = json["os"].asString();
140+
if (json.isMember("gpu_arch"))
141+
gpu_arch = json["gpu_arch"].asString();
142+
if (json.isMember("quantization_method"))
143+
quantization_method = json["quantization_method"].asString();
144+
if (json.isMember("precision"))
145+
precision = json["precision"].asString();
146+
147+
if (json.isMember("files") && json["files"].isArray()) {
148+
files.clear();
149+
for (const auto& file : json["files"]) {
150+
files.push_back(file.asString());
151+
}
152+
}
153+
154+
if (json.isMember("created"))
155+
created = json["created"].asUInt64();
156+
if (json.isMember("object"))
157+
object = json["object"].asString();
158+
if (json.isMember("owned_by"))
159+
owned_by = json["owned_by"].asString();
160+
if (json.isMember("text_model"))
161+
text_model = json["text_model"].asBool();
162+
163+
if (engine == "cortex.tensorrt-llm") {
164+
if (json.isMember("trtllm_version"))
165+
trtllm_version = json["trtllm_version"].asString();
166+
if (json.isMember("tp"))
167+
tp = json["tp"].asInt();
168+
}
169+
}
62170
Json::Value ToJson() const {
63171
Json::Value obj;
64172

engine/controllers/command_line_parser.cc

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "commands/model_pull_cmd.h"
1515
#include "commands/model_start_cmd.h"
1616
#include "commands/model_stop_cmd.h"
17+
#include "commands/model_upd_cmd.h"
1718
#include "commands/run_cmd.h"
1819
#include "commands/server_start_cmd.h"
1920
#include "commands/server_stop_cmd.h"
@@ -256,10 +257,8 @@ void CommandLineParser::SetupModelCommands() {
256257
commands::ModelAliasCmd mdc;
257258
mdc.Exec(cml_data_.model_id, cml_data_.model_alias);
258259
});
259-
260-
auto model_update_cmd =
261-
models_cmd->add_subcommand("update", "Update configuration of a model");
262-
model_update_cmd->group(kSubcommands);
260+
// Model update parameters comment
261+
ModelUpdate(models_cmd);
263262

264263
std::string model_path;
265264
auto model_import_cmd = models_cmd->add_subcommand(
@@ -373,6 +372,12 @@ void CommandLineParser::SetupSystemCommands() {
373372
update_cmd->group(kSystemGroup);
374373
update_cmd->add_option("-v", cml_data_.cortex_version, "");
375374
update_cmd->callback([this] {
375+
#if !defined(_WIN32)
376+
if (getuid()) {
377+
CLI_LOG("Error: Not root user. Please run with sudo.");
378+
return;
379+
}
380+
#endif
376381
commands::CortexUpdCmd cuc;
377382
cuc.Exec(cml_data_.cortex_version);
378383
cml_data_.check_upd = false;
@@ -442,3 +447,69 @@ void CommandLineParser::EngineGet(CLI::App* parent) {
442447
[engine_name] { commands::EngineGetCmd().Exec(engine_name); });
443448
}
444449
}
450+
451+
void CommandLineParser::ModelUpdate(CLI::App* parent) {
452+
auto model_update_cmd =
453+
parent->add_subcommand("update", "Update configuration of a model");
454+
model_update_cmd->group(kSubcommands);
455+
model_update_cmd->add_option("--model_id", cml_data_.model_id, "Model ID")
456+
->required();
457+
458+
// Add options dynamically
459+
std::vector<std::string> option_names = {"name",
460+
"model",
461+
"version",
462+
"stop",
463+
"top_p",
464+
"temperature",
465+
"frequency_penalty",
466+
"presence_penalty",
467+
"max_tokens",
468+
"stream",
469+
"ngl",
470+
"ctx_len",
471+
"engine",
472+
"prompt_template",
473+
"system_template",
474+
"user_template",
475+
"ai_template",
476+
"os",
477+
"gpu_arch",
478+
"quantization_method",
479+
"precision",
480+
"tp",
481+
"trtllm_version",
482+
"text_model",
483+
"files",
484+
"created",
485+
"object",
486+
"owned_by",
487+
"seed",
488+
"dynatemp_range",
489+
"dynatemp_exponent",
490+
"top_k",
491+
"min_p",
492+
"tfs_z",
493+
"typ_p",
494+
"repeat_last_n",
495+
"repeat_penalty",
496+
"mirostat",
497+
"mirostat_tau",
498+
"mirostat_eta",
499+
"penalize_nl",
500+
"ignore_eos",
501+
"n_probs",
502+
"min_keep",
503+
"grammar"};
504+
505+
for (const auto& option_name : option_names) {
506+
model_update_cmd->add_option("--" + option_name,
507+
cml_data_.model_update_options[option_name],
508+
option_name);
509+
}
510+
511+
model_update_cmd->callback([this]() {
512+
commands::ModelUpdCmd command(cml_data_.model_id);
513+
command.Exec(cml_data_.model_update_options);
514+
});
515+
}

0 commit comments

Comments
 (0)