From bf7f92428378f87d7312f4e0ad09ea01db6ff81d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 28 Oct 2025 19:25:02 +0100 Subject: [PATCH 1/9] llama: store mrope data in KV cell --- src/llama-batch.cpp | 3 +++ src/llama-batch.h | 4 ++++ src/llama-kv-cache.cpp | 24 +++++++++++++++++++++++ src/llama-kv-cells.h | 43 ++++++++++++++++++++++++++++++++++++++---- 4 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 55d89eca0ad94..d8a0e8a474ba2 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -660,6 +660,9 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0; const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur; + // printf("ubatch_add: n_tokens=%d, n_seqs=%d, n_pos_cur=%d, n_embd_all=%lld, n_pos_all=%lld\n", + // n_tokens, n_seqs, n_pos_cur, n_embd_all, n_pos_all); + udata->token .resize(n_tokens); udata->embd .resize(n_embd_all); udata->pos .resize(n_pos_all); diff --git a/src/llama-batch.h b/src/llama-batch.h index 0dc8cebd2a7b3..34f964ef0f963 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -17,6 +17,10 @@ struct llama_ubatch { return b_equal_seqs != 0; } + bool has_mrope() const { + return data->pos.size() == data->token.size()*4; + } + uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment // otherwise address sanitizer complains // TODO: whole_seqs for embeddings? diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index add74391f0c47..f80ade6970b58 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -900,6 +900,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & cells.pos_set(idx, ubatch.pos[i]); + if (ubatch.has_mrope()) { + cells.pos_mrope_set(idx, { + ubatch.pos[i + ubatch.n_tokens], // x + ubatch.pos[i + ubatch.n_tokens*2], // y + ubatch.pos[i + ubatch.n_tokens*3], // t + }); + } + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { cells.seq_add(idx, ubatch.seq_id[i][s]); } @@ -1243,6 +1251,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const llama_pos p1 = ubatch->pos[i]; + // for M-RoPE + llama_kv_pos_mrope p1_mrope; + if (ubatch->has_mrope()) { + p1_mrope.x = ubatch->pos[i + ubatch->n_tokens]; + p1_mrope.y = ubatch->pos[i + ubatch->n_tokens*2]; + p1_mrope.t = ubatch->pos[i + ubatch->n_tokens*3]; + } + const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); for (uint32_t j = 0; j < n_kv; ++j) { @@ -1262,6 +1278,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u continue; } + // M-RoPE causal mask + if (causal_attn && ubatch->has_mrope() && p0 == p1) { + const auto & p0_mrope = cells.pos_mrope_get(j); + if (p0_mrope.is_gt(p1_mrope)) { + continue; + } + } + // apply SWA if any if (is_masked_swa(p0, p1)) { continue; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 8f6bf01456c8f..9c76ca7452cae 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -9,6 +9,18 @@ #include #include +struct llama_kv_pos_mrope { + llama_pos x; + llama_pos y; + llama_pos t; + // return true if this position is greater than the other position + bool is_gt(const llama_kv_pos_mrope & other) const { + return (t > other.t) + || (t == other.t && y > other.y) + || (t == other.t && y == other.y && x > other.x); + } +}; + // meta information about KV cells that can be part of multiple sequences at the same time // TODO: add unit tests class llama_kv_cells { @@ -43,6 +55,7 @@ class llama_kv_cells { void resize(uint32_t n) { pos.resize(n); + pos_mrope.resize(n); shift.resize(n); seq.resize(n); @@ -107,8 +120,9 @@ class llama_kv_cells { for (uint32_t j = 0; j < n; ++j) { const auto idx = i + j; - res.pos[j] = pos[idx]; - res.seq[j] = seq[idx]; + res.pos [j] = pos[idx]; + res.pos_mrope[j] = pos_mrope[idx]; + res.seq [j] = seq[idx]; assert(shift[idx] == 0); } @@ -125,8 +139,9 @@ class llama_kv_cells { for (uint32_t j = 0; j < idxs.size(); ++j) { const auto idx = idxs[j]; - res.pos[j] = pos[idx]; - res.seq[j] = seq[idx]; + res.pos [j] = pos[idx]; + res.pos_mrope[j] = pos_mrope[idx]; + res.seq [j] = seq[idx]; assert(shift[idx] == 0); } @@ -340,6 +355,13 @@ class llama_kv_cells { return pos[i]; } + const llama_kv_pos_mrope & pos_mrope_get(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return pos_mrope[i]; + } + // note: call only if the cell is not empty llama_pos get_shift(uint32_t i) const { assert(i < pos.size()); @@ -368,6 +390,16 @@ class llama_kv_cells { used.insert(i); } + void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) { + assert(i < pos.size()); + assert(pos[i] == -1); + assert(seq[i].none()); + + pos_mrope[i] = p; + + used.insert(i); + } + // pos[i] = pos[i] + d // sets "has_shift" to true // note: call only if the cell is not empty @@ -424,6 +456,9 @@ class llama_kv_cells { std::vector pos; + // stores addition info for M-RoPE positions + std::vector pos_mrope; + // this array accumulates any applied shifts to the pos array since the last reset_shift() call // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: // From 90353eae92e4c928e7185264673bd4c586033bc4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 28 Oct 2025 21:42:38 +0100 Subject: [PATCH 2/9] correct x,y ordering --- src/llama-batch.cpp | 50 +++++++++++++++++------------------------- src/llama-kv-cache.cpp | 10 ++++----- src/llama-kv-cells.h | 16 ++++---------- tools/mtmd/mtmd.cpp | 13 ++++++++++- tools/mtmd/mtmd.h | 4 ++-- 5 files changed, 42 insertions(+), 51 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index d8a0e8a474ba2..5f4e71c85b801 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -251,46 +251,39 @@ bool llama_batch_allocr::init( // consistency checks // - for (uint32_t s = 0; s < n_seq_max; ++s) { - if (seq_pos[s].empty()) { - continue; - } + // TODO @ngxson : we currently can't check M-RoPE positions, as the position is increased based on image size + if (n_pos_per_embd == 1) { + for (uint32_t s = 0; s < n_seq_max; ++s) { + if (seq_pos[s].empty()) { + continue; + } - const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1; + const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1; - if (p0 >= 0) { - bool ok = true; + if (p0 >= 0) { + bool ok = true; - if (batch.token) { if (seq_pos_min(s) != p0 + 1) { ok = false; } - } else { - assert(batch.embd); - // for embeddings (typically used as vision input), we allow them to have repeating positions - // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 - if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) { - ok = false; + if (!ok) { + LLAMA_LOG_ERROR( + "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" + " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" + " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" + " it is required that the sequence positions remain consecutive: Y = X + 1\n", + __func__, s, s, p0, s, seq_pos_min(s)); + + return false; } } - if (!ok) { - LLAMA_LOG_ERROR( - "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" - " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" - " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" - " it is required that the sequence positions remain consecutive: Y = X + 1\n", - __func__, s, s, p0, s, seq_pos_min(s)); - + if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { + LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s); return false; } } - - if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { - LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s); - return false; - } } if (memory) { @@ -660,9 +653,6 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0; const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur; - // printf("ubatch_add: n_tokens=%d, n_seqs=%d, n_pos_cur=%d, n_embd_all=%lld, n_pos_all=%lld\n", - // n_tokens, n_seqs, n_pos_cur, n_embd_all, n_pos_all); - udata->token .resize(n_tokens); udata->embd .resize(n_embd_all); udata->pos .resize(n_pos_all); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index f80ade6970b58..e49ec19451496 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -902,9 +902,8 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & if (ubatch.has_mrope()) { cells.pos_mrope_set(idx, { - ubatch.pos[i + ubatch.n_tokens], // x - ubatch.pos[i + ubatch.n_tokens*2], // y - ubatch.pos[i + ubatch.n_tokens*3], // t + ubatch.pos[i + ubatch.n_tokens], // y + ubatch.pos[i + ubatch.n_tokens*2], // x }); } @@ -1254,9 +1253,8 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // for M-RoPE llama_kv_pos_mrope p1_mrope; if (ubatch->has_mrope()) { - p1_mrope.x = ubatch->pos[i + ubatch->n_tokens]; - p1_mrope.y = ubatch->pos[i + ubatch->n_tokens*2]; - p1_mrope.t = ubatch->pos[i + ubatch->n_tokens*3]; + p1_mrope.y = ubatch->pos[i + ubatch->n_tokens]; + p1_mrope.x = ubatch->pos[i + ubatch->n_tokens*2]; } const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 9c76ca7452cae..bcb08587deb72 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -10,14 +10,11 @@ #include struct llama_kv_pos_mrope { - llama_pos x; - llama_pos y; - llama_pos t; + llama_pos y = 0; + llama_pos x = 0; // return true if this position is greater than the other position bool is_gt(const llama_kv_pos_mrope & other) const { - return (t > other.t) - || (t == other.t && y > other.y) - || (t == other.t && y == other.y && x > other.x); + return (y > other.y) || (y == other.y && x > other.x); } }; @@ -391,13 +388,8 @@ class llama_kv_cells { } void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) { - assert(i < pos.size()); - assert(pos[i] == -1); - assert(seq[i].none()); - + assert(i < pos_mrope.size()); pos_mrope[i] = p; - - used.insert(i); } // pos[i] = pos[i] + d diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 3b901bfac8215..1a48f903efac9 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -5,6 +5,15 @@ #include "llama.h" +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + #include #include #include @@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { if (image_tokens->use_mrope_pos) { - return 1; // for M-RoPE, the whole image is 1 in temporal dimension + // for M-RoPE, temporal dimension = max(t,h,w) + // t is omitted as we don't support video input + return std::max(image_tokens->nx, image_tokens->ny); } return image_tokens->n_tokens(); } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index f4ea07d3ad521..0b5d2ba0c7634 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk); // returns nullptr for ID on text chunk MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk); -// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) +// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise) MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk); // in case you want to use custom logic to handle the chunk (i.e. KV cache management) @@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens); MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens); MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate -// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) +// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise) MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate // tokenize an input text prompt and a list of bitmaps (images/audio) From c3e1393f63d92937614d13dc024e0608a45cdcc1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 29 Oct 2025 11:16:23 +0100 Subject: [PATCH 3/9] address review comments --- src/llama-batch.cpp | 3 +++ src/llama-batch.h | 13 ++++++++++--- src/llama-kv-cache.cpp | 23 ++++++++++------------- src/llama-kv-cells.h | 40 +++++++++++++++++++++------------------- 4 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 5f4e71c85b801..59fc7890a75bf 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -215,6 +215,7 @@ bool llama_batch_allocr::init( /*.n_seq_tokens =*/ (uint32_t) 1, /*.n_seqs =*/ (uint32_t) batch.n_tokens, /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(), + /*.n_pos =*/ n_pos_per_embd, /*.token =*/ batch.token, /*.embd =*/ batch.embd, /*.pos =*/ batch.pos, @@ -382,6 +383,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.n_seqs_unq =*/ n_seqs, + /*.n_pos =*/ n_pos_per_embd, /*.token =*/ udata->token.data(), /*.embd =*/ nullptr, @@ -703,6 +705,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u /*.n_seq_tokens =*/ n_tokens/n_seqs, /*.n_seqs =*/ n_seqs, /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(), + /*.n_pos =*/ n_pos_per_embd, /*.token =*/ batch.token ? udata->token.data() : nullptr, /*.embd =*/ batch.embd ? udata->embd.data() : nullptr, diff --git a/src/llama-batch.h b/src/llama-batch.h index 34f964ef0f963..209cf3699de23 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -17,8 +17,14 @@ struct llama_ubatch { return b_equal_seqs != 0; } - bool has_mrope() const { - return data->pos.size() == data->token.size()*4; + // typical for M-RoPE cases: + // 0 - sequantial position of the tokens/embeddings in the sequence + // 1 - y position in the image + // 2 - x position in the image + // 3 - other + bool is_pos_2d() const { + // TODO @ngxson : we may need to check for model arch when more models use >1 positions + return n_pos >= 3; } uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment @@ -29,6 +35,7 @@ struct llama_ubatch { uint32_t n_seq_tokens; // tokens per sequence set uint32_t n_seqs; // sequence sets in the ubatch uint32_t n_seqs_unq; // unique sequence ids in the ubatch + uint32_t n_pos; // number of position inputs for each token/embedding // seq_id_unq: unique sequence ids in the ubatch // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq) @@ -37,7 +44,7 @@ struct llama_ubatch { // // size | idx | val llama_token * token; // [n_tokens] | i | id, token float * embd; // [n_embd, n_tokens] | i | embd - llama_pos * pos; // [n_tokens] | i | pos + llama_pos * pos; // [n_tokens*n_pos] | i | pos int32_t * n_seq_id; // [n_tokens] | i | - llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index e49ec19451496..1a36f7a7b8508 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -900,11 +900,11 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & cells.pos_set(idx, ubatch.pos[i]); - if (ubatch.has_mrope()) { - cells.pos_mrope_set(idx, { - ubatch.pos[i + ubatch.n_tokens], // y - ubatch.pos[i + ubatch.n_tokens*2], // x - }); + if (ubatch.is_pos_2d()) { + llama_kv_cell_ext ext; + ext.x = ubatch.pos[i + ubatch.n_tokens*2]; + ext.y = ubatch.pos[i + ubatch.n_tokens]; + cells.ext_set(idx, std::move(ext)); } for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { @@ -1251,11 +1251,8 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const llama_pos p1 = ubatch->pos[i]; // for M-RoPE - llama_kv_pos_mrope p1_mrope; - if (ubatch->has_mrope()) { - p1_mrope.y = ubatch->pos[i + ubatch->n_tokens]; - p1_mrope.x = ubatch->pos[i + ubatch->n_tokens*2]; - } + llama_pos p1_x = ubatch->pos[i + ubatch->n_tokens*2]; + llama_pos p1_y = ubatch->pos[i + ubatch->n_tokens]; const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); @@ -1277,9 +1274,9 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u } // M-RoPE causal mask - if (causal_attn && ubatch->has_mrope() && p0 == p1) { - const auto & p0_mrope = cells.pos_mrope_get(j); - if (p0_mrope.is_gt(p1_mrope)) { + if (causal_attn && ubatch->is_pos_2d() && p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + if (p0_ext.is_2d_gt(p1_x, p1_y)) { continue; } } diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index bcb08587deb72..3e5437c408960 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -9,12 +9,14 @@ #include #include -struct llama_kv_pos_mrope { - llama_pos y = 0; +struct llama_kv_cell_ext { + // 2D spatial positions, typically used for M-RoPE llama_pos x = 0; - // return true if this position is greater than the other position - bool is_gt(const llama_kv_pos_mrope & other) const { - return (y > other.y) || (y == other.y && x > other.x); + llama_pos y = 0; + + // return true if the current 2D spatial position is greater than other + bool is_2d_gt(llama_pos ox, llama_pos oy) const { + return (y > oy) || (y == oy && x > ox); } }; @@ -52,7 +54,7 @@ class llama_kv_cells { void resize(uint32_t n) { pos.resize(n); - pos_mrope.resize(n); + ext.resize(n); shift.resize(n); seq.resize(n); @@ -117,9 +119,9 @@ class llama_kv_cells { for (uint32_t j = 0; j < n; ++j) { const auto idx = i + j; - res.pos [j] = pos[idx]; - res.pos_mrope[j] = pos_mrope[idx]; - res.seq [j] = seq[idx]; + res.pos[j] = pos[idx]; + res.ext[j] = ext[idx]; + res.seq[j] = seq[idx]; assert(shift[idx] == 0); } @@ -136,9 +138,9 @@ class llama_kv_cells { for (uint32_t j = 0; j < idxs.size(); ++j) { const auto idx = idxs[j]; - res.pos [j] = pos[idx]; - res.pos_mrope[j] = pos_mrope[idx]; - res.seq [j] = seq[idx]; + res.pos[j] = pos[idx]; + res.ext[j] = ext[idx]; + res.seq[j] = seq[idx]; assert(shift[idx] == 0); } @@ -352,11 +354,11 @@ class llama_kv_cells { return pos[i]; } - const llama_kv_pos_mrope & pos_mrope_get(uint32_t i) const { + const llama_kv_cell_ext & ext_get(uint32_t i) const { assert(i < pos.size()); assert(pos[i] != -1); - return pos_mrope[i]; + return ext[i]; } // note: call only if the cell is not empty @@ -387,9 +389,9 @@ class llama_kv_cells { used.insert(i); } - void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) { - assert(i < pos_mrope.size()); - pos_mrope[i] = p; + void ext_set(uint32_t i, llama_kv_cell_ext && p) { + assert(i < ext.size()); + ext[i] = std::move(p); } // pos[i] = pos[i] + d @@ -448,8 +450,8 @@ class llama_kv_cells { std::vector pos; - // stores addition info for M-RoPE positions - std::vector pos_mrope; + // stores extra info per cell + std::vector ext; // this array accumulates any applied shifts to the pos array since the last reset_shift() call // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: From 9102a7cc598a3550bf5f32ffcd1911cd41ba8940 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 29 Oct 2025 11:34:42 +0100 Subject: [PATCH 4/9] add consistency checks --- src/llama-batch.cpp | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 59fc7890a75bf..6cb118f684e40 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -252,8 +252,27 @@ bool llama_batch_allocr::init( // consistency checks // - // TODO @ngxson : we currently can't check M-RoPE positions, as the position is increased based on image size - if (n_pos_per_embd == 1) { + if (n_pos_per_embd > 1) { + // M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed) + for (uint32_t s = 0; s < n_seq_max; ++s) { + if (seq_pos[s].empty()) { + continue; + } + + const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1; + + if (p0 >= 0 && p0 >= seq_pos_min(s)) { + LLAMA_LOG_ERROR( + "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" + " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" + " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" + " for M-RoPE, it is required that the position satisfies: X < Y\n", + __func__, s, s, p0, s, seq_pos_min(s)); + + return false; + } + } + } else { for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_pos[s].empty()) { continue; From f7063582c67bbd2be24285902315889f7d73609b Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 29 Oct 2025 14:20:11 +0100 Subject: [PATCH 5/9] Update src/llama-kv-cache.cpp Co-authored-by: Georgi Gerganov --- src/llama-kv-cache.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index ada3086d6e34c..27d168f5fc73f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -901,9 +901,10 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & cells.pos_set(idx, ubatch.pos[i]); if (ubatch.is_pos_2d()) { - llama_kv_cell_ext ext; - ext.x = ubatch.pos[i + ubatch.n_tokens*2]; - ext.y = ubatch.pos[i + ubatch.n_tokens]; + llama_kv_cell_ext ext { + /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2], + /*.y =*/ ubatch.pos[i + ubatch.n_tokens], + }; cells.ext_set(idx, std::move(ext)); } From 18842b60c5d0dcddb07f31923666e331456663f4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 29 Oct 2025 14:24:02 +0100 Subject: [PATCH 6/9] add TODO --- src/llama-kv-cache.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 27d168f5fc73f..c7b024ac13c45 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1579,6 +1579,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t io.write(&pos, sizeof(pos)); io.write(&n_seq_id, sizeof(n_seq_id)); + // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it + // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 + for (const auto & seq_id : seq_ids) { io.write(&seq_id, sizeof(seq_id)); } @@ -1724,6 +1727,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 return false; } + // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet + // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 apply_ubatch(sinfo, ubatch); const auto head_cur = sinfo.head(); From 5ec41a107f913a8afa14c0b65b33856130d39207 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 29 Oct 2025 14:25:22 +0100 Subject: [PATCH 7/9] fix asan error --- src/llama-kv-cache.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c7b024ac13c45..f2511faebef6d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1256,8 +1256,9 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const llama_pos p1 = ubatch->pos[i]; // for M-RoPE - llama_pos p1_x = ubatch->pos[i + ubatch->n_tokens*2]; - llama_pos p1_y = ubatch->pos[i + ubatch->n_tokens]; + const bool is_2d = ubatch->is_pos_2d(); + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); @@ -1279,7 +1280,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u } // M-RoPE causal mask - if (causal_attn && ubatch->is_pos_2d() && p0 == p1) { + if (causal_attn && is_2d && p0 == p1) { const auto & p0_ext = cells.ext_get(j); if (p0_ext.is_2d_gt(p1_x, p1_y)) { continue; From bed0f57fa23d5d88ac2dfd74ad9805e85b546273 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Oct 2025 15:52:49 +0200 Subject: [PATCH 8/9] kv-cells : improve ext handling --- src/llama-kv-cache.cpp | 8 +++++++- src/llama-kv-cells.h | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index f2511faebef6d..17627b6ccbb1e 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll llama_pos pos = v_cells[s0].pos_get(i); llama_pos shift = v_cells[s0].get_shift(i); + llama_kv_cell_ext ext = v_cells[s0].ext_get(i); + if (shift != 0) { pos -= shift; assert(pos >= 0); @@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll if (shift != 0) { v_cells[s1].pos_add(i, shift); } + + v_cells[s1].ext_set(i, ext); } } @@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1"); auto & cells = v_cells[seq_to_stream[seq_id]]; auto & head = v_heads[seq_to_stream[seq_id]]; @@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1"); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -905,7 +911,7 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2], /*.y =*/ ubatch.pos[i + ubatch.n_tokens], }; - cells.ext_set(idx, std::move(ext)); + cells.ext_set(idx, ext); } for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 3e5437c408960..94374ed88d4c3 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -18,6 +18,12 @@ struct llama_kv_cell_ext { bool is_2d_gt(llama_pos ox, llama_pos oy) const { return (y > oy) || (y == oy && x > ox); } + + void reset() { + static_assert(std::is_trivially_copyable_v); + + memset(this, 0, sizeof(*this)); + } }; // meta information about KV cells that can be part of multiple sequences at the same time @@ -27,6 +33,7 @@ class llama_kv_cells { void reset() { for (uint32_t i = 0; i < pos.size(); ++i) { pos[i] = -1; + ext[i].reset(); shift[i] = 0; seq[i].reset(); } @@ -168,6 +175,7 @@ class llama_kv_cells { } pos[idx] = other.pos[j]; + ext[idx] = other.ext[j]; seq[idx] = other.seq[j]; if (pos[idx] != -1) { @@ -198,6 +206,7 @@ class llama_kv_cells { } pos[idx] = other.pos[j]; + ext[idx] = other.ext[j]; seq[idx] = other.seq[j]; if (pos[idx] != -1) { @@ -217,6 +226,7 @@ class llama_kv_cells { seq[i].reset(); pos[i] = -1; + ext[i].reset(); shift[i] = 0; used.erase(i); @@ -235,6 +245,7 @@ class llama_kv_cells { if (seq[i].none()) { pos[i] = -1; + ext[i].reset(); shift[i] = 0; used.erase(i); @@ -264,6 +275,7 @@ class llama_kv_cells { seq[i].reset(); pos[i] = -1; + ext[i].reset(); shift[i] = 0; used.erase(i); @@ -389,9 +401,9 @@ class llama_kv_cells { used.insert(i); } - void ext_set(uint32_t i, llama_kv_cell_ext && p) { + void ext_set(uint32_t i, llama_kv_cell_ext p) { assert(i < ext.size()); - ext[i] = std::move(p); + ext[i] = p; } // pos[i] = pos[i] + d From 45d60e17aeb6145ffcf1f6b92102afed2b79e59e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Oct 2025 16:13:05 +0200 Subject: [PATCH 9/9] cont : fix headers --- src/llama-kv-cells.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 94374ed88d4c3..10063bf4272ef 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -5,9 +5,10 @@ #include #include -#include -#include +#include #include +#include +#include struct llama_kv_cell_ext { // 2D spatial positions, typically used for M-RoPE