@@ -55,6 +55,11 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
5555
5656// helper struct to make working with embd batch easier
5757// note: this will be removed after llama_batch_ext refactoring
58+ // notes2: Normally, batch's `pos` stores linearly increasing position
59+ // However, some multi-modal models requires special position embedding (e.g. M-Rope in qwen2vl and qwen2.5vl)
60+ // But linearly increasing position is still needed for proper causal attention masking
61+ // So we store both of them: the first n_tokens elements are not changed, while model-specific positions are appended after that.
62+ // So `pos` has `n_tokens * (n_pos_per_embd + 1)` elements
5863struct decode_embd_batch {
5964 int n_pos_per_embd;
6065 int n_mmproj_embd;
@@ -66,7 +71,7 @@ struct decode_embd_batch {
6671 std::vector<int8_t > logits;
6772 llama_batch batch;
6873 decode_embd_batch (float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
69- pos .resize (n_tokens * n_pos_per_embd);
74+ pos .resize (n_tokens * ( n_pos_per_embd + 1 ) );
7075 n_seq_id.resize (n_tokens);
7176 seq_ids .resize (n_tokens + 1 );
7277 logits .resize (n_tokens);
@@ -100,13 +105,14 @@ struct decode_embd_batch {
100105 for (int y = 0 ; y < ny; y++) {
101106 for (int x = 0 ; x < nx; x++) {
102107 int i = y * nx + x;
103- pos[i ] = pos_0;
104- pos[i + batch.n_tokens ] = pos_0 + y;
105- pos[i + batch.n_tokens * 2 ] = pos_0 + x;
106- pos[i + batch.n_tokens * 3 ] = 0 ; // last pos dim is unused
108+ pos[i + batch. n_tokens ] = pos_0;
109+ pos[i + batch.n_tokens * 2 ] = pos_0 + y;
110+ pos[i + batch.n_tokens * 3 ] = pos_0 + x;
111+ pos[i + batch.n_tokens * 4 ] = 0 ; // last pos dim is unused
107112 }
108113 }
109114 for (int i = 0 ; i < batch.n_tokens ; i++) {
115+ batch.pos [i] = pos_0 + i;
110116 batch.n_seq_id [i] = 1 ;
111117 batch.seq_id [i] = seq_id_0.data ();
112118 batch.logits [i] = false ;
@@ -118,12 +124,13 @@ struct decode_embd_batch {
118124 GGML_ASSERT (n_pos_per_embd == 4 );
119125 seq_id_0[0 ] = seq_id;
120126 for (int i = 0 ; i < batch.n_tokens ; i++) {
121- pos[i ] = pos_0 + i;
122127 pos[i + batch.n_tokens ] = pos_0 + i;
123128 pos[i + batch.n_tokens * 2 ] = pos_0 + i;
124- pos[i + batch.n_tokens * 3 ] = 0 ; // last pos dim is unused
129+ pos[i + batch.n_tokens * 3 ] = pos_0 + i;
130+ pos[i + batch.n_tokens * 4 ] = 0 ; // last pos dim is unused
125131 }
126132 for (int i = 0 ; i < batch.n_tokens ; i++) {
133+ batch.pos [i] = pos_0 + i;
127134 batch.n_seq_id [i] = 1 ;
128135 batch.seq_id [i] = seq_id_0.data ();
129136 batch.logits [i] = false ;
@@ -133,12 +140,12 @@ struct decode_embd_batch {
133140 llama_batch get_view (int offset, int n_tokens) {
134141 llama_pos * pos_ptr;
135142 pos_view.clear ();
136- pos_view.reserve (n_tokens * n_pos_per_embd);
143+ pos_view.reserve (n_tokens * ( n_pos_per_embd + 1 ) );
137144 if (n_pos_per_embd > 1 ) {
138145 // mrope
139146 // for example, with layout of src: 1234...1234...1234...1234...
140147 // offset 2 will give us dst: 34...34...34...34...
141- for (int i = 0 ; i < n_pos_per_embd; i++) {
148+ for (int i = 0 ; i <= n_pos_per_embd; i++) {
142149 // assume n_tokens is less than or equal to batch.n_tokens
143150 // batch.n_tokens is number of **total** tokens
144151 // n_tokens is number of viewed token
0 commit comments