Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {

float * data = (float *) kq_mask->data;

// [TAG_NO_CACHE_ISWA]
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");

for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
Expand All @@ -315,9 +318,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
continue; // skip future tokens for causal attention
}

if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
continue; // skip masked tokens for SWA
}
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
// continue; // skip masked tokens for SWA
//}

// TODO: reimplement this like in llama_kv_cache_unified
if (hparams.use_alibi) {
Expand Down
2 changes: 1 addition & 1 deletion src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ uint32_t llama_hparams::n_layer_kv() const {
return res;
}

bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);

switch (swa_type) {
Expand Down
5 changes: 4 additions & 1 deletion src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ struct llama_hparams {
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;

bool is_masked_swa(llama_pos p0, llama_pos p1) const;
// note that this function uses different SWA parameters from those in the hparams
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this to a static function.

Maybe it should become a member like this:

Suggested change
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
bool is_masked_swa(uint32_t il, llama_pos p0, llama_pos p1) const;

But let's refactor this after the master stabilized.

};

static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
kv_base = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, filter_base, reuse);
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);

LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);

kv_swa = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, filter_swa, reuse);
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
}

void llama_kv_cache_iswa::clear(bool data) {
Expand Down
5 changes: 3 additions & 2 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) {
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {

GGML_ASSERT(kv_size % n_pad == 0);

Expand Down Expand Up @@ -1392,7 +1393,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
}

bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
return hparams.is_masked_swa(p0, p1);
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
}

void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
Expand Down
4 changes: 4 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class llama_kv_cache : public llama_memory_i {
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);

Expand Down Expand Up @@ -211,6 +212,9 @@ class llama_kv_cache : public llama_memory_i {
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;

// this is the SWA type of the cache - not to be confused with the model SWA type
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;

Expand Down
2 changes: 2 additions & 0 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ llama_memory_hybrid::llama_memory_hybrid(
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
Expand All @@ -40,6 +41,7 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max,
n_pad,
n_swa,
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
Expand Down
1 change: 1 addition & 0 deletions src/llama-memory-hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class llama_memory_hybrid : public llama_memory_i {
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
Expand Down
7 changes: 5 additions & 2 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11084,7 +11084,8 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_no_cache();
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
auto * inp_attn = build_attn_inp_kv_iswa();

ggml_tensor * inp_out_ids = build_inp_out_ids();

Expand Down Expand Up @@ -18632,7 +18633,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_GEMMA_EMBEDDING:
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
{
Expand Down Expand Up @@ -18681,6 +18682,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
Expand Down Expand Up @@ -18750,6 +18752,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
padding,
hparams.n_swa,
hparams.swa_type,
nullptr,
nullptr);
}
Expand Down
Loading