@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166
166
167
167
// note: tracking the other way around is not necessary for now
168
168
// seq_cpl[s0][s1] = true;
169
+
170
+ has_cpl = true ;
169
171
}
170
172
}
171
173
}
@@ -403,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
403
405
return n_outputs;
404
406
}
405
407
408
+ uint32_t llama_batch_allocr::get_n_used () const {
409
+ return n_used;
410
+ }
411
+
406
412
std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
407
413
return out_ids;
408
414
}
@@ -418,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
418
424
void llama_batch_allocr::split_reset () {
419
425
out_ids.clear ();
420
426
427
+ n_used = 0 ;
428
+
421
429
used.clear ();
422
430
used.resize (get_n_tokens (), false );
423
431
@@ -442,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
442
450
idxs.push_back (cur_idx);
443
451
444
452
used[cur_idx] = true ;
453
+ ++n_used;
445
454
446
455
++cur_idx;
447
456
@@ -458,6 +467,12 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
458
467
}
459
468
460
469
llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch, bool sequential) {
470
+ if (sequential && has_cpl) {
471
+ LLAMA_LOG_ERROR (" %s: sequential split is not supported when there are coupled sequences in the input batch\n " , __func__);
472
+
473
+ return {};
474
+ }
475
+
461
476
std::vector<seq_set_t > cur_seq_set;
462
477
463
478
llama_seq_id last_seq_id = -1 ;
@@ -536,6 +551,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential)
536
551
idxs_per_seq[s].push_back (idx);
537
552
538
553
used[idx] = true ;
554
+ ++n_used;
539
555
540
556
++cur_idx[s];
541
557
}
@@ -577,6 +593,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
577
593
idxs.push_back (cur_idx);
578
594
579
595
used[cur_idx] = true ;
596
+ ++n_used;
580
597
581
598
if (idxs.size () >= n_ubatch) {
582
599
break ;
0 commit comments