@@ -121,8 +121,23 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
121121 printf (" \n " );
122122}
123123
124- void perplexity_lines (llama_context * ctx, const gpt_params & params) {
125- // Calculates perplexity over each line of the prompt
124+ void hellaswag_score (llama_context * ctx, const gpt_params & params) {
125+ // Calculates hellaswag score (acc_norm) from prompt
126+ //
127+ // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
128+ // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
129+ //
130+ // All 10042 tasks should be extracted to keep the results standardized like other implementations.
131+ //
132+ // Datafile layout:
133+ // ['??'] denotes json fields
134+ // 6 lines per task:
135+ // ['activity_label'] + ": " +['ctx'] - The first part of the query, the context
136+ // ['label'] - The index the best common sense ending aka gold ending
137+ // ['endings'][0] - Endings added to the first part of the query
138+ // ['endings'][1]
139+ // ['endings'][2]
140+ // ['endings'][3]
126141
127142 std::vector<std::string> prompt_lines;
128143 std::istringstream strstream (params.prompt );
@@ -132,63 +147,149 @@ void perplexity_lines(llama_context * ctx, const gpt_params & params) {
132147 prompt_lines.push_back (line);
133148 }
134149
135- const int n_vocab = llama_n_vocab (ctx);
150+ if ( prompt_lines.size () % 6 != 0 ) {
151+ fprintf (stderr, " %s : number of lines in prompt not a multiple of 6.\n " , __func__);
152+ return ;
153+ }
136154
137- int counttotal = 0 ;
138- size_t n_lines = prompt_lines. size ( );
155+ size_t hs_task_count = prompt_lines. size ()/ 6 ;
156+ fprintf (stderr, " %s : loaded %lu tasks from prompt. \n " , __func__, hs_task_count );
139157
140- double nll = 0.0 ;
158+ // This is needed as usual for LLaMA models
159+ bool prepend_bos = true ;
160+
161+ // Number of tasks to use when computing the score
162+ if ( params.hellaswag_tasks < hs_task_count ) {
163+ hs_task_count = params.hellaswag_tasks ;
164+ }
141165
142- fprintf (stderr, " %s: calculating perplexity over %lu lines\n " , __func__, n_lines);
166+ // The tasks should be randomized so the score stabilizes quickly.
167+ bool randomize_tasks = true ;
143168
144- printf (" \n Line\t PPL line\t PPL cumulative\n " );
169+ // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
170+ std::mt19937 rng (1 );
145171
146- for (size_t i = 0 ; i < n_lines; ++i) {
172+ // Dataholder for hellaswag tasks
173+ struct hs_data_t {
174+ std::string context;
175+ size_t gold_ending_idx;
176+ std::string ending[4 ];
177+ size_t ending_logprob_count[4 ];
178+ double ending_logprob[4 ];
179+ };
147180
148- // Tokenize and insert BOS at start
149- std::vector<int > batch_embd = ::llama_tokenize (ctx, prompt_lines[i], true );
181+ fprintf (stderr, " %s : selecting %lu %s tasks.\n " , __func__, hs_task_count, (randomize_tasks?" randomized" :" the first" ) );
150182
151- size_t batch_size = batch_embd.size ();
183+ // Select and read data from prompt lines
184+ hs_data_t *hs_data = new hs_data_t [hs_task_count];
185+ for (size_t i=0 ; i < hs_task_count; i++) {
186+ size_t idx = i;
152187
153- // Stop if line is too long
154- if ( batch_size > ( size_t )params. n_ctx ) {
155- fprintf (stderr, " %s : tokens in line %lu > n_ctxl \n " , __func__, i) ;
156- return ;
188+ // Select a random example of those left in the prompt
189+ if (randomize_tasks ) {
190+ std::uniform_int_distribution< size_t > dist ( 0 , prompt_lines. size ()/ 6 - 1 ) ;
191+ idx = dist (rng) ;
157192 }
158193
159- if (llama_eval (ctx, batch_embd.data (), batch_size, 0 , params.n_threads )) {
160- fprintf (stderr, " %s : failed to eval\n " , __func__);
161- return ;
194+ hs_data[i].context = prompt_lines[idx*6 ];
195+ hs_data[i].gold_ending_idx = std::stoi ( prompt_lines[idx*6 +1 ] );
196+ for (size_t j=0 ; j < 4 ; j++) {
197+ hs_data[i].ending [j] = " " + prompt_lines[idx*6 +2 +j];
162198 }
163199
164- const auto batch_logits = llama_get_logits (ctx);
165- std::vector<float > logits;
166- logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
200+ // Delete the selected random example from the prompt
201+ if (randomize_tasks) {
202+ prompt_lines.erase ( std::next (prompt_lines.begin (),idx*6 ) , std::next (prompt_lines.begin (),idx*6 +6 ) );
203+ }
204+ }
167205
168- double nllline = 0.0 ;
169- int countline = 0 ;
206+ fprintf (stderr, " %s : calculating hellaswag score over selected tasks. \n " , __func__) ;
207+ printf ( " \n task \t acc_norm \n " ) ;
170208
171- // Perplexity over second half of the line
172- for (size_t j = batch_size/2 ; j < batch_size - 1 ; ++j) {
173- // Calculate probability of next token, given the previous ones.
174- const std::vector<float > tok_logits (
175- logits.begin () + (j + 0 ) * n_vocab,
176- logits.begin () + (j + 1 ) * n_vocab);
209+ double acc = 0 .0f ;
210+ const int n_vocab = llama_n_vocab (ctx);
211+
212+ for (size_t task_idx = 0 ; task_idx < hs_task_count; task_idx++) {
213+
214+ // Tokenize the context to count tokens
215+ std::vector<int > context_embd = ::llama_tokenize (ctx, hs_data[task_idx].context , prepend_bos);
216+ size_t context_size = context_embd.size ();
217+
218+ for (size_t ending_idx=0 ;ending_idx<4 ;ending_idx++) {
219+
220+ // Tokenize the query
221+ std::vector<int > query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx].ending [ending_idx], prepend_bos);
222+ size_t query_size = query_embd.size ();
223+
224+ // Stop if query wont fit the ctx window
225+ if (query_size > (size_t )params.n_ctx ) {
226+ fprintf (stderr, " %s : number of tokens in query %lu > n_ctxl\n " , __func__, query_size);
227+ return ;
228+ }
177229
178- const float prob = softmax (tok_logits)[batch_embd[ j + 1 ]];
230+ // Speedup small evaluations by evaluating atleast 32 tokens
231+ if (query_size < 32 ) {
232+ query_embd.resize (32 );
233+ }
234+
235+ // Evaluate the query
236+ if (llama_eval (ctx, query_embd.data (), query_embd.size (), 0 , params.n_threads )) {
237+ fprintf (stderr, " %s : failed to eval\n " , __func__);
238+ return ;
239+ }
240+
241+ const auto query_logits = llama_get_logits (ctx);
242+ std::vector<float > logits;
243+ logits.insert (logits.end (), query_logits, query_logits + query_size * n_vocab);
244+
245+ hs_data[task_idx].ending_logprob_count [ending_idx] = 0 ;
246+ hs_data[task_idx].ending_logprob [ending_idx] = 0 .0f ;
247+
248+ // Calculate the logprobs over the ending
249+ for (size_t j = context_size-1 ; j < query_size - 1 ; j++) {
250+ // Calculate probability of next token, given the previous ones.
251+ const std::vector<float > tok_logits (
252+ logits.begin () + (j + 0 ) * n_vocab,
253+ logits.begin () + (j + 1 ) * n_vocab);
254+
255+ const float prob = softmax (tok_logits)[query_embd[ j + 1 ]];
256+
257+ hs_data[task_idx].ending_logprob [ending_idx] += std::log (prob);
258+ hs_data[task_idx].ending_logprob_count [ending_idx]++;
259+ }
260+
261+ // Calculate the mean token logprob for acc_norm
262+ hs_data[task_idx].ending_logprob [ending_idx] /= hs_data[task_idx].ending_logprob_count [ending_idx];
263+
264+
265+ // printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
266+ // task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
267+ }
179268
180- nllline += -std::log (prob);
181- ++countline;
269+ // Find the ending with maximum logprob
270+ size_t ending_logprob_max_idx = -1 ;
271+ double ending_logprob_max_val = -INFINITY;
272+ for (size_t j=0 ; j < 4 ; j++) {
273+ if (hs_data[task_idx].ending_logprob [j] > ending_logprob_max_val) {
274+ ending_logprob_max_idx = j;
275+ ending_logprob_max_val = hs_data[task_idx].ending_logprob [j];
276+ }
182277 }
183278
184- nll += nllline;
185- counttotal += countline;
279+ // printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
186280
187- // perplexity is e^(average negative log-likelihood)
188- printf (" %lu\t %.8lf\t %.8lf\n " , i + 1 , std::exp (nllline/countline), std::exp (nll / counttotal) );
281+ // If the gold ending got the maximum logprobe add one accuracy point
282+ if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx ) {
283+ acc += 1.0 ;
284+ }
285+
286+ // Print the accumulated accuracy mean x 100
287+ printf (" %li\t %.8lf\n " ,task_idx+1 , acc/double (task_idx+1 )*100.0 );
189288 fflush (stdout);
190289 }
191290
291+ delete [] hs_data;
292+
192293 printf (" \n " );
193294}
194295
@@ -240,8 +341,8 @@ int main(int argc, char ** argv) {
240341 params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
241342 }
242343
243- if (params.perplexity_lines ) {
244- perplexity_lines (ctx, params);
344+ if (params.hellaswag ) {
345+ hellaswag_score (ctx, params);
245346 } else {
246347 perplexity (ctx, params);
247348 }
0 commit comments