Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri

return samplers;
}

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
llama_sampler_apply(gsmpl->chain, cur_p);
}
2 changes: 2 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,5 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:

llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);
61 changes: 20 additions & 41 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,56 +370,35 @@ llama_token mtp_speculative_gen_draft(
int32_t n_past,
int32_t last_tok_idx) {

llama_token token_data[] = { id_last };
llama_pos pos_data[] = { n_past };
int32_t n_seq_id_data[] = { 1 };
llama_seq_id seq_id_data_internal[] = { 0 };
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };

llama_batch batch = {
/*.n_tokens = */ 1,
/*.token = */ token_data,
/*.embd = */ nullptr,
/*.pos = */ pos_data,
/*.n_seq_id = */ n_seq_id_data,
/*.seq_id = */ seq_id_data,
/*.logits = */ logits_data
};

return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
//LOG_INF("updating kv cache for n_past: %d\n", n_past);

/*
if (!smpl) {
return -1;
}
else {
common_sampler_sample(smpl, ctx, last_tok_idx, true);
const auto* cur_p = common_sampler_get_candidates(smpl);

//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
//}
llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, id_last, n_past, {0}, true);

const llama_token id = cur_p->data[0].id;
return id;
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);

llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);

cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
}
*/
// LOG_INF("cur_p->size: %d\n", cur_p->size);
cur_p->sorted = false;

common_sampler_apply_chain(smpl, cur_p);

// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;

// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
// smpl will accept the token if it doesn't get rejected by main model later
// common_sampler_accept(smpl, id, true);
llama_batch_free(batch);

//llama_tokens result;
//result.reserve(1);
//result.push_back(id);
//return result;
return id;
}


Expand All @@ -438,4 +417,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
}

tokens.clear();
}
}
8 changes: 4 additions & 4 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ llama_token mtp_speculative_gen_draft(

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);
4 changes: 2 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1454,8 +1454,8 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);

LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);

#ifdef __cplusplus
}
Expand Down
51 changes: 26 additions & 25 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2995,7 +2995,7 @@ void llama_opt_epoch(
callback_eval);
}

llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {

const auto * model = llama_get_model(ctx);
Expand Down Expand Up @@ -3033,6 +3033,12 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,

auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);

if (!gf) {
LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
if (sched) ggml_backend_sched_free(sched);
return;
}

ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it

Expand All @@ -3044,29 +3050,24 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,

ggml_backend_sched_graph_compute(sched, gf); // execute the graph

//struct ggml_tensor * logits_mtp = res_mtp->get_logits();

//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);

//if (logits_mtp) {
// ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
//}
struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result");


llama_token token_id = 0; // The C++ variable to hold the result.

// ggml_backend_tensor_get is the function for GPU->CPU copies.
// We are copying a single 32-bit integer.
ggml_backend_tensor_get(
token_id_tensor,
&token_id, // Pointer to our C++ variable
0, // Starting offset in bytes
sizeof(llama_token) // Number of bytes to copy
);
struct ggml_tensor * logits_mtp = res_mtp->get_logits();

if (logits_mtp) {
float * logits_dest = ctx->get_logits_ith(last_tok_idx);
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
if (backend_res) {
// ggml_backend_tensor_get is the function for GPU->CPU copies.
// We are copying a single 32-bit integer.
ggml_backend_tensor_get(logits_mtp,
logits_dest, // Pointer to our C++ variable
0, // Starting offset in bytes
ggml_nbytes(logits_mtp)); // Number of bytes to copy
} else {
LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
}
} else {
LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
}

ggml_backend_sched_free(sched);

return token_id;
}

}
16 changes: 4 additions & 12 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13950,6 +13950,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
llama_token last_token_id, int n_past
) : llm_graph_context(params) {

const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);

Expand All @@ -13964,8 +13965,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
//llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
auto * inp_attn = build_attn_inp_kv_unified();

ggml_tensor * cur;

// get MTP embedding for last (conventionally sampled) token
// ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
// LLAMA_LOG_INFO("step: '%d'\n", 5641);
Expand All @@ -13979,7 +13978,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {

//ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id);
//ggml_set_no_alloc(ctx0, true);

ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);

Expand All @@ -13994,9 +13993,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {

ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat


cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj

ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj

// now proceed through last layer (skipped in main model)
ggml_tensor * inpSA = cur;
Expand Down Expand Up @@ -14096,14 +14093,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {

cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);

res->t_logits = cur;

ggml_build_forward_expand(gf, res->t_logits);

struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur);
ggml_set_name(token_id_tensor, "mtp_argmax_result");
ggml_build_forward_expand(gf, token_id_tensor);
}
};

Expand Down