33#include " build-info.h"
44
55#include < cmath>
6+ #include < cstdio>
7+ #include < cstring>
68#include < ctime>
79#include < sstream>
8- #include < cstring>
910#include < thread>
1011#include < mutex>
12+ #include < tuple>
13+ #include < utility>
14+ #include < vector>
1115
1216#if defined(_MSC_VER)
1317#pragma warning(disable: 4244 4267) // possible loss of data
@@ -29,20 +33,20 @@ std::vector<float> softmax(const std::vector<float>& logits) {
2933 return probs;
3034}
3135
32- float log_softmax (int n_vocab, const float * logits, int tok) {
36+ std::tuple< double , float , float > log_softmax (int n_vocab, const float * logits, int tok) {
3337 float max_logit = logits[0 ];
3438 for (int i = 1 ; i < n_vocab; ++i) max_logit = std::max (max_logit, logits[i]);
3539 double sum_exp = 0.0 ;
3640 for (int i = 0 ; i < n_vocab; ++i) sum_exp += expf (logits[i] - max_logit);
37- return logits[tok] - max_logit - log (sum_exp);
41+ return std::make_tuple (-( logits[tok] - max_logit - log (sum_exp)), logits[tok], expf (logits[tok] - max_logit) / sum_exp);
3842}
3943
40- void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
41- double & nll, double & nll2) {
44+ void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
45+ double & nll, double & nll2, float * logit_history, float * prob_history ) {
4246
4347 std::mutex mutex;
4448 int counter = 0 ;
45- auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
49+ auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
4650 double local_nll = 0 , local_nll2 = 0 ;
4751 while (true ) {
4852 std::unique_lock<std::mutex> lock (mutex);
@@ -52,34 +56,44 @@ void process_logits(int n_vocab, const float * logits, const int * tokens, int n
5256 break ;
5357 }
5458 lock.unlock ();
55- double v = -log_softmax (n_vocab, logits + i*n_vocab, tokens[i+1 ]);
56- local_nll += v;
57- local_nll2 += v*v;
59+ const std::tuple<double , float , float > v = log_softmax (n_vocab, logits + i*n_vocab, tokens[i+1 ]);
60+ const double v0 = std::get<0 >(v);
61+ local_nll += v0;
62+ local_nll2 += v0*v0;
63+
64+ logit_history[i] = std::get<1 >(v);
65+ prob_history[i] = std::get<2 >(v);
5866 }
5967 };
60- for (auto & w : workers) w = std::thread (compute);
68+ for (auto & w : workers) w = std::thread (compute);
6169 compute ();
62- for (auto & w : workers) w.join ();
70+ for (auto & w : workers) w.join ();
6371
6472}
6573
66- void perplexity_v2 (llama_context * ctx, const gpt_params & params) {
74+ std::tuple<std::vector<llama_token>, std::vector<float >, std::vector<float >, float >
75+ perplexity_v2 (llama_context * ctx, const gpt_params & params) {
6776 // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
6877 // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
6978 // Output: `perplexity: 13.5106 [114/114]`
7079 // BOS tokens will be added for each chunk before eval
7180
72- if (params.ppl_stride <= 0 ) {
73- fprintf (stderr, " %s: stride is %d but must be greater than zero!\n " ,__func__,params.ppl_stride );
74- return ;
75- }
76-
7781 const bool is_spm = llama_vocab_type (ctx) == LLAMA_VOCAB_TYPE_SPM;
7882 const bool add_bos = is_spm;
7983
8084 fprintf (stderr, " %s: tokenizing the input ..\n " , __func__);
8185
82- auto tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
86+ std::vector<llama_token> tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
87+ std::vector<float > logit_history;
88+ std::vector<float > prob_history;
89+
90+ logit_history.resize (tokens.size ());
91+ prob_history.resize (tokens.size ());
92+
93+ if (params.ppl_stride <= 0 ) {
94+ fprintf (stderr, " %s: stride is %d but must be greater than zero!\n " ,__func__,params.ppl_stride );
95+ return std::make_tuple (tokens, logit_history, prob_history, -1 );
96+ }
8397
8498 const int calc_chunk = params.n_ctx ;
8599
@@ -88,7 +102,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
88102 if (int (tokens.size ()) <= calc_chunk) {
89103 fprintf (stderr, " %s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n " ,__func__,
90104 tokens.size (), params.n_ctx , params.ppl_stride );
91- return ;
105+ return std::make_tuple (tokens, logit_history, prob_history, - 1 ) ;
92106 }
93107
94108 const int n_chunk_max = (tokens.size () - calc_chunk + params.ppl_stride - 1 ) / params.ppl_stride ;
@@ -120,7 +134,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
120134 // fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
121135 if (llama_eval (ctx, tokens.data () + batch_start, batch_size, j * n_batch, params.n_threads )) {
122136 // fprintf(stderr, "%s : failed to eval\n", __func__);
123- return ;
137+ return std::make_tuple (tokens, logit_history, prob_history, - 1 ) ;
124138 }
125139
126140 // save original token and restore it after eval
@@ -161,6 +175,8 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
161175 logits.begin () + (j + 1 ) * n_vocab);
162176
163177 const float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
178+ logit_history[start + j + 1 ] = tok_logits[tokens[start + j + 1 ]];
179+ prob_history[start + j + 1 ] = prob;
164180
165181 nll += -std::log (prob);
166182 ++count;
@@ -174,12 +190,15 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
174190 fflush (stdout);
175191 }
176192 printf (" \n " );
193+
194+ return std::make_tuple (tokens, logit_history, prob_history, std::exp (nll / count));
177195}
178196
179- void perplexity (llama_context * ctx, const gpt_params & params) {
197+ std::tuple<std::vector<llama_token>, std::vector<float >, std::vector<float >, float >
198+ perplexity (llama_context * ctx, const gpt_params & params) {
199+
180200 if (params.ppl_stride > 0 ) {
181- perplexity_v2 (ctx, params);
182- return ;
201+ return perplexity_v2 (ctx, params);
183202 }
184203
185204 // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
@@ -193,11 +212,17 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
193212 auto tim1 = std::chrono::high_resolution_clock::now ();
194213 fprintf (stderr, " %s: tokenizing the input ..\n " , __func__);
195214
196- auto tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
215+ std::vector<llama_token> tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
197216
198217 auto tim2 = std::chrono::high_resolution_clock::now ();
199218 fprintf (stderr, " %s: tokenization took %g ms\n " ,__func__,1e-3 *std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count ());
200219
220+ std::vector<float > logit_history;
221+ logit_history.resize (tokens.size ());
222+
223+ std::vector<float > prob_history;
224+ prob_history.resize (tokens.size ());
225+
201226 const int n_chunk_max = tokens.size () / params.n_ctx ;
202227
203228 const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min (params.n_chunks , n_chunk_max);
@@ -236,7 +261,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
236261
237262 if (llama_eval (ctx, tokens.data () + batch_start, batch_size, j * n_batch, params.n_threads )) {
238263 fprintf (stderr, " %s : failed to eval\n " , __func__);
239- return ;
264+ return std::make_tuple (tokens, logit_history, prob_history, - 1 ) ;
240265 }
241266
242267 // restore the original token in case it was set to BOS
@@ -272,7 +297,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
272297 // last 256 tokens. Then, we split the input up into context window size chunks to
273298 // process the entire prompt.
274299 const int first = std::min (512 , params.n_ctx /2 );
275- process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
300+ process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, params.n_ctx - 1 - first,
301+ workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
276302 count += params.n_ctx - first - 1 ;
277303
278304 // perplexity is e^(average negative log-likelihood)
@@ -287,16 +313,19 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
287313 fflush (stdout);
288314 }
289315 printf (" \n " );
316+
290317 nll2 /= count;
291318 nll /= count;
319+ const double ppl = exp (nll);
292320 nll2 -= nll * nll;
293321 if (nll2 > 0 ) {
294322 nll2 = sqrt (nll2/(count-1 ));
295- double ppl = exp (nll);
296323 printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2*ppl);
297324 } else {
298325 printf (" Unexpected negative standard deviation of log(prob)\n " );
299326 }
327+
328+ return std::make_tuple (tokens, logit_history, prob_history, ppl);
300329}
301330
302331std::vector<float > hellaswag_evaluate_tokens (llama_context * ctx, const std::vector<int >& tokens, int n_past, int n_batch,
@@ -604,13 +633,56 @@ int main(int argc, char ** argv) {
604633 params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
605634 }
606635
636+ std::vector<llama_token> tokens;
637+ std::vector<float > logits;
638+ std::vector<float > probs;
639+ double perplexity_value = -1 ;
607640 if (params.hellaswag ) {
608641 hellaswag_score (ctx, params);
609642 } else {
610- perplexity (ctx, params);
643+ auto ret = perplexity (ctx, params);
644+ tokens = std::get<0 >(ret);
645+ logits = std::get<1 >(ret);
646+ probs = std::get<2 >(ret);
647+ perplexity_value = std::get<3 >(ret);
611648 }
612649
613650 llama_print_timings (ctx);
651+
652+ if (params.hellaswag && !params.logdir .empty ()) {
653+ fprintf (stderr, " %s: warning: logging results is not implemented for HellaSwag. No files will be written.\n " , __func__);
654+ }
655+
656+ if (!params.hellaswag && !params.logdir .empty ()) {
657+ const std::string timestamp = get_sortable_timestamp ();
658+
659+ const bool success = create_directory_with_parents (params.logdir );
660+ if (success) {
661+
662+ FILE * logfile = fopen ((params.logdir + timestamp + " .yml" ).c_str (), " w" );
663+ fprintf (logfile, " binary: perplexity\n " );
664+ char model_type[128 ];
665+ llama_model_desc (model, model_type, sizeof (model_type));
666+ dump_non_result_info_yaml (logfile, params, ctx, timestamp, tokens, model_type);
667+
668+ fprintf (logfile, " \n " );
669+ fprintf (logfile, " ######################\n " );
670+ fprintf (logfile, " # Perplexity Results #\n " );
671+ fprintf (logfile, " ######################\n " );
672+ fprintf (logfile, " \n " );
673+
674+ dump_vector_float_yaml (logfile, " logits" , logits);
675+ fprintf (logfile, " ppl_value: %f\n " , perplexity_value);
676+ dump_vector_float_yaml (logfile, " probs" , probs);
677+
678+ llama_dump_timing_info_yaml (logfile, ctx);
679+ fclose (logfile);
680+ } else {
681+ fprintf (stderr, " %s: warning: failed to create logdir %s, cannot write logfile\n " ,
682+ __func__, params.logdir .c_str ());
683+ }
684+ }
685+
614686 llama_free (ctx);
615687 llama_free_model (model);
616688
0 commit comments