Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 186eee3

Browse files
tikikunhiento09
authored andcommitted
Merge pull request #14 from janhq/10-epic-add-proper-handler-for-stop-words
Add naive hiding stop words case
1 parent f9ba94b commit 186eee3

File tree

2 files changed

+82
-27
lines changed

2 files changed

+82
-27
lines changed

cpp/tensorrt_llm/nitro/controllers/tensorrtllm.cc

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,52 @@ void removeId(std::vector<int>& vec, int id)
2626
struct inferenceState
2727
{
2828
int prevPos{0};
29+
std::string prevText;
2930
bool isFinished;
3031
std::queue<std::string> textsToStream;
3132
std::mutex queueMutex; // Mutex to protect access to textsToStream
33+
34+
size_t stopWordMatchLen = 0;
35+
std::vector<std::string> sequence{"<", "|", "im", "_", "end", "|", ">"};
36+
37+
void reset()
38+
{
39+
stopWordMatchLen = 0;
40+
prevText = "";
41+
}
42+
43+
bool isComplete() const
44+
{
45+
return stopWordMatchLen >= sequence.size();
46+
}
3247
};
3348

49+
bool handleMatch(const std::string& rawText, std::shared_ptr<inferenceState> inferState)
50+
{
51+
if (inferState->isComplete())
52+
{
53+
return true;
54+
}
55+
56+
if (rawText == inferState->sequence[inferState->stopWordMatchLen])
57+
{
58+
inferState->stopWordMatchLen++; // Move to next state
59+
inferState->prevText = rawText;
60+
return true;
61+
}
62+
else if (inferState->stopWordMatchLen > 0 && rawText == inferState->sequence[0])
63+
{
64+
inferState->stopWordMatchLen = 1; // Restart from first match if sequence breaks but matches start
65+
inferState->prevText = rawText;
66+
return true;
67+
}
68+
else
69+
{
70+
inferState->reset();
71+
return false; // Reset to start if sequence breaks
72+
}
73+
}
74+
3475
// Only support single token stopping point now
3576
std::string create_return_json(const std::string& id, const std::string& model, const std::string& content,
3677
Json::Value finish_reason = Json::Value())
@@ -67,6 +108,13 @@ GenerationInput::TensorPtr tensorrtllm::getTensorSingleStopWordList(int stopToke
67108
return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 2}), MemoryType::kGPU);
68109
}
69110

