diff --git a/common/sampling.cpp b/common/sampling.cpp index a5824ebeedbaa..452cefee3b9ac 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -582,3 +582,7 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { + llama_sampler_apply(gsmpl->chain, cur_p); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 2064421db4e80..b424d7d6d70ca 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -105,3 +105,5 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index c1d9149ea13d2..77ed75913d5c7 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -370,56 +370,35 @@ llama_token mtp_speculative_gen_draft( int32_t n_past, int32_t last_tok_idx) { - llama_token token_data[] = { id_last }; - llama_pos pos_data[] = { n_past }; - int32_t n_seq_id_data[] = { 1 }; - llama_seq_id seq_id_data_internal[] = { 0 }; - llama_seq_id* seq_id_data[] = {seq_id_data_internal}; - int8_t logits_data[] = { (int8_t) (smpl != nullptr) }; - - llama_batch batch = { - /*.n_tokens = */ 1, - /*.token = */ token_data, - /*.embd = */ nullptr, - /*.pos = */ pos_data, - /*.n_seq_id = */ n_seq_id_data, - /*.seq_id = */ seq_id_data, - /*.logits = */ logits_data - }; - - return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - //LOG_INF("updating kv cache for n_past: %d\n", n_past); - - /* if (!smpl) { return -1; } - else { - common_sampler_sample(smpl, ctx, last_tok_idx, true); - const auto* cur_p = common_sampler_get_candidates(smpl); - //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { - // LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - // k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - //} + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, id_last, n_past, {0}, true); - const llama_token id = cur_p->data[0].id; - return id; + llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_n_vocab(vocab); + + llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + + cur_p->size = n_vocab; + for (int i = 0; i < n_vocab; ++i) { + cur_p->data[i].id = i; + cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; } - */ - // LOG_INF("cur_p->size: %d\n", cur_p->size); + cur_p->sorted = false; + common_sampler_apply_chain(smpl, cur_p); - // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; - // skip accepting draft token -- since we're only drafting one token this can't affect future outputs - // smpl will accept the token if it doesn't get rejected by main model later - // common_sampler_accept(smpl, id, true); + llama_batch_free(batch); - //llama_tokens result; - //result.reserve(1); - //result.push_back(id); - //return result; + return id; } @@ -438,4 +417,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); diff --git a/include/llama.h b/include/llama.h index 015c777763bf6..e43cd83468d0f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1454,8 +1454,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1f04b72145b28..fb285a8d297c9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2995,7 +2995,7 @@ void llama_opt_epoch( callback_eval); } -llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, +void llama_build_and_execute_mtp_graph(struct llama_context * ctx, const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); @@ -3033,6 +3033,12 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); + if (!gf) { + LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); + if (sched) ggml_backend_sched_free(sched); + return; + } + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it @@ -3044,29 +3050,24 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, ggml_backend_sched_graph_compute(sched, gf); // execute the graph - //struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - - //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); - - //if (logits_mtp) { - // ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); - //} - struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result"); - - - llama_token token_id = 0; // The C++ variable to hold the result. - - // ggml_backend_tensor_get is the function for GPU->CPU copies. - // We are copying a single 32-bit integer. - ggml_backend_tensor_get( - token_id_tensor, - &token_id, // Pointer to our C++ variable - 0, // Starting offset in bytes - sizeof(llama_token) // Number of bytes to copy - ); + struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + + if (logits_mtp) { + float * logits_dest = ctx->get_logits_ith(last_tok_idx); + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); + if (backend_res) { + // ggml_backend_tensor_get is the function for GPU->CPU copies. + // We are copying a single 32-bit integer. + ggml_backend_tensor_get(logits_mtp, + logits_dest, // Pointer to our C++ variable + 0, // Starting offset in bytes + ggml_nbytes(logits_mtp)); // Number of bytes to copy + } else { + LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); + } + } else { + LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); + } ggml_backend_sched_free(sched); - - return token_id; -} - +} \ No newline at end of file diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f9921e4b6d448..dd4bf211b7e94 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13950,6 +13950,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization llama_token last_token_id, int n_past ) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13964,8 +13965,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; auto * inp_attn = build_attn_inp_kv_unified(); - ggml_tensor * cur; - // get MTP embedding for last (conventionally sampled) token // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); // LLAMA_LOG_INFO("step: '%d'\n", 5641); @@ -13979,7 +13978,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); //ggml_set_no_alloc(ctx0, true); - + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); @@ -13994,9 +13993,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat - - cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj - + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; @@ -14096,14 +14093,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - + res->t_logits = cur; - ggml_build_forward_expand(gf, res->t_logits); - - struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur); - ggml_set_name(token_id_tensor, "mtp_argmax_result"); - ggml_build_forward_expand(gf, token_id_tensor); } };