@@ -30,6 +30,8 @@ bool llama_batch_allocr::init(
3030
3131 batch = batch_inp;
3232
33+ this ->vocab = &vocab;
34+
3335 GGML_ASSERT (batch.n_tokens > 0 );
3436
3537 //
@@ -172,67 +174,39 @@ bool llama_batch_allocr::init(
172174
173175 if (debug > 0 ) {
174176 LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
175- LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
176- LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
177- LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
178- LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
179- LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) batch.n_seq_id );
180- LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) batch.seq_id );
181- LLAMA_LOG_DEBUG (" %s: logits = %p\n " , __func__, (void *) batch.logits );
182- LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
183177
184- if (debug > 1 ) {
185- int seq_id_max = 0 ;
186- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
187- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
188- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
189- seq_id_max = std::max (seq_id_max, batch.seq_id [i][s]);
190- }
191- }
178+ llama_ubatch ubatch {
179+ /* .equal_seqs =*/ false ,
180+ /* .n_tokens =*/ (uint32_t ) batch.n_tokens ,
181+ /* .n_seq_tokens =*/ (uint32_t ) 1 ,
182+ /* .n_seqs =*/ (uint32_t ) batch.n_tokens ,
183+ /* .token =*/ batch.token ,
184+ /* .embd =*/ batch.embd ,
185+ /* .pos =*/ batch.pos ,
186+ /* .n_seq_id =*/ batch.n_seq_id ,
187+ /* .seq_id =*/ batch.seq_id ,
188+ /* .output =*/ batch.logits ,
189+ };
190+
191+ ubatch_print (ubatch, debug);
192+
193+ LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
194+ for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
195+ if (seq_pos[s0].empty ()) {
196+ continue ;
192197 }
193- ++seq_id_max;
194198
195- LLAMA_LOG_DEBUG (" %s: token = [\n " , __func__);
196- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
197- std::vector<int8_t > seq_id (seq_id_max);
198-
199- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
200- seq_id[batch.seq_id [i][s]] = 1 ;
201- }
202-
203- std::stringstream ss;
204- for (int s = 0 ; s < seq_id_max; ++s) {
205- if (seq_id[s]) {
206- ss << s%10 ;
207- } else {
208- ss << " ." ;
209- }
199+ std::stringstream ss;
200+ for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
201+ if (seq_cpl[s0][s1]) {
202+ ss << s1 << " " ;
210203 }
211-
212- LLAMA_LOG_DEBUG (" %s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n " ,
213- __func__, i, batch.token [i], vocab.token_to_piece (batch.token [i]).c_str (),
214- batch.pos [i], batch.n_seq_id [i], ss.str ().c_str (), batch.logits [i]);
215204 }
216- LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
217-
218- LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
219- for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
220- if (seq_pos[s0].empty ()) {
221- continue ;
222- }
223205
224- std::stringstream ss;
225- for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
226- if (seq_cpl[s0][s1]) {
227- ss << s1 << " " ;
228- }
229- }
230-
231- LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
232- __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
233- }
234- LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
206+ LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
207+ __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
235208 }
209+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
236210 }
237211
238212 //
@@ -296,7 +270,7 @@ bool llama_batch_allocr::init(
296270 return true ;
297271}
298272
299- llama_ubatch llama_batch_allocr::reserve_one (uint32_t n_tokens) {
273+ llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_tokens) {
300274 clear ();
301275 split_reset ();
302276
@@ -389,7 +363,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
389363 }
390364 }
391365
392- return add_ubatch (idxs, idxs.size (), false );
366+ return ubatch_add (idxs, idxs.size (), false );
393367}
394368
395369llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
@@ -470,7 +444,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
470444 idxs.insert (idxs.end (), idxs_per_seq[s].begin (), idxs_per_seq[s].end ());
471445 }
472446
473- return add_ubatch (idxs, n_seqs, true );
447+ return ubatch_add (idxs, n_seqs, true );
474448}
475449
476450llama_ubatch llama_batch_allocr::split_seq (uint32_t n_ubatch) {
@@ -507,7 +481,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
507481 cur_seq_set = seq_set[cur_idx];
508482 }
509483
510- return add_ubatch (idxs, 1 , true );
484+ return ubatch_add (idxs, 1 , true );
511485}
512486
513487void llama_batch_allocr::clear () {
@@ -533,11 +507,9 @@ void llama_batch_allocr::clear() {
533507 seq_set_map.clear ();
534508}
535509
536- llama_ubatch llama_batch_allocr::add_ubatch (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs) {
510+ llama_ubatch llama_batch_allocr::ubatch_add (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs) {
537511 const uint32_t n_tokens = idxs.size ();
538512
539- LLAMA_LOG_DEBUG (" add_ubatch: n_tokens = %d, n_seqs = %d, equal_seqs = %d" , n_tokens, n_seqs, equal_seqs);
540-
541513 assert (n_tokens%n_seqs == 0 );
542514
543515 ubatches.emplace_back ();
@@ -584,11 +556,67 @@ llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, u
584556 /* .output =*/ ubatch.output .data (),
585557 };
586558
587- LLAMA_LOG_DEBUG (" %s: added ubatch of size %d\n " , __func__, res.n_tokens );
559+ LLAMA_LOG_DEBUG (" %s: added ubatch %d in split\n " , __func__, (int ) ubatches.size () - 1 );
560+
561+ if (debug > 0 ) {
562+ ubatch_print (res, debug);
563+ }
588564
589565 return res;
590566}
591567
568+ void llama_batch_allocr::ubatch_print (const llama_ubatch & ubatch, int debug) {
569+ if (debug > 0 ) {
570+ LLAMA_LOG_DEBUG (" %s: equal_seqs = %d\n " , __func__, ubatch.equal_seqs );
571+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, ubatch.n_tokens );
572+ LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d\n " , __func__, ubatch.n_seq_tokens );
573+ LLAMA_LOG_DEBUG (" %s: n_seqs = %d\n " , __func__, ubatch.n_seqs );
574+
575+ LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) ubatch.token );
576+ LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) ubatch.embd );
577+ LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) ubatch.pos );
578+ LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) ubatch.n_seq_id );
579+ LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) ubatch.seq_id );
580+ LLAMA_LOG_DEBUG (" %s: output = %p\n " , __func__, (void *) ubatch.output );
581+ LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
582+
583+ if (debug > 1 ) {
584+ int seq_id_max = 0 ;
585+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
586+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
587+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
588+ seq_id_max = std::max (seq_id_max, ubatch.seq_id [i][s]);
589+ }
590+ }
591+ }
592+ ++seq_id_max;
593+
594+ LLAMA_LOG_DEBUG (" %s: token = [\n " , __func__);
595+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
596+ std::vector<int8_t > seq_id (seq_id_max);
597+
598+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
599+ seq_id[ubatch.seq_id [i][s]] = 1 ;
600+ }
601+
602+ std::stringstream ss;
603+ for (int s = 0 ; s < seq_id_max; ++s) {
604+ if (seq_id[s]) {
605+ ss << s%10 ;
606+ } else {
607+ ss << " ." ;
608+ }
609+ }
610+
611+ LLAMA_LOG_DEBUG (" %s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n " ,
612+ __func__, i, ubatch.token [i], vocab->token_to_piece (ubatch.token [i]).c_str (),
613+ ubatch.pos [i], ubatch.n_seq_id [i], ss.str ().c_str (), ubatch.output [i]);
614+ }
615+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
616+ }
617+ }
618+ }
619+
592620//
593621// interface implementation
594622//
0 commit comments