1616#include < thread>
1717#include < trantor/utils/Logger.h>
1818#include < vector>
19+ #include < chrono>
1920
2021using json = nlohmann::json;
2122using namespace tensorrtllm ;
2223
24+ namespace {
25+ constexpr const int k200OK = 200 ;
26+ constexpr const int k400BadRequest = 400 ;
27+ constexpr const int k409Conflict = 409 ;
28+ constexpr const int k500InternalServerError = 500 ;
29+
30+ // https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h
31+ // stopWordsList
32+ // 'im', '_' , 'end', '</s>', '<|im_end|>'
33+ const std::vector<int32_t > kOpenhermesStopWords = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 };
34+ const std::string kOhUserPrompt = " <|im_end|>\n <|im_start|>user\n " ;
35+ const std::string kOhAiPrompt = " <|im_end|>\n <|im_start|>assistant\n " ;
36+ const std::string kOhSystemPrompt = " <|im_start|>system\n " ;
37+ const std::unordered_map<std::string, int > kOpenhermesTemplate = {{" <|im_end|>" , 32000 } , {" <|im_start|>" , 32001 }};
38+
39+ // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
40+ const std::vector<int32_t > kMistral_V0_3_StopWords
41+ = {29560 , 17057 , 29561 , 3 , 29560 , 29516 , 17057 , 29561 , 4 , 2 , 3 , 4 , 8 , 9 , 10 , -1 , -1 , -1 , -1 , -1 };
42+
43+ enum class MistralTemplate : int32_t {
44+ kBos = 1 ,
45+ kEos = 2 ,
46+ kBeginInst = 3 ,
47+ kEndInst = 4
48+ };
2349
24- constexpr const int k200OK = 200 ;
25- constexpr const int k400BadRequest = 400 ;
26- constexpr const int k409Conflict = 409 ;
27- constexpr const int k500InternalServerError = 500 ;
28-
50+ // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
51+ bool IsOpenhermes (const std::string& s) {
52+ if (s.find (" mistral" ) != std::string::npos || s.find (" Mistral" ) != std::string::npos) {
53+ return false ;
54+ }
55+ return true ;
56+ }
57+ }
2958TensorrtllmEngine::~TensorrtllmEngine () {}
3059
3160void RemoveId (std::vector<int >& vec, int id) {
3261 vec.erase (std::remove (vec.begin (), vec.end (), id), vec.end ());
3362}
3463
35- bool HandleMatch (std::string const & rew_text, std::shared_ptr<InferenceState> infer_state) {
36- if (infer_state->IsComplete ()) {
64+ bool HandleMatch (std::string const & rew_text,
65+ std::shared_ptr<InferenceState> infer_state,
66+ std::function<void (Json::Value&&, Json::Value&&)> cb,
67+ bool is_openhermes) {
68+ if (infer_state->IsComplete (is_openhermes)) {
3769 return false ;
3870 }
3971 if (infer_state->stop_word_match_len == 0 ) {
40- if (rew_text.find (' <' ) != std::string::npos) { // Found "<" anywhere in the text
72+ if ((is_openhermes && rew_text.find (' <' ) != std::string::npos) ||
73+ (!is_openhermes && rew_text.find (' [' ) != std::string::npos)) {
4174 infer_state->stop_word_match_len ++; // Move to next state
42- infer_state->prev_text = rew_text;
4375 return true ;
4476 }
45- }
46- else if (rew_text == infer_state->sequence [infer_state->stop_word_match_len ]) {
77+ } else if (rew_text == infer_state->GetSequence (is_openhermes, infer_state->stop_word_match_len )) {
4778 infer_state->stop_word_match_len ++; // Move to next state
48- infer_state->prev_text = rew_text;
4979 return true ;
50- }
51- else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence [0 ]) {
80+ } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence (is_openhermes, 0u )) {
5281 infer_state->stop_word_match_len = 1 ; // Restart from first match if sequence breaks but matches start
53- infer_state->prev_text = rew_text;
5482 return true ;
55- }
56- else {
83+ } else {
5784 infer_state->Reset ();
5885 return false ; // Reset to start if sequence breaks
5986 }
@@ -66,19 +93,21 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st
6693}
6794
6895GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList () {
69- std::vector<int32_t > stop_words_tokens
70- = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 }; // Extend with -1 for increased length
71- return gpt_session->getBufferManager ().copyFrom (stop_words_tokens, ITensor::makeShape ({1 , 2 , 5 }), MemoryType::kGPU );
96+ if (is_openhermes_) {
97+ return gpt_session->getBufferManager ().copyFrom (kOpenhermesStopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kOpenhermesStopWords .size ()/2 )}), MemoryType::kGPU );
98+ } else {
99+ return gpt_session->getBufferManager ().copyFrom (kMistral_V0_3_StopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kMistral_V0_3_StopWords .size ()/2 )}), MemoryType::kGPU );
100+ }
72101}
73102
74103GenerationInput TensorrtllmEngine::CreateGenerationInput (std::vector<int32_t > input_ids_host) {
75104 int input_len = input_ids_host.size ();
76- std::vector<int32_t > input_lengths_host (batchSize , input_len);
105+ std::vector<int32_t > input_lengths_host (batch_size_ , input_len);
77106 GenerationInput::TensorPtr input_lengths
78- = gpt_session->getBufferManager ().copyFrom (input_lengths_host, ITensor::makeShape ({batchSize }), MemoryType::kGPU );
107+ = gpt_session->getBufferManager ().copyFrom (input_lengths_host, ITensor::makeShape ({batch_size_ }), MemoryType::kGPU );
79108 GenerationInput::TensorPtr input_ids = gpt_session->getBufferManager ().copyFrom (
80- input_ids_host, ITensor::makeShape ({batchSize , input_len}), MemoryType::kGPU );
81- GenerationInput generation_input{0 , 0 , input_ids, input_lengths, model_config ->usePackedInput ()};
109+ input_ids_host, ITensor::makeShape ({batch_size_ , input_len}), MemoryType::kGPU );
110+ GenerationInput generation_input{0 , 0 , input_ids, input_lengths, model_config_ ->usePackedInput ()};
82111 generation_input.stopWordsList = GetTensorChatMLStopWordList ();
83112
84113 LOG_INFO << " Create generation input successfully" ;
@@ -101,27 +130,34 @@ void InferenceThread(
101130 TensorrtllmEngine* self,
102131 SamplingConfig sampling_config,
103132 int input_len,
104- int outputLen) {
133+ int outputLen, bool is_openhermes ) {
105134
106135 // Input preparation
107136 LOG_INFO << " Inference thread started" ;
108137 GenerationInput generation_input = self->CreateGenerationInput (input_ids_host);
109138 GenerationOutput generation_output = self->CreateGenerationOutput ();
110139
111140 // Define the callback to stream each generated token
112- generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output](
141+ generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes ](
113142 GenerationOutput::TensorPtr const & output_ids, SizeType32 step, bool finished) {
114- LOG_INFO << " Generating tokenizer in thread" ;
143+ // LOG_INFO << "Generating tokenizer in thread";
115144 // Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens
116145 int output_length = output_ids->getShape ().d [2 ]; // Get the length of output IDs based on the tensor shape
117146 // Copy output IDs from GPU to host for printing
118147 std::vector<int32_t > output_idsHost (output_length);
119148 self->gpt_session ->getBufferManager ().copy (*output_ids, output_idsHost.data (), MemoryType::kCPU );
120149 // Find the last non-zero value in the output IDs starting from the end of the input sequence
121150 std::vector<int > output_idsHostDecode (output_idsHost.begin () + input_len, output_idsHost.end ());
151+
122152 RemoveId (output_idsHostDecode, 0 );
123- RemoveId (output_idsHostDecode, 32000 );
124- RemoveId (output_idsHostDecode, 32001 );
153+ if (is_openhermes) {
154+ for (auto const & [_, v]: kOpenhermesTemplate ) {
155+ RemoveId (output_idsHostDecode, v);
156+ }
157+ } else {
158+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kBeginInst ));
159+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kEndInst ));
160+ }
125161 std::string text = self->cortex_tokenizer ->Decode (output_idsHostDecode);
126162
127163 if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size ()) {
@@ -191,29 +227,47 @@ bool TensorrtllmEngine::CheckModelLoaded(std::function<void(Json::Value&&, Json:
191227
192228void TensorrtllmEngine::HandleChatCompletion (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
193229 inferences::ChatCompletionRequest request = inferences::fromJson (json_body);
194- std::string formatted_input = pre_prompt ;
230+ std::string formatted_input = pre_prompt_ ;
195231 nlohmann::json data;
196232 // data["stream"] = completion.stream;
197233 // data["n_predict"] = completion.max_tokens;
198234 data[" presence_penalty" ] = request.presence_penalty ;
199235 Json::Value const & messages = request.messages ;
200236
237+ // tokens for Mistral v0.3
238+ // TODO(sang): too much hard code here, need to refactor it soon
239+ std::vector<int32_t > tokens = {static_cast <int32_t >(MistralTemplate::kBos )};
240+
201241 // Format the input from user
242+ int msg_count = 0 ;
202243 for (auto const & message : messages) {
203244 std::string input_role = message[" role" ].asString ();
204245 std::string role;
205246 if (input_role == " user" ) {
206- role = user_prompt ;
247+ role = user_prompt_ ;
207248 std::string content = message[" content" ].asString ();
208249 formatted_input += role + content;
250+ if (!is_openhermes_) {
251+ auto new_tokens = cortex_tokenizer->Encode (content);
252+ new_tokens.insert (new_tokens.begin (), static_cast <int32_t >(MistralTemplate::kBeginInst ));
253+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEndInst ));
254+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
255+ }
209256 }
210257 else if (input_role == " assistant" ) {
211- role = ai_prompt ;
258+ role = ai_prompt_ ;
212259 std::string content = message[" content" ].asString ();
213260 formatted_input += role + content;
261+ if (!is_openhermes_) {
262+ auto new_tokens = cortex_tokenizer->Encode (content);
263+ if (msg_count == messages.size () - 1 ) {
264+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEos ));
265+ }
266+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
267+ }
214268 }
215269 else if (input_role == " system" ) {
216- role = system_prompt ;
270+ role = system_prompt_ ;
217271 std::string content = message[" content" ].asString ();
218272 formatted_input = role + content + formatted_input;
219273 }
@@ -222,13 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
222276 std::string content = message[" content" ].asString ();
223277 formatted_input += role + content;
224278 }
279+ msg_count++;
225280 }
226- formatted_input += ai_prompt;
281+ formatted_input += ai_prompt_;
282+ // LOG_INFO << formatted_input;
227283 // Format the input from user
228284
229285 std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
230286
231- std::vector<int32_t > input_ids_host = cortex_tokenizer->Encode (formatted_input);
287+ std::vector<int32_t > input_ids_host;
288+ if (is_openhermes_) {
289+ input_ids_host = cortex_tokenizer->Encode (formatted_input);
290+ } else {
291+ input_ids_host = tokens;
292+ }
293+
232294 int const input_len = input_ids_host.size ();
233295 int const outputLen = request.max_tokens - input_len;
234296
@@ -242,23 +304,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
242304 sampling_config.repetitionPenalty = std::vector{request.frequency_penalty };
243305 // Input preparation
244306
245- std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen);
307+ std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen, is_openhermes_ );
246308 inference_thread.detach (); // Detach the thread to allow it to run independently
247309
248- q_->runTaskInQueue ([cb = std::move (callback), infer_state]() {
310+ q_->runTaskInQueue ([this , cb = std::move (callback), infer_state]() {
311+ // std::string res_str;
249312 LOG_INFO << " Preparing to run inference task queue..." ;
250313 while (true ) { // Continuously check if the queue is not empty
251314 std::unique_lock<std::mutex> lock (infer_state->queue_mutex ); // Lock the queue for exclusive access
252315 if (!infer_state->texts_to_stream .empty ()) {
253316 std::string rew_text = infer_state->texts_to_stream .front ();
317+ // res_str += rew_text;
254318 infer_state->texts_to_stream .pop ();
255- if (HandleMatch (rew_text, infer_state) && rew_text != " [DONE]" ) {
319+ if (HandleMatch (rew_text, infer_state, cb, is_openhermes_ ) && rew_text != " [DONE]" ) {
256320 continue ;
257321 };
258322
259323 if (rew_text == " [DONE]" ) {
260324 const std::string str
261- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , " " , " stop" )
325+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , " " , " stop" )
262326 + " \n\n " + " data: [DONE]" + " \n\n " ;
263327
264328 infer_state->is_finished = true ;
@@ -274,10 +338,10 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
274338 break ;
275339 }
276340 const std::string text_to_stream
277- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , rew_text) + " \n\n " ;
341+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , rew_text) + " \n\n " ;
278342
279343 lock.unlock (); // Unlock as soon as possible
280- infer_state-> prev_text = rew_text;
344+ // std::cout << rew_text;
281345
282346 Json::Value resp_data;
283347 resp_data[" data" ] = text_to_stream;
@@ -292,6 +356,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
292356 lock.unlock ();
293357 }
294358 }
359+ // LOG_INFO << res_str;
295360 });
296361
297362 LOG_INFO << " Inference completed" ;
@@ -301,16 +366,20 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
301366void TensorrtllmEngine::LoadModel (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
302367 model::LoadModelRequest request = model::fromJson (json_body);
303368 std::filesystem::path model_dir = request.model_path ;
369+ is_openhermes_ = IsOpenhermes (request.model_path );
304370
305371 int ctx_len = request.ctx_len ;
306- this ->user_prompt = request.user_prompt ;
307- this ->ai_prompt = request.ai_prompt ;
308- this ->system_prompt = request.system_prompt ;
309- this ->model_id_ = GetModelId (*json_body);
372+ // We only support 2 models for now, it is ugly but it works :(
373+ if (is_openhermes_) {
374+ user_prompt_ = request.user_prompt .empty () ? kOhUserPrompt : request.user_prompt ;
375+ ai_prompt_ = request.ai_prompt .empty () ? kOhAiPrompt : request.ai_prompt ;
376+ system_prompt_ = request.system_prompt .empty () ? kOhSystemPrompt : request.system_prompt ;
377+ }
378+ model_id_ = GetModelId (*json_body);
310379
311- logger = std::make_shared<TllmLogger>();
312- logger ->setLevel (nvinfer1::ILogger::Severity::kINFO );
313- initTrtLlmPlugins (logger .get ());
380+ logger_ = std::make_shared<TllmLogger>();
381+ logger_ ->setLevel (nvinfer1::ILogger::Severity::kINFO );
382+ initTrtLlmPlugins (logger_ .get ());
314383
315384 std::filesystem::path tokenizer_model_name = model_dir / " tokenizer.model" ;
316385 cortex_tokenizer = std::make_unique<Tokenizer>(tokenizer_model_name.string ());
@@ -319,20 +388,20 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
319388 std::filesystem::path json_file_name = model_dir / " config.json" ;
320389 auto json = GptJsonConfig::parse (json_file_name);
321390 auto config = json.getModelConfig ();
322- model_config = std::make_unique<ModelConfig>(config);
391+ model_config_ = std::make_unique<ModelConfig>(config);
323392 auto world_config = WorldConfig::mpi (1 , json.getTensorParallelism (), json.getPipelineParallelism ());
324393 LOG_INFO << " Loaded config from " << json_file_name.string ();
325394 // auto dtype = model_config->getDataType();
326395
327396 // Currently doing fixed session config
328- session_config .maxBatchSize = batchSize ;
329- session_config .maxBeamWidth = 1 ; // Fixed for simplicity
330- session_config .maxSequenceLength = ctx_len;
331- session_config .cudaGraphMode = true ; // Fixed for simplicity
397+ session_config_ .maxBatchSize = batch_size_ ;
398+ session_config_ .maxBeamWidth = 1 ; // Fixed for simplicity
399+ session_config_ .maxSequenceLength = ctx_len;
400+ session_config_ .cudaGraphMode = true ; // Fixed for simplicity
332401
333402 // Init gpt_session
334403 auto model_path = model_dir / json.engineFilename (world_config, model_id_);
335- gpt_session = std::make_unique<GptSession>(session_config , *model_config , world_config, model_path.string (), logger );
404+ gpt_session = std::make_unique<GptSession>(session_config_ , *model_config_ , world_config, model_path.string (), logger_ );
336405
337406 model_loaded_ = true ;
338407 if (q_ == nullptr ) {
@@ -346,7 +415,8 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
346415 Json::Value status_resp;
347416 status_resp[" status_code" ] = k200OK;
348417 callback (std::move (status_resp), std::move (json_resp));
349- return ;
418+ start_time_ = std::chrono::system_clock::now ().time_since_epoch () /
419+ std::chrono::milliseconds (1 );
350420};
351421
352422void TensorrtllmEngine::UnloadModel (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
@@ -363,8 +433,8 @@ void TensorrtllmEngine::UnloadModel(std::shared_ptr<Json::Value> json_body, std:
363433 gpt_session.reset ();
364434 cortex_tokenizer.reset ();
365435 q_.reset ();
366- model_config .reset ();
367- logger .reset ();
436+ model_config_ .reset ();
437+ logger_ .reset ();
368438 model_loaded_ = false ;
369439
370440 Json::Value json_resp;
0 commit comments