@@ -26,11 +26,52 @@ void removeId(std::vector<int>& vec, int id)
2626struct inferenceState
2727{
2828 int prevPos{0 };
29+ std::string prevText;
2930 bool isFinished;
3031 std::queue<std::string> textsToStream;
3132 std::mutex queueMutex; // Mutex to protect access to textsToStream
33+
34+ size_t stopWordMatchLen = 0 ;
35+ std::vector<std::string> sequence{" <" , " |" , " im" , " _" , " end" , " |" , " >" };
36+
37+ void reset ()
38+ {
39+ stopWordMatchLen = 0 ;
40+ prevText = " " ;
41+ }
42+
43+ bool isComplete () const
44+ {
45+ return stopWordMatchLen >= sequence.size ();
46+ }
3247};
3348
49+ bool handleMatch (const std::string& rawText, std::shared_ptr<inferenceState> inferState)
50+ {
51+ if (inferState->isComplete ())
52+ {
53+ return true ;
54+ }
55+
56+ if (rawText == inferState->sequence [inferState->stopWordMatchLen ])
57+ {
58+ inferState->stopWordMatchLen ++; // Move to next state
59+ inferState->prevText = rawText;
60+ return true ;
61+ }
62+ else if (inferState->stopWordMatchLen > 0 && rawText == inferState->sequence [0 ])
63+ {
64+ inferState->stopWordMatchLen = 1 ; // Restart from first match if sequence breaks but matches start
65+ inferState->prevText = rawText;
66+ return true ;
67+ }
68+ else
69+ {
70+ inferState->reset ();
71+ return false ; // Reset to start if sequence breaks
72+ }
73+ }
74+
3475// Only support single token stopping point now
3576std::string create_return_json (const std::string& id, const std::string& model, const std::string& content,
3677 Json::Value finish_reason = Json::Value())
@@ -67,6 +108,13 @@ GenerationInput::TensorPtr tensorrtllm::getTensorSingleStopWordList(int stopToke
67108 return gptSession->getBufferManager ().copyFrom (stopWordsTokens, ITensor::makeShape ({1 , 2 , 2 }), MemoryType::kGPU );
68109}
69110
111+ GenerationInput::TensorPtr tensorrtllm::getTensorChatMLStopWordList ()
112+ {
113+ std::vector<int32_t > stopWordsTokens = {28789 , 28766 , 321 , 28730 , 416 , 28766 , 28767 , 32000 , 6 , 8 , -1 , -1 , -1 , -1 ,
114+ -1 , -1 }; // Extend with -1 for increased length
115+ return gptSession->getBufferManager ().copyFrom (stopWordsTokens, ITensor::makeShape ({1 , 2 , 8 }), MemoryType::kGPU );
116+ }
117+
70118GenerationInput tensorrtllm::createGenerationInput (std::vector<int32_t > inputIdsHost)
71119{
72120 int inputLen = inputIdsHost.size ();
@@ -78,7 +126,7 @@ GenerationInput tensorrtllm::createGenerationInput(std::vector<int32_t> inputIds
78126
79127 GenerationInput generationInput{0 , 0 , inputIds, inputLengths, modelConfig->usePackedInput ()};
80128
81- generationInput.stopWordsList = getTensorSingleStopWordList ( 32000 );
129+ generationInput.stopWordsList = getTensorChatMLStopWordList ( );
82130 return generationInput;
83131}
84132
@@ -117,35 +165,35 @@ void inferenceThread(std::shared_ptr<inferenceState> inferState, std::vector<int
117165 generationOutput.onTokenGenerated = [&inferState, inputLen, outputLen, self, &generationOutput](
118166 GenerationOutput::TensorPtr const & outputIds, SizeType step, bool finished)
119167 {
120- if (!finished)
168+ // Assuming the shape of outputIds tensor is (1, 1, 160), where 160 is the number of tokens
169+ int outputLength = outputIds->getShape ().d [2 ]; // Get the length of output IDs based on the tensor shape
170+ // Copy output IDs from GPU to host for printing
171+ std::vector<int32_t > outputIdsHost (outputLength);
172+ self->gptSession ->getBufferManager ().copy (*outputIds, outputIdsHost.data (), MemoryType::kCPU );
173+ // Find the last non-zero value in the output IDs starting from the end of the input sequence
174+ std::vector<int > outputIdsHostDecode (outputIdsHost.begin () + inputLen, outputIdsHost.end ());
175+ removeId (outputIdsHostDecode, 0 );
176+ std::string text = self->nitro_tokenizer ->decode (outputIdsHostDecode);
177+
178+ if (inferState->prevPos > 0 && inferState->prevPos < text.size ())
179+ {
180+ // Valid prevPos, proceed with slicing the string from prevPos to the end
181+ std::string stringTok (text.begin () + inferState->prevPos , text.end ());
182+ std::lock_guard<std::mutex> guard (inferState->queueMutex ); // Protect access with a lock
183+ inferState->textsToStream .push (stringTok);
184+ }
185+ else if (inferState->prevPos >= text.size ())
121186 {
122- // Assuming the shape of outputIds tensor is (1, 1, 160), where 160 is the number of tokens
123- int outputLength = outputIds->getShape ().d [2 ]; // Get the length of output IDs based on the tensor shape
124- // Copy output IDs from GPU to host for printing
125- std::vector<int32_t > outputIdsHost (outputLength);
126- self->gptSession ->getBufferManager ().copy (*outputIds, outputIdsHost.data (), MemoryType::kCPU );
127- // Find the last non-zero value in the output IDs starting from the end of the input sequence
128- std::vector<int > outputIdsHostDecode (outputIdsHost.begin () + inputLen, outputIdsHost.end ());
129- removeId (outputIdsHostDecode, 0 );
130- removeId (outputIdsHostDecode, 32000 );
131- std::string text = self->nitro_tokenizer ->decode (outputIdsHostDecode);
132-
133- if (inferState->prevPos > 0 && inferState->prevPos < text.size ())
134- {
135- // Valid prevPos, proceed with slicing the string from prevPos to the end
136- std::string stringTok (text.begin () + inferState->prevPos , text.end ());
137- std::lock_guard<std::mutex> guard (inferState->queueMutex ); // Protect access with a lock
138- inferState->textsToStream .push (stringTok);
139- }
140- else if (inferState->prevPos >= text.size ())
141- {
142- inferState->prevPos = text.size ();
143- }
144187 inferState->prevPos = text.size ();
188+ }
189+ inferState->prevPos = text.size ();
190+ if (finished)
191+ {
192+
193+ std::lock_guard<std::mutex> guard (inferState->queueMutex ); // Protect access with a lock
194+ inferState->textsToStream .push (" [DONE]" );
145195 return ;
146196 }
147- std::lock_guard<std::mutex> guard (inferState->queueMutex ); // Protect access with a lock
148- inferState->textsToStream .push (" [DONE]" );
149197 };
150198 // The rest of the logic inside the `chat_completion` remains unchanged...
151199 // After finishing the setup, call the inference logic
@@ -243,6 +291,12 @@ void tensorrtllm::chat_completion(
243291 {
244292
245293 std::string rawText = inferState->textsToStream .front ();
294+ inferState->textsToStream .pop ();
295+ if (handleMatch (rawText, inferState))
296+ {
297+ continue ;
298+ };
299+
246300 if (rawText == " [DONE]" )
247301 {
248302 LOG_INFO << " End of result" ;
@@ -257,14 +311,14 @@ void tensorrtllm::chat_completion(
257311 }
258312 const std::string textToStream
259313 = " data: " + create_return_json (nitro_utils::generate_random_string (20 ), " _" , rawText) + " \n\n " ;
260- inferState->textsToStream .pop ();
261314 lock.unlock (); // Unlock as soon as possible
262315
263316 // Ensure we do not exceed the buffer size. Truncate if necessary.
264317 std::size_t bytesToWrite = std::min (nBuffSize, textToStream.size ());
265318
266319 // Copy the text to the provided buffer
267320 std::memcpy (pBuffer, textToStream.data (), bytesToWrite);
321+ inferState->prevText = rawText;
268322 return bytesToWrite; // Return the number of bytes written to the buffer
269323 }
270324 else
0 commit comments