@@ -119,27 +119,27 @@ bool llama_kv_cache_init(
119119
120120struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
121121 struct llama_kv_cache & cache,
122- const struct llama_ubatch & batch ) {
123- const uint32_t n_tokens = batch .n_tokens;
124- const uint32_t n_seqs = batch .n_seqs;
125- const uint32_t n_seq_tokens = batch .n_seq_tokens;
122+ const struct llama_ubatch & ubatch ) {
123+ const uint32_t n_tokens = ubatch .n_tokens;
124+ const uint32_t n_seqs = ubatch .n_seqs;
125+ const uint32_t n_seq_tokens = ubatch .n_seq_tokens;
126126
127127 if (cache.recurrent) {
128128 // For recurrent state architectures (like Mamba or RWKV),
129129 // each cache cell can store the state for a whole sequence.
130130 // A slot should be always be contiguous.
131131
132132 // can only process batches with an equal number of new tokens in each sequence
133- GGML_ASSERT(batch .equal_seqs);
133+ GGML_ASSERT(ubatch .equal_seqs);
134134
135135 int32_t min = cache.size - 1;
136136 int32_t max = 0;
137137
138138 // everything should fit if all seq_ids are smaller than the max
139139 for (uint32_t s = 0; s < n_seqs; ++s) {
140- const uint32_t n_seq_id = batch .n_seq_id[s];
140+ const uint32_t n_seq_id = ubatch .n_seq_id[s];
141141 for (uint32_t j = 0; j < n_seq_id; ++j) {
142- const llama_seq_id seq_id = batch .seq_id[s][j];
142+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
143143
144144 if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
145145 // too big seq_id
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
198198
199199 // find usable cell range
200200 for (uint32_t s = 0; s < n_seqs; ++s) {
201- const llama_seq_id seq_id = batch .seq_id[s][0];
201+ const llama_seq_id seq_id = ubatch .seq_id[s][0];
202202 llama_kv_cell & seq_meta = cache.cells[seq_id];
203203 bool has_cell = false;
204204 if (seq_meta.tail >= 0) {
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237237 // gather and re-order
238238 for (uint32_t s = 0; s < n_seqs; ++s) {
239239 int32_t dst_id = s + min;
240- int32_t src_id = cache.cells[batch .seq_id[s][0]].tail;
240+ int32_t src_id = cache.cells[ubatch .seq_id[s][0]].tail;
241241 if (dst_id != src_id) {
242242 llama_kv_cell & dst_cell = cache.cells[dst_id];
243243 llama_kv_cell & src_cell = cache.cells[src_id];
@@ -258,20 +258,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
258258
259259 // update the pos of the used seqs
260260 for (uint32_t s = 0; s < n_seqs; ++s) {
261- const llama_pos last_pos = batch .pos[n_seq_tokens * s + n_seq_tokens - 1];
261+ const llama_pos last_pos = ubatch .pos[n_seq_tokens * s + n_seq_tokens - 1];
262262 int32_t cell_id = s + min;
263263 llama_kv_cell & cell = cache.cells[cell_id];
264264
265265 if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266266 // What should happen when the pos backtracks or skips a value?
267267 // Clearing the state mid-batch would require special-casing which isn't done.
268268 LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
269- __func__, last_pos, cell.pos, batch .seq_id[s][0], n_seq_tokens);
269+ __func__, last_pos, cell.pos, ubatch .seq_id[s][0], n_seq_tokens);
270270 }
271271 cell.pos = last_pos;
272272 cell.seq_id.clear();
273- for (int32_t j = 0; j < batch .n_seq_id[s]; ++j) {
274- const llama_seq_id seq_id = batch .seq_id[s][j];
273+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; ++j) {
274+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
275275 cell.seq_id.insert(seq_id);
276276 cache.cells[seq_id].tail = cell_id;
277277 }
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
325325 for (uint32_t s = 0; s < n_seqs; s++) {
326326 for (uint32_t i = 0; i < n_seq_tokens; ++i) {
327327 uint32_t k = s*n_seq_tokens + i;
328- cache.cells[cache.head + k].pos = batch .pos[k];
328+ cache.cells[cache.head + k].pos = ubatch .pos[k];
329329
330- for (int32_t j = 0; j < batch .n_seq_id[s]; j++) {
331- cache.cells[cache.head + k].seq_id.insert(batch .seq_id[s][j]);
330+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; j++) {
331+ cache.cells[cache.head + k].seq_id.insert(ubatch .seq_id[s][j]);
332332 }
333333 }
334334 }
0 commit comments