111+
GenerationInput::TensorPtr tensorrtllm::getTensorChatMLStopWordList()
112+
{
113+
std::vector<int32_t> stopWordsTokens = {28789, 28766, 321, 28730, 416, 28766, 28767, 32000, 6, 8, -1, -1, -1, -1,
114+
-1, -1}; // Extend with -1 for increased length
115+
return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 8}), MemoryType::kGPU);
116+
}
117+
70118
GenerationInput tensorrtllm::createGenerationInput(std::vector<int32_t> inputIdsHost)
71119
{
72120
int inputLen = inputIdsHost.size();
@@ -78,7 +126,7 @@ GenerationInput tensorrtllm::createGenerationInput(std::vector<int32_t> inputIds
78126

79127
GenerationInput generationInput{0, 0, inputIds, inputLengths, modelConfig->usePackedInput()};
80128

81-
generationInput.stopWordsList = getTensorSingleStopWordList(32000);
129+
generationInput.stopWordsList = getTensorChatMLStopWordList();
82130
return generationInput;
83131
}
84132

@@ -117,35 +165,35 @@ void inferenceThread(std::shared_ptr<inferenceState> inferState, std::vector<int
117165
generationOutput.onTokenGenerated = [&inferState, inputLen, outputLen, self, &generationOutput](
118166
GenerationOutput::TensorPtr const& outputIds, SizeType step, bool finished)
119167
{
120-
if (!finished)
168+
// Assuming the shape of outputIds tensor is (1, 1, 160), where 160 is the number of tokens
169+
int outputLength = outputIds->getShape().d[2]; // Get the length of output IDs based on the tensor shape
170+
// Copy output IDs from GPU to host for printing
171+
std::vector<int32_t> outputIdsHost(outputLength);
172+
self->gptSession->getBufferManager().copy(*outputIds, outputIdsHost.data(), MemoryType::kCPU);
173+
// Find the last non-zero value in the output IDs starting from the end of the input sequence
174+
std::vector<int> outputIdsHostDecode(outputIdsHost.begin() + inputLen, outputIdsHost.end());
175+
removeId(outputIdsHostDecode, 0);
176+
std::string text = self->nitro_tokenizer->decode(outputIdsHostDecode);
177+
178+
if (inferState->prevPos > 0 && inferState->prevPos < text.size())
179+
{
180+
// Valid prevPos, proceed with slicing the string from prevPos to the end
181+
std::string stringTok(text.begin() + inferState->prevPos, text.end());
182+
std::lock_guard<std::mutex> guard(inferState->queueMutex); // Protect access with a lock
183+
inferState->textsToStream.push(stringTok);
184+
}
185+
else if (inferState->prevPos >= text.size())
121186
{
122-
// Assuming the shape of outputIds tensor is (1, 1, 160), where 160 is the number of tokens
123-
int outputLength = outputIds->getShape().d[2]; // Get the length of output IDs based on the tensor shape
124-
// Copy output IDs from GPU to host for printing
125-
std::vector<int32_t> outputIdsHost(outputLength);
126-
self->gptSession->getBufferManager().copy(*outputIds, outputIdsHost.data(), MemoryType::kCPU);
127-
// Find the last non-zero value in the output IDs starting from the end of the input sequence
128-
std::vector<int> outputIdsHostDecode(outputIdsHost.begin() + inputLen, outputIdsHost.end());
129-
removeId(outputIdsHostDecode, 0);
130-
removeId(outputIdsHostDecode, 32000);
131-
std::string text = self->nitro_tokenizer->decode(outputIdsHostDecode);
132-
133-
if (inferState->prevPos > 0 && inferState->prevPos < text.size())
134-
{
135-
// Valid prevPos, proceed with slicing the string from prevPos to the end
136-
std::string stringTok(text.begin() + inferState->prevPos, text.end());
137-
std::lock_guard<std::mutex> guard(inferState->queueMutex); // Protect access with a lock
138-
inferState->textsToStream.push(stringTok);
139-
}
140-
else if (inferState->prevPos >= text.size())
141-
{
142-
inferState->prevPos = text.size();
143-
}
144187
inferState->prevPos = text.size();
188+
}
189+
inferState->prevPos = text.size();
190+
if (finished)
191+
{
192+
193+
std::lock_guard<std::mutex> guard(inferState->queueMutex); // Protect access with a lock
194+
inferState->textsToStream.push("[DONE]");
145195
return;
146196
}
147-
std::lock_guard<std::mutex> guard(inferState->queueMutex); // Protect access with a lock
148-
inferState->textsToStream.push("[DONE]");
149197
};
150198
// The rest of the logic inside the `chat_completion` remains unchanged...
151199
// After finishing the setup, call the inference logic
@@ -243,6 +291,12 @@ void tensorrtllm::chat_completion(
243291
{
244292

245293
std::string rawText = inferState->textsToStream.front();
294+
inferState->textsToStream.pop();
295+
if (handleMatch(rawText, inferState))
296+
{
297+
continue;
298+
};
299+
246300
if (rawText == "[DONE]")
247301
{
248302
LOG_INFO << "End of result";
@@ -257,14 +311,14 @@ void tensorrtllm::chat_completion(
257311
}
258312
const std::string textToStream
259313
= "data: " + create_return_json(nitro_utils::generate_random_string(20), "_", rawText) + "\n\n";
260-
inferState->textsToStream.pop();
261314
lock.unlock(); // Unlock as soon as possible
262315

263316
// Ensure we do not exceed the buffer size. Truncate if necessary.
264317
std::size_t bytesToWrite = std::min(nBuffSize, textToStream.size());
265318

266319
// Copy the text to the provided buffer
267320
std::memcpy(pBuffer, textToStream.data(), bytesToWrite);
321+
inferState->prevText = rawText;
268322
return bytesToWrite; // Return the number of bytes written to the buffer
269323
}
270324
else

cpp/tensorrt_llm/nitro/controllers/tensorrtllm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class tensorrtllm : public drogon::HttpController<tensorrtllm>
100100
GenerationInput createGenerationInput(std::vector<int32_t> inputIds);
101101
GenerationOutput createGenerationOutput();
102102
std::unique_ptr<Tokenizer> nitro_tokenizer;
103+
GenerationInput::TensorPtr getTensorChatMLStopWordList();
103104

104105
private:
105106
GptSession::Config sessionConfig{1, 1, 1};

0 commit comments

Comments
 (0)