1- #include " llama.h"
21#include < cstdio>
32#include < cstring>
43#include < iostream>
4+ #include < memory>
55#include < string>
66#include < vector>
77
8+ #include " llama.h"
9+
10+ // Add a message to `messages` and store its content in `owned_content`
11+ static void add_message (const std::string &role, const std::string &text,
12+ std::vector<llama_chat_message> &messages,
13+ std::vector<std::unique_ptr<char []>> &owned_content) {
14+ auto content = std::make_unique<char []>(text.size () + 1 );
15+ std::strcpy (content.get (), text.c_str ());
16+ messages.push_back ({role.c_str (), content.get ()});
17+ owned_content.push_back (std::move (content));
18+ }
19+
20+ // Function to apply the chat template and resize `formatted` if needed
21+ static int apply_chat_template (const llama_model *model,
22+ const std::vector<llama_chat_message> &messages,
23+ std::vector<char > &formatted, bool append) {
24+ int result = llama_chat_apply_template (model, nullptr , messages.data (),
25+ messages.size (), append,
26+ formatted.data (), formatted.size ());
27+ if (result > static_cast <int >(formatted.size ())) {
28+ formatted.resize (result);
29+ result = llama_chat_apply_template (model, nullptr , messages.data (),
30+ messages.size (), append,
31+ formatted.data (), formatted.size ());
32+ }
33+
34+ return result;
35+ }
36+
37+ // Function to tokenize the prompt
38+ static int tokenize_prompt (const llama_model *model, const std::string &prompt,
39+ std::vector<llama_token> &prompt_tokens) {
40+ const int n_prompt_tokens = -llama_tokenize (
41+ model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
42+ prompt_tokens.resize (n_prompt_tokens);
43+ if (llama_tokenize (model, prompt.c_str (), prompt.size (),
44+ prompt_tokens.data (), prompt_tokens.size (), true ,
45+ true ) < 0 ) {
46+ GGML_ABORT (" failed to tokenize the prompt\n " );
47+ return -1 ;
48+ }
49+
50+ return n_prompt_tokens;
51+ }
52+
53+ // Check if we have enough space in the context to evaluate this batch
54+ static int check_context_size (const llama_context *ctx,
55+ const llama_batch &batch) {
56+ const int n_ctx = llama_n_ctx (ctx);
57+ const int n_ctx_used = llama_get_kv_cache_used_cells (ctx);
58+ if (n_ctx_used + batch.n_tokens > n_ctx) {
59+ printf (" \033 [0m\n " );
60+ fprintf (stderr, " context size exceeded\n " );
61+ return 1 ;
62+ }
63+
64+ return 0 ;
65+ }
66+
67+ // convert the token to a string
68+ static int convert_token_to_string (const llama_model *model,
69+ const llama_token token_id,
70+ std::string &piece) {
71+ char buf[256 ];
72+ int n = llama_token_to_piece (model, token_id, buf, sizeof (buf), 0 , true );
73+ if (n < 0 ) {
74+ GGML_ABORT (" failed to convert token to piece\n " );
75+ return 1 ;
76+ }
77+
78+ piece = std::string (buf, n);
79+ return 0 ;
80+ }
81+
82+ static void print_word_and_concatenate_to_response (const std::string &piece,
83+ std::string &response) {
84+ printf (" %s" , piece.c_str ());
85+ fflush (stdout);
86+ response += piece;
87+ }
88+
89+ // helper function to evaluate a prompt and generate a response
90+ static int generate (const llama_model *model, llama_sampler *smpl,
91+ llama_context *ctx, const std::string &prompt,
92+ std::string &response) {
93+ std::vector<llama_token> prompt_tokens;
94+ const int n_prompt_tokens = tokenize_prompt (model, prompt, prompt_tokens);
95+ if (n_prompt_tokens < 0 ) {
96+ return 1 ;
97+ }
98+
99+ // prepare a batch for the prompt
100+ llama_batch batch =
101+ llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size ());
102+ llama_token new_token_id;
103+ while (true ) {
104+ check_context_size (ctx, batch);
105+ if (llama_decode (ctx, batch)) {
106+ GGML_ABORT (" failed to decode\n " );
107+ return 1 ;
108+ }
109+
110+ // sample the next token, check is it an end of generation?
111+ new_token_id = llama_sampler_sample (smpl, ctx, -1 );
112+ if (llama_token_is_eog (model, new_token_id)) {
113+ break ;
114+ }
115+
116+ std::string piece;
117+ if (convert_token_to_string (model, new_token_id, piece)) {
118+ return 1 ;
119+ }
120+
121+ print_word_and_concatenate_to_response (piece, response);
122+
123+ // prepare the next batch with the sampled token
124+ batch = llama_batch_get_one (&new_token_id, 1 );
125+ }
126+
127+ return 0 ;
128+ }
129+
8130static void print_usage (int , char ** argv) {
9131 printf (" \n example usage:\n " );
10132 printf (" \n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n " , argv[0 ]);
@@ -66,6 +188,7 @@ int main(int argc, char ** argv) {
66188 llama_model_params model_params = llama_model_default_params ();
67189 model_params.n_gpu_layers = ngl;
68190
191+ // This prints ........
69192 llama_model * model = llama_load_model_from_file (model_path.c_str (), model_params);
70193 if (!model) {
71194 fprintf (stderr , " %s: error: unable to load model\n " , __func__);
@@ -88,107 +211,49 @@ int main(int argc, char ** argv) {
88211 llama_sampler_chain_add (smpl, llama_sampler_init_min_p (0 .05f , 1 ));
89212 llama_sampler_chain_add (smpl, llama_sampler_init_temp (0 .8f ));
90213 llama_sampler_chain_add (smpl, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
91-
92- // helper function to evaluate a prompt and generate a response
93- auto generate = [&](const std::string & prompt) {
94- std::string response;
95-
96- // tokenize the prompt
97- const int n_prompt_tokens = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
98- std::vector<llama_token> prompt_tokens (n_prompt_tokens);
99- if (llama_tokenize (model, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), llama_get_kv_cache_used_cells (ctx) == 0 , true ) < 0 ) {
100- GGML_ABORT (" failed to tokenize the prompt\n " );
101- }
102-
103- // prepare a batch for the prompt
104- llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size ());
105- llama_token new_token_id;
106- while (true ) {
107- // check if we have enough space in the context to evaluate this batch
108- int n_ctx = llama_n_ctx (ctx);
109- int n_ctx_used = llama_get_kv_cache_used_cells (ctx);
110- if (n_ctx_used + batch.n_tokens > n_ctx) {
111- printf (" \033 [0m\n " );
112- fprintf (stderr, " context size exceeded\n " );
113- exit (0 );
114- }
115-
116- if (llama_decode (ctx, batch)) {
117- GGML_ABORT (" failed to decode\n " );
118- }
119-
120- // sample the next token
121- new_token_id = llama_sampler_sample (smpl, ctx, -1 );
122-
123- // is it an end of generation?
124- if (llama_token_is_eog (model, new_token_id)) {
125- break ;
126- }
127-
128- // convert the token to a string, print it and add it to the response
129- char buf[256 ];
130- int n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
131- if (n < 0 ) {
132- GGML_ABORT (" failed to convert token to piece\n " );
133- }
134- std::string piece (buf, n);
135- printf (" %s" , piece.c_str ());
136- fflush (stdout);
137- response += piece;
138-
139- // prepare the next batch with the sampled token
140- batch = llama_batch_get_one (&new_token_id, 1 );
141- }
142-
143- return response;
144- };
145-
146214 std::vector<llama_chat_message> messages;
215+ std::vector<std::unique_ptr<char []>> owned_content;
147216 std::vector<char > formatted (llama_n_ctx (ctx));
148217 int prev_len = 0 ;
149218 while (true ) {
150219 // get user input
151220 printf (" \033 [32m> \033 [0m" );
152221 std::string user;
153222 std::getline (std::cin, user);
154-
155223 if (user.empty ()) {
156224 break ;
157225 }
158226
159- // add the user input to the message list and format it
160- messages.push_back ({" user" , strdup (user.c_str ())});
161- int new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
162- if (new_len > (int )formatted.size ()) {
163- formatted.resize (new_len);
164- new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
165- }
227+ // Add user input to messages
228+ add_message (" user" , user, messages, owned_content);
229+ int new_len = apply_chat_template (model, messages, formatted, true );
166230 if (new_len < 0 ) {
167231 fprintf (stderr, " failed to apply the chat template\n " );
168232 return 1 ;
169233 }
170234
171- // remove previous messages to obtain the prompt to generate the response
172- std::string prompt (formatted.begin () + prev_len, formatted.begin () + new_len);
235+ // remove previous messages to obtain the prompt to generate the
236+ // response
237+ std::string prompt (formatted.begin () + prev_len,
238+ formatted.begin () + new_len);
173239
174240 // generate a response
175241 printf (" \033 [33m" );
176- std::string response = generate (prompt);
242+ std::string response;
243+ if (generate (model, smpl, ctx, prompt, response)) {
244+ return 1 ;
245+ }
246+
177247 printf (" \n\033 [0m" );
178248
179- // add the response to the messages
180- messages.push_back ({" assistant" , strdup (response.c_str ())});
181- prev_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), false , nullptr , 0 );
249+ // Add response to messages
250+ prev_len = apply_chat_template (model, messages, formatted, false );
182251 if (prev_len < 0 ) {
183252 fprintf (stderr, " failed to apply the chat template\n " );
184253 return 1 ;
185254 }
186255 }
187256
188- // free resources
189- for (auto & msg : messages) {
190- free (const_cast <char *>(msg.content ));
191- }
192257 llama_sampler_free (smpl);
193258 llama_free (ctx);
194259 llama_free_model (model);
0 commit comments