@@ -39,73 +39,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
3939 return true ;
4040}
4141
42- // TODO: use common/sampling.h
43- static llama_token sample_id (llama_context * ctx_llama, gpt_params & params) {
44- auto & sparams = params.sparams ;
45-
46- // out of user input, sample next token
47- const float temp = sparams.temp ;
48- const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab (llama_get_model (ctx_llama)) : sparams.top_k ;
49- const float top_p = sparams.top_p ;
50- const float tfs_z = sparams.tfs_z ;
51- const float typical_p = sparams.typical_p ;
52- // const int32_t repeat_last_n = sparams.repeat_last_n < 0 ? n_ctx : sparams.repeat_last_n;
53- // const float repeat_penalty = sparams.repeat_penalty;
54- // const float alpha_presence = sparams.presence_penalty;
55- // const float alpha_frequency = sparams.frequency_penalty;
56- const int mirostat = sparams.mirostat ;
57- const float mirostat_tau = sparams.mirostat_tau ;
58- const float mirostat_eta = sparams.mirostat_eta ;
59- // const bool penalize_nl = sparams.penalize_nl;
60-
61- llama_token id = 0 ;
62- {
63- auto logits = llama_get_logits (ctx_llama);
64- auto n_vocab = llama_n_vocab (llama_get_model (ctx_llama));
65-
66- // Apply params.logit_bias map
67- for (auto it = sparams.logit_bias .begin (); it != sparams.logit_bias .end (); it++) {
68- logits[it->first ] += it->second ;
69- }
70-
71- std::vector<llama_token_data> candidates;
72- candidates.reserve (n_vocab);
73- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
74- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
75- }
76-
77- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
78-
79- if (temp <= 0 ) {
80- // Greedy sampling
81- id = llama_sample_token_greedy (ctx_llama, &candidates_p);
82- } else {
83- if (mirostat == 1 ) {
84- static float mirostat_mu = 2 .0f * mirostat_tau;
85- const int mirostat_m = 100 ;
86- llama_sample_temp (ctx_llama, &candidates_p, temp);
87- id = llama_sample_token_mirostat (ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
88- } else if (mirostat == 2 ) {
89- static float mirostat_mu = 2 .0f * mirostat_tau;
90- llama_sample_temp (ctx_llama, &candidates_p, temp);
91- id = llama_sample_token_mirostat_v2 (ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
92- } else {
93- // Temperature sampling
94- llama_sample_top_k (ctx_llama, &candidates_p, top_k, 1 );
95- llama_sample_tail_free (ctx_llama, &candidates_p, tfs_z, 1 );
96- llama_sample_typical (ctx_llama, &candidates_p, typical_p, 1 );
97- llama_sample_top_p (ctx_llama, &candidates_p, top_p, 1 );
98- llama_sample_temp (ctx_llama, &candidates_p, temp);
99- id = llama_sample_token (ctx_llama, &candidates_p);
100- }
101- }
102- }
103-
104- return id;
105- }
106-
107- static const char * sample (struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
108- int id = sample_id (ctx_llama, params);
42+ static const char * sample (struct llama_sampling_context * ctx_sampling,
43+ struct llama_context * ctx_llama,
44+ int * n_past) {
45+ const llama_token id = llama_sampling_sample (ctx_sampling, ctx_llama, NULL );
46+ llama_sampling_accept (ctx_sampling, ctx_llama, id, true );
10947 static std::string ret;
11048 if (id == llama_token_eos (llama_get_model (ctx_llama))) {
11149 ret = " </s>" ;
@@ -174,8 +112,8 @@ struct llava_context {
174112};
175113
176114static void show_additional_info (int /* argc*/ , char ** argv) {
177- printf ( " \n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \" describe the image in detail.\" ]\n " , argv[0 ]);
178- printf ( " note: a lower temperature value like 0.1 is recommended for better quality.\n " );
115+ fprintf (stderr, " \n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \" describe the image in detail.\" ]\n " , argv[0 ]);
116+ fprintf (stderr, " note: a lower temperature value like 0.1 is recommended for better quality.\n " );
179117}
180118
181119static struct llava_image_embed * load_image (llava_context * ctx_llava, gpt_params * params) {
@@ -185,7 +123,7 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
185123 auto prompt = params->prompt ;
186124 if (prompt_contains_image (prompt)) {
187125 if (!params->image .empty ()) {
188- printf ( " using base64 encoded image instead of command line image path\n " );
126+ fprintf (stderr, " using base64 encoded image instead of command line image path\n " );
189127 }
190128 embed = llava_image_embed_make_with_prompt_base64 (ctx_llava->ctx_clip , params->n_threads , prompt);
191129 if (!embed) {
@@ -217,16 +155,19 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
217155
218156 // generate the response
219157
220- printf (" \n " );
158+ fprintf (stderr, " \n " );
159+
160+ struct llama_sampling_context * ctx_sampling = llama_sampling_init (params->sparams );
221161
222162 for (int i = 0 ; i < max_tgt_len; i++) {
223- const char * tmp = sample (ctx_llava->ctx_llama , *params , &n_past);
163+ const char * tmp = sample (ctx_sampling, ctx_llava->ctx_llama , &n_past);
224164 if (strcmp (tmp, " </s>" ) == 0 ) break ;
225165
226166 printf (" %s" , tmp);
227167 fflush (stdout);
228168 }
229169
170+ llama_sampling_free (ctx_sampling);
230171 printf (" \n " );
231172}
232173
0 commit comments