66#include < ctime>
77#include < sstream>
88#include < cstring>
9+ #include < thread>
10+ #include < mutex>
911
1012#if defined(_MSC_VER)
1113#pragma warning(disable: 4244 4267) // possible loss of data
@@ -27,6 +29,40 @@ std::vector<float> softmax(const std::vector<float>& logits) {
2729 return probs;
2830}
2931
32+ float log_softmax (int n_vocab, const float * logits, int tok) {
33+ float max_logit = logits[0 ];
34+ for (int i = 1 ; i < n_vocab; ++i) max_logit = std::max (max_logit, logits[i]);
35+ double sum_exp = 0.0 ;
36+ for (int i = 0 ; i < n_vocab; ++i) sum_exp += expf (logits[i] - max_logit);
37+ return logits[tok] - max_logit - log (sum_exp);
38+ }
39+
40+ void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
41+ double & nll, double & nll2) {
42+
43+ std::mutex mutex;
44+ int counter = 0 ;
45+ auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
46+ double local_nll = 0 , local_nll2 = 0 ;
47+ while (true ) {
48+ std::unique_lock<std::mutex> lock (mutex);
49+ int i = counter++;
50+ if (i >= n_token) {
51+ nll += local_nll; nll2 += local_nll2;
52+ break ;
53+ }
54+ lock.unlock ();
55+ double v = -log_softmax (n_vocab, logits + i*n_vocab, tokens[i+1 ]);
56+ local_nll += v;
57+ local_nll2 += v*v;
58+ }
59+ };
60+ for (auto & w : workers) w = std::thread (compute);
61+ compute ();
62+ for (auto & w : workers) w.join ();
63+
64+ }
65+
3066void perplexity_v2 (llama_context * ctx, const gpt_params & params) {
3167 // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
3268 // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
@@ -166,9 +202,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
166202
167203 int count = 0 ;
168204 double nll = 0.0 ;
205+ double nll2 = 0.0 ;
169206
170207 fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
171208
209+ std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
210+
172211 for (int i = 0 ; i < n_chunk; ++i) {
173212 const int start = i * params.n_ctx ;
174213 const int end = start + params.n_ctx ;
@@ -228,26 +267,32 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
228267 // Example, we have a context window of 512, we will compute perplexity for each of the
229268 // last 256 tokens. Then, we split the input up into context window size chunks to
230269 // process the entire prompt.
231- for (int j = std::min (512 , params.n_ctx / 2 ); j < params.n_ctx - 1 ; ++j) {
232- // Calculate probability of next token, given the previous ones.
233- const std::vector<float > tok_logits (
234- logits.begin () + (j + 0 ) * n_vocab,
235- logits.begin () + (j + 1 ) * n_vocab);
236-
237- const float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
270+ const int first = std::min (512 , params.n_ctx /2 );
271+ process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
272+ count += params.n_ctx - first - 1 ;
238273
239- nll += -std::log (prob);
240- ++count;
241- }
242274 // perplexity is e^(average negative log-likelihood)
243275 if (params.ppl_output_type == 0 ) {
244276 printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
245277 } else {
246- printf (" %8d %.4lf\n " , i*params.n_ctx , std::exp (nll / count));
278+ double av = nll/count;
279+ double av2 = nll2/count - av*av;
280+ if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
281+ printf (" %8d %.4lf %4lf %4lf\n " , i*params.n_ctx , std::exp (nll / count), av, av2);
247282 }
248283 fflush (stdout);
249284 }
250285 printf (" \n " );
286+ nll2 /= count;
287+ nll /= count;
288+ nll2 -= nll * nll;
289+ if (nll2 > 0 ) {
290+ nll2 = sqrt (nll2/(count-1 ));
291+ double ppl = exp (nll);
292+ printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2*ppl);
293+ } else {
294+ printf (" Unexpected negative standard deviation of log(prob)\n " );
295+ }
251296}
252297
253298std::vector<float > hellaswag_evaluate_tokens (llama_context * ctx, const std::vector<int >& tokens, int n_past, int n_batch,
0 commit comments