Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri

return samplers;
}

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
llama_sampler_apply(gsmpl->chain, cur_p);
}
2 changes: 2 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,5 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:

llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);
117 changes: 65 additions & 52 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,79 +363,92 @@ llama_tokens common_speculative_gen_draft(
}


llama_token mtp_speculative_gen_draft(
llama_tokens mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {

llama_token token_data[] = { id_last };
llama_pos pos_data[] = { n_past };
int32_t n_seq_id_data[] = { 1 };
llama_seq_id seq_id_data_internal[] = { 0 };
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };

llama_batch batch = {
/*.n_tokens = */ 1,
/*.token = */ token_data,
/*.embd = */ nullptr,
/*.pos = */ pos_data,
/*.n_seq_id = */ n_seq_id_data,
/*.seq_id = */ seq_id_data,
/*.logits = */ logits_data
};
int32_t last_tok_idx,
int32_t n_mtp_draft) {

llama_tokens draft_tokens;
draft_tokens.reserve(n_mtp_draft);

llama_token current_token = id_last;
int32_t current_n_past = n_past;

float* prev_embedding_data = llama_get_embeddings_ith(ctx, last_tok_idx);
LOG_DBG("\n--- MTP total draft %d ---\n", n_mtp_draft);

// The same layer will draft multiple tokens before being validated
for (int i = 0; i < n_mtp_draft; ++i) {
if (prev_embedding_data == nullptr) {
LOG_DBG("ERROR: prev_embedding_data is null in iteration %d!\n", i);
break;
}
llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, current_token, current_n_past, {0}, true);

return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
float* next_embedding_data = llama_build_and_execute_mtp_graph(
ctx, batch, prev_embedding_data, i
);

/*
if (!smpl) {
return -1;
}
else {
common_sampler_sample(smpl, ctx, last_tok_idx, true);
const auto* cur_p = common_sampler_get_candidates(smpl);
if (next_embedding_data == nullptr) {
LOG_DBG("ERROR: next_embedding_data returned null from graph execution\n", i);
}

//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
//}
// Apply logits + greedy: The main model has not yet selected
// the token as correct, so we cannot apply all samples.
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);

llama_token_data_array* cur_p = common_sampler_get_candidates(smpl);
cur_p->size = n_vocab;
for (int j = 0; j < n_vocab; ++j) {
cur_p->data[j].id = j;
// Place the MTP logits in the first slot of the context's logit buffer.
// This temporary storage is then read by the sampler.
cur_p->data[j].logit = llama_get_logits_ith(ctx, 0)[j];
}
cur_p->sorted = false;
common_sampler_apply_chain(smpl, cur_p);

const llama_token id = cur_p->data[0].id;
return id;
}
*/
// LOG_INF("cur_p->size: %d\n", cur_p->size);
const llama_token new_id = cur_p->data[0].id;

draft_tokens.push_back(new_id);

// add drafted token for each sequence
current_token = new_id;
current_n_past++;
prev_embedding_data = next_embedding_data;

// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
// smpl will accept the token if it doesn't get rejected by main model later
// common_sampler_accept(smpl, id, true);
llama_batch_free(batch);

//llama_tokens result;
//result.reserve(1);
//result.push_back(id);
//return result;
if (!next_embedding_data) {
break;
}
}

return draft_tokens;
}


void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
mtp_kv_update_data token;

if (n_tokens < 0) {
n_tokens = tokens.size();
}

for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
token = tokens[i];
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
for (size_t i = 0; i < std::min((size_t)tokens.size(), n_tokens); ++i) {
mtp_kv_update_data& token = tokens[i];

llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, token.id, token.n_past, {0}, true);

// Broken for now
// mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);

mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
llama_batch_free(batch);
}

tokens.clear();
}
}
13 changes: 7 additions & 6 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,19 @@ void common_speculative_add_replacement_tgt_dft(


// sample up to n_draft tokens and add them to the batch using the draft model
llama_token mtp_speculative_gen_draft(
llama_tokens mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx);
int32_t last_tok_idx,
int32_t n_mtp_draft);

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);
8 changes: 6 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1454,8 +1454,12 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);

LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
LLAMA_API float* llama_build_and_execute_mtp_graph(
struct llama_context * ctx,
const llama_batch batch_inp,
float * prev_embedding_data,
int32_t mtp_head_idx
);

#ifdef __cplusplus
}
Expand Down
78 changes: 48 additions & 30 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2995,16 +2995,19 @@ void llama_opt_epoch(
callback_eval);
}

llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {

float* llama_build_and_execute_mtp_graph(
struct llama_context * ctx,
const llama_batch batch_inp,
float * prev_embedding_data,
int32_t mtp_head_idx
) {
const auto * model = llama_get_model(ctx);

auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);

std::vector<uint32_t> idxs;
idxs.push_back(n_past);
idxs.push_back(batch_inp.pos[0]);
llama_kv_cache_unified::slot_info sinfo = {
/*.s0 =*/ 0,
/*.s1 =*/ 0,
Expand All @@ -3024,49 +3027,64 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
ggml_backend_sched_t sched = params_mtp->sched;

auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);

//if (mctx && !mctx->set_n_kv()) {
// LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
//}
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();

auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
auto * gf = model->build_mtp_graph(*params_mtp, mtp_head_idx);

if (!gf) {
LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
if (sched) ggml_backend_sched_free(sched);
return nullptr;
}

ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it

llama_token token_id = batch_inp.token[0];
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
ggml_backend_tensor_set(mtp_token_id_input, &token_id, 0, sizeof(token_id)); // copy data to the newly allocated graph tensors

ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors

ggml_backend_sched_graph_compute(sched, gf); // execute the graph
if (mtp_prev_embedding_input) {
ggml_backend_tensor_set(mtp_prev_embedding_input, prev_embedding_data, 0,
ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
} else {
LLAMA_LOG_WARN("%s: Could not find 'mtp_prev_embedding_input' tensor in the MTP graph.\n", __func__);
}

//struct ggml_tensor * logits_mtp = res_mtp->get_logits();
ggml_backend_sched_graph_compute(sched, gf); // execute the graph

//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);

//if (logits_mtp) {
// ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
//}
struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result");
struct ggml_tensor * logits_mtp = res_mtp->get_logits();

if (logits_mtp) {
float * logits_dest = llama_get_logits_ith(ctx, 0);
// ggml_backend_tensor_get is the function for GPU->CPU copies.
// We are copying a single 32-bit integer.
ggml_backend_tensor_get(logits_mtp,
logits_dest, // Pointer to our C++ variable
0, // Starting offset in bytes
ggml_nbytes(logits_mtp)); // Number of bytes to copy
} else {
LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
}

llama_token token_id = 0; // The C++ variable to hold the result.
struct ggml_tensor * next_embedding_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_next_embedding_output");
float * next_embedding_data_ptr = nullptr;

// ggml_backend_tensor_get is the function for GPU->CPU copies.
// We are copying a single 32-bit integer.
ggml_backend_tensor_get(
token_id_tensor,
&token_id, // Pointer to our C++ variable
0, // Starting offset in bytes
sizeof(llama_token) // Number of bytes to copy
);
if (next_embedding_tensor) {
if (ctx->mtp_embedding_buffer.size() < ggml_nbytes(next_embedding_tensor)) {
ctx->mtp_embedding_buffer.resize(ggml_nbytes(next_embedding_tensor));
}
ggml_backend_tensor_get(next_embedding_tensor, ctx->mtp_embedding_buffer.data(), 0, ggml_nbytes(next_embedding_tensor));
next_embedding_data_ptr = reinterpret_cast<float *>(ctx->mtp_embedding_buffer.data());
} else {
LLAMA_LOG_ERROR("%s: The MTP graph did not produce an output embedding tensor.\n", __func__);
}

ggml_backend_sched_free(sched);

return token_id;
}

return next_embedding_data_ptr;
}
2 changes: 2 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ struct llama_context {
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);

std::unique_ptr<llama_memory_context_i> mtp_memory_batch(const llama_batch& batch_inp);

std::vector<uint8_t> mtp_embedding_buffer;

private:
llm_graph_params graph_params(
Expand Down
Loading