Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm

const int32_t ikv = map_layer_ids.at(il);

auto * k = layers[ikv].k;
ggml_tensor * k = layers[ikv].k;

const int64_t n_embd_head = k_cur->ne[0];
const int64_t n_head = k_cur->ne[1];
const int64_t n_tokens = k_cur->ne[2];

const int64_t n_embd_gqa = n_embd_head*n_head;

const int64_t n_tokens = k_cur->ne[2];
// we can merge dims 0 and 1
// TODO: add ggml helper function for this?
GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);

k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);

if (k->ne[2] > 1) {
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
const int64_t n_stream = k->ne[2];

if (n_stream > 1) {
const int64_t kv_size = get_size();

assert(n_embd_gqa == k->ne[0]);
assert(kv_size == k->ne[1]);

// merge the buffer across all streams because the idxs are global
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
}

// store the current K values into the cache
return ggml_set_rows(ctx, k, k_cur, k_idxs);
}

Expand All @@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm

auto * v = layers[ikv].v;

const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
const int64_t n_tokens = v_cur->ne[2];
const int64_t n_embd_head = v_cur->ne[0];
const int64_t n_head = v_cur->ne[1];
const int64_t n_tokens = v_cur->ne[2];

const int64_t n_embd_gqa = n_embd_head*n_head;

v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
// we can merge dims 0 and 1
GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);

const int64_t n_stream = v->ne[2];

// take this branch when FA is enabled (the V cache is not transposed)
if (!v_trans) {
if (v->ne[2] > 1) {
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);

if (n_stream > 1) {
const int64_t kv_size = get_size();

assert(n_embd_gqa == v->ne[0]);
assert(kv_size == v->ne[1]);

// merge the buffer across all streams because the idxs are global
v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
}

return ggml_set_rows(ctx, v, v_cur, v_idxs);
}

if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
// we can merge dims 0, 1 and 2
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
} else {
// otherwise -> make a copy to get contiguous data
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
}

// [TAG_V_CACHE_VARIABLE]
if (n_embd_v_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
if (n_embd_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
}

// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
// in this branch the v_idxs are constructed in such a way that each row is a single head element
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));

v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));

return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
Expand Down
8 changes: 8 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,17 @@ class llama_kv_cache_context : public llama_memory_context_i {
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;

// store k_cur and v_cur in the cache based on the provided head location
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
// - k_idxs [n_tokens]
// - v_cur [n_embd_head_v, n_head_v, n_tokens]
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;

// create destination indices for each head of the current batch for where it would be written in the KV cache
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
// helps understand the implementation logic of cpy_k and cpy_v
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;

Expand Down
Loading
Loading