From aad330a9759e9f951def74895ceda16eac0d2fea Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Mon, 17 Nov 2025 15:45:53 +0800 Subject: [PATCH 1/2] feat: revert the original code before refactoring the multi-stream[1/2]. Signed-off-by: Tao Peng --- xllm/core/framework/model/causal_lm.h | 28 +- xllm/core/framework/model/causal_vlm.h | 14 +- xllm/core/framework/model_context.cpp | 13 +- xllm/core/layers/npu/npu_base_layer.cpp | 38 +- xllm/core/layers/npu/npu_base_layer.h | 10 +- .../npu_deepseek_v2_decoder_layer_impl.cpp | 332 ++++-------------- .../npu/npu_deepseek_v2_decoder_layer_impl.h | 33 +- .../layers/npu/npu_glm4_moe_decoder_layer.cpp | 21 +- .../layers/npu/npu_glm4_moe_decoder_layer.h | 5 +- .../npu/npu_qwen2_decoder_layer_impl.cpp | 41 ++- .../layers/npu/npu_qwen2_decoder_layer_impl.h | 15 +- .../npu/npu_qwen3_decoder_layer_impl.cpp | 155 ++------ .../layers/npu/npu_qwen3_decoder_layer_impl.h | 29 +- .../npu/npu_qwen3_moe_decoder_layer_impl.cpp | 4 +- xllm/core/runtime/acl_graph_executor_impl.cpp | 6 +- xllm/core/runtime/acl_graph_executor_test.cpp | 25 +- xllm/core/runtime/base_executor_impl.cpp | 2 +- xllm/core/runtime/llm_worker_impl.h | 4 +- xllm/models/llm/deepseek_v2.h | 165 ++++----- xllm/models/llm/deepseek_v2_mtp.h | 170 ++++----- xllm/models/llm/embedding_model_base.h | 16 +- xllm/models/llm/glm4_moe.h | 36 +- xllm/models/llm/glm4_moe_mtp.h | 34 +- xllm/models/llm/llama.h | 20 +- xllm/models/llm/llm_model_base.h | 299 +++++++--------- xllm/models/llm/qwen2.h | 7 +- xllm/models/llm/qwen3.h | 235 ++++++------- xllm/models/llm/qwen3_embedding.h | 14 +- xllm/models/llm/qwen3_moe.h | 20 +- xllm/models/vlm/minicpmv.h | 21 +- xllm/models/vlm/qwen2_5_vl.h | 18 +- xllm/models/vlm/qwen3_vl.h | 18 +- xllm/models/vlm/qwen3_vl_moe.h | 18 +- 33 files changed, 709 insertions(+), 1157 deletions(-) mode change 100755 => 100644 xllm/models/vlm/qwen2_5_vl.h mode change 100755 => 100644 xllm/models/vlm/qwen3_vl.h diff --git a/xllm/core/framework/model/causal_lm.h b/xllm/core/framework/model/causal_lm.h index eca725082..8eb9abaa9 100644 --- a/xllm/core/framework/model/causal_lm.h +++ b/xllm/core/framework/model/causal_lm.h @@ -43,11 +43,10 @@ class CausalLM : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] // returns: [num_tokens, hidden_size] - virtual torch::Tensor forward( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& parameters) = 0; + virtual torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& parameters) = 0; // hidden_states: [num_tokens, hidden_size] // seleted_idxes: [num_tokens] @@ -68,9 +67,8 @@ class CausalLM : public torch::nn::Module { virtual layer::LmHead get_lm_head() = 0; virtual void set_lm_head(layer::LmHead& head) = 0; - virtual std::vector get_word_embedding() = 0; - virtual void set_word_embedding( - std::vector& embedding) = 0; + virtual layer::WordEmbedding get_word_embedding() = 0; + virtual void set_word_embedding(layer::WordEmbedding& embedding) = 0; }; template @@ -79,11 +77,10 @@ class CausalLMImpl : public CausalLM { CausalLMImpl(Model model, const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - torch::Tensor forward( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& parameters) override { + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& parameters) override { return model_->forward(tokens, positions, kv_caches, parameters); } @@ -109,12 +106,11 @@ class CausalLMImpl : public CausalLM { void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); }; - std::vector get_word_embedding() override { + layer::WordEmbedding get_word_embedding() override { return model_->get_word_embedding(); }; - void set_word_embedding( - std::vector& embedding) override { + void set_word_embedding(layer::WordEmbedding& embedding) override { model_->set_word_embedding(embedding); }; diff --git a/xllm/core/framework/model/causal_vlm.h b/xllm/core/framework/model/causal_vlm.h index 906b0b818..2f90e47cc 100644 --- a/xllm/core/framework/model/causal_vlm.h +++ b/xllm/core/framework/model/causal_vlm.h @@ -40,11 +40,10 @@ class CausalVLMImpl : public CausalVLM { CausalVLMImpl(Model model, const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - torch::Tensor forward( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& parameters) override { + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& parameters) override { return model_->forward(tokens, positions, kv_caches, parameters); } @@ -68,12 +67,11 @@ class CausalVLMImpl : public CausalVLM { void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); }; - std::vector get_word_embedding() override { + layer::WordEmbedding get_word_embedding() override { return model_->get_word_embedding(); }; - void set_word_embedding( - std::vector& embedding) override { + void set_word_embedding(layer::WordEmbedding& embedding) override { model_->set_word_embedding(embedding); }; diff --git a/xllm/core/framework/model_context.cpp b/xllm/core/framework/model_context.cpp index 1b69d7797..b0b15390a 100644 --- a/xllm/core/framework/model_context.cpp +++ b/xllm/core/framework/model_context.cpp @@ -40,17 +40,8 @@ ModelContext::ModelContext(const ParallelArgs& input_parallel_args, int32_t device_id = tensor_options.device().index(); aclError ret = aclrtSetDevice(device_id); atb::CreateContext(&context_); - std::vector streams; - streams.push_back(c10_npu::getCurrentNPUStream(device_id).stream()); - for (int i = 0; i < 1; ++i) { - aclrtStream sub_stream; - aclError ret = aclrtCreateStream(&sub_stream); - if (ret != ACL_ERROR_NONE) { - ATB_SPEED_LOG_ERROR("Failed to create aclrtStream: " << ret); - } - streams.push_back(sub_stream); - } - context_->SetExecuteStreams(streams); + void* stream = c10_npu::getCurrentNPUStream(device_id).stream(); + context_->SetExecuteStream(stream); context_->SetAsyncTilingCopyStatus(true); #endif } diff --git a/xllm/core/layers/npu/npu_base_layer.cpp b/xllm/core/layers/npu/npu_base_layer.cpp index d6c73d1eb..72165dde9 100644 --- a/xllm/core/layers/npu/npu_base_layer.cpp +++ b/xllm/core/layers/npu/npu_base_layer.cpp @@ -32,11 +32,10 @@ NpuBaseLayer::NpuBaseLayer(const ModelContext& context) : BaseLayer(context) { work_space_ = AtbWorkspace(device_); } -atb::Status NpuBaseLayer::execute_node( - atb_speed::Model::Node& node, - int node_id, - std::vector event, - std::vector*> event_flag) { +atb::Status NpuBaseLayer::execute_node(atb_speed::Model::Node& node, + int node_id, + aclrtEvent* event, + std::atomic* event_flag) { // TODO(by zhangminchao1@jd.com): Stream management needs to be refactored // for better separation of concerns Current issues: // 1. ACLGraph capture requires execution on a non-default stream, so we @@ -93,28 +92,25 @@ atb::Status NpuBaseLayer::execute_node( return st; } -atb::Status NpuBaseLayer::execute_plan( - const atb_speed::Model::Node& node, - const std::string& op_name, - std::vector event, - std::vector*> event_flag) { +atb::Status NpuBaseLayer::execute_plan(const atb_speed::Model::Node& node, + const std::string& op_name, + aclrtEvent* event, + std::atomic* event_flag) { atb::Status st = node.operation->Execute( node.variantPack, (uint8_t*)node.workspace, node.workspaceSize, context_); LOG_IF(ERROR, st != 0) << name_ << " execute plan fail, error code: " << st; - for (auto i = 0; i < event.size(); ++i) { - if (st == 0 && event[i] != nullptr) { - aclrtStream stream = context_->GetExecuteStream(); + if (st == 0 && event != nullptr) { + aclrtStream stream = context_->GetExecuteStream(); - aclrtEvent* aclrt_event = reinterpret_cast(event[i]); + aclrtEvent* aclrt_event = reinterpret_cast(event); - auto ret = aclrtRecordEvent(*aclrt_event, stream); - if (ret != ACL_SUCCESS) { - LOG(ERROR) << "Record event failed."; - return st; - } - - event_flag[i]->store(true, std::memory_order_release); + auto ret = aclrtRecordEvent(*aclrt_event, stream); + if (ret != ACL_SUCCESS) { + LOG(ERROR) << "Record event failed."; + return st; } + + event_flag->store(true, std::memory_order_release); } return st; diff --git a/xllm/core/layers/npu/npu_base_layer.h b/xllm/core/layers/npu/npu_base_layer.h index 3d102b23b..cb87e2274 100644 --- a/xllm/core/layers/npu/npu_base_layer.h +++ b/xllm/core/layers/npu/npu_base_layer.h @@ -61,15 +61,13 @@ class NpuBaseLayer : public BaseLayer { atb::Status execute_node(atb_speed::Model::Node& node, int nodeId = 0, - std::vector event = {nullptr, nullptr}, - std::vector*> event_flag = { - nullptr, - nullptr}); + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr); atb::Status execute_plan(const atb_speed::Model::Node& node, const std::string& op_name, - std::vector event, - std::vector*> event_flag); + aclrtEvent* event, + std::atomic* event_flag); virtual void run_task(std::string taskName, std::function task) const override; diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index 3a8351f4d..f321d1401 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -1473,8 +1473,6 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_node( bool eplb_enabled = FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace && !param.isPrefill; - bool multi_stream_parallel_enabled = - param.isPrefill && FLAGS_enable_multi_stream_parallel; atb::Operation* operation = nullptr; atb_speed::deepseekV2::DecoderLayer(param, &operation); node.operation.reset(operation); @@ -1488,7 +1486,7 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_node( } node.inTensors.resize(node.operation->GetInputNum()); - if (eplb_enabled || multi_stream_parallel_enabled) { + if (eplb_enabled) { node.outTensors.resize(2); } else { node.outTensors.resize(1); @@ -1506,7 +1504,7 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_node( // eplb used in decode stage, while multi stream parallel used in prefill // stage - if (eplb_enabled || multi_stream_parallel_enabled) { + if (eplb_enabled) { node.variantPack.outTensors.reserve(2); node.variantPack.outTensors.resize(2); // TODO } else { @@ -1517,41 +1515,39 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_node( } torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( - std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - const std::vector& input_params, - std::vector event, - std::vector*> event_flag, + const ModelInputParams& input_params, + aclrtEvent* event, + std::atomic* event_flag, int node_id) { atb::Status st; - // all micro batches are in same prefill/decode stage, - // so, to judge empty_kv_cache, use input_params[0] here - if (input_params[0].global_empty_kv_cache) { + ModelInputParams& input_params_new = + const_cast(input_params); + if (input_params.global_empty_kv_cache) { build_node_variant_pack(prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, - input_params, + input_params_new, true); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; } else { - std::vector attn_mask{tensor_placeholder_, - tensor_placeholder_}; if (!FLAGS_enable_customize_mla_kernel) { build_node_variant_pack(decode_node_, x, cos_pos, sin_pos, - attn_mask, + /*attn_mask*/ tensor_placeholder_, kv_cache, - input_params, + input_params_new, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) @@ -1561,9 +1557,9 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( x, cos_pos, sin_pos, - attn_mask, + /*attn_mask*/ tensor_placeholder_, kv_cache, - input_params, + input_params_new, false); st = execute_node(decode_mla_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) @@ -1575,17 +1571,17 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( atb_speed::Model::Node& node, - std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - const std::vector& input_params, + ModelInputParams& input_params, bool is_prefill) { - internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x[0]); + internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); // final_hidden_states_ = torch::zeros_like(x); int32_t input_idx = 0; - auto& dp_ep_padding = input_params[0].dp_ep_padding_data; + auto& dp_ep_padding = input_params.dp_ep_padding_data; // set micro batch 0 input part node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensor_; @@ -1600,11 +1596,11 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 5) = atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) = - atb_speed::Utils::AtTensor2Tensor(cos_pos[0]); + atb_speed::Utils::AtTensor2Tensor(cos_pos); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7) = - atb_speed::Utils::AtTensor2Tensor(sin_pos[0]); + atb_speed::Utils::AtTensor2Tensor(sin_pos); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 8) = - atb_speed::Utils::AtTensor2Tensor(attn_mask[0]); + atb_speed::Utils::AtTensor2Tensor(attn_mask); if (!FLAGS_enable_continuous_kvcache) { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9) = @@ -1618,8 +1614,8 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( XTensor2Tensor(kv_cache.get_v_xtensor()); } - if ((!input_params[0].block_tables.defined() || - input_params[0].block_tables.storage().data() == nullptr) && + if ((!input_params.block_tables.defined() || + input_params.block_tables.storage().data() == nullptr) && !FLAGS_enable_continuous_kvcache) { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11) = atb_speed::Utils::AtTensor2Tensor(int_tensor_placeholder_); @@ -1627,9 +1623,9 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( const_cast(placeholder_vec_.data()); } else { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11) = - atb_speed::Utils::AtTensor2Tensor(input_params[0].kv_seq_lens); + atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11).hostData = - const_cast(input_params[0].kv_seq_lens_vec.data()); + const_cast(input_params.kv_seq_lens_vec.data()); } node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 12) = @@ -1642,30 +1638,28 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); if (!FLAGS_enable_continuous_kvcache) { - if (!input_params[0].block_tables.defined() || - input_params[0].block_tables.storage().data() == nullptr) { + if (!input_params.block_tables.defined() || + input_params.block_tables.storage().data() == nullptr) { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 15) = atb_speed::Utils::AtTensor2Tensor(block_tables_placeholder_); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) = atb_speed::Utils::AtTensor2Tensor(slot_tensor_placeholder_); } else { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 15) = - atb_speed::Utils::AtTensor2Tensor(input_params[0].block_tables); + atb_speed::Utils::AtTensor2Tensor(input_params.block_tables); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) = - atb_speed::Utils::AtTensor2Tensor(input_params[0].new_cache_slots); + atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); } } else { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 15) = - atb_speed::Utils::AtTensor2Tensor( - input_params[0].kv_cache_start_offsets); + atb_speed::Utils::AtTensor2Tensor(input_params.kv_cache_start_offsets); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) = - atb_speed::Utils::AtTensor2Tensor( - input_params[0].new_cache_slot_offsets); + atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slot_offsets); } if (num_speculative_tokens_ > 0 && !is_prefill) { - if ((!input_params[0].block_tables.defined() || - input_params[0].block_tables.storage().data() == nullptr) && + if ((!input_params.block_tables.defined() || + input_params.block_tables.storage().data() == nullptr) && !FLAGS_enable_continuous_kvcache) { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17) = atb_speed::Utils::AtTensor2Tensor(int_tensor_placeholder_); @@ -1673,224 +1667,47 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( const_cast(placeholder_vec_.data()); } else { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17) = - atb_speed::Utils::AtTensor2Tensor(input_params[0].q_seq_lens); + atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17).hostData = - const_cast(input_params[0].q_seq_lens_vec.data()); + const_cast(input_params.q_seq_lens_vec.data()); } } else { node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17) = atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); } - if (is_prefill && FLAGS_enable_multi_stream_parallel) { - internal_tensor_auxiliary_ = atb_speed::Utils::AtTensor2Tensor(x[1]); - auto& dp_ep_padding_auxiliary = input_params[1].dp_ep_padding_data; - - // set micro batch 1 input part - auto offset = 18; - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset) = - internal_tensor_auxiliary_; - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.expert_array()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2 + offset) = - atb_speed::Utils::AtTensor2Tensor(expert_group_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3 + offset) = - atb_speed::Utils::AtTensor2Tensor(one_hot_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 4 + offset) = - atb_speed::Utils::AtTensor2Tensor(zero_hot_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 5 + offset) = - atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6 + offset) = - atb_speed::Utils::AtTensor2Tensor(cos_pos[1]); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7 + offset) = - atb_speed::Utils::AtTensor2Tensor(sin_pos[1]); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 8 + offset) = - atb_speed::Utils::AtTensor2Tensor(attn_mask[1]); - - if (!FLAGS_enable_continuous_kvcache) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9 + offset) = - atb_speed::Utils::AtTensor2Tensor(kv_cache.get_k_cache()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 10 + offset) = - atb_speed::Utils::AtTensor2Tensor(kv_cache.get_v_cache()); - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9 + offset) = - XTensor2Tensor(kv_cache.get_k_xtensor()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 10 + offset) = - XTensor2Tensor(kv_cache.get_v_xtensor()); - } - - if ((!input_params[1].block_tables.defined() || - input_params[1].block_tables.storage().data() == nullptr) && - !FLAGS_enable_continuous_kvcache) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11 + offset) = - atb_speed::Utils::AtTensor2Tensor(int_tensor_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11 + offset) - .hostData = const_cast(placeholder_vec_.data()); - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11 + offset) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].kv_seq_lens); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11 + offset) - .hostData = - const_cast(input_params[1].kv_seq_lens_vec.data()); - } - - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 12 + offset) = - atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 13 + offset) = - atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 13 + offset) - .hostData = const_cast(placeholder_vec_.data()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 14 + offset) = - atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); - - if (!FLAGS_enable_continuous_kvcache) { - if (!input_params[1].block_tables.defined() || - input_params[1].block_tables.storage().data() == nullptr) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 15 + offset) = - atb_speed::Utils::AtTensor2Tensor(block_tables_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16 + offset) = - atb_speed::Utils::AtTensor2Tensor(slot_tensor_placeholder_); - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 15 + offset) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].block_tables); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16 + offset) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].new_cache_slots); - } - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 15 + offset) = - atb_speed::Utils::AtTensor2Tensor( - input_params[1].kv_cache_start_offsets); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16 + offset) = - atb_speed::Utils::AtTensor2Tensor( - input_params[1].new_cache_slot_offsets); - } - - if (num_speculative_tokens_ > 0 && !is_prefill) { - if ((!input_params[1].block_tables.defined() || - input_params[1].block_tables.storage().data() == nullptr) && - !FLAGS_enable_continuous_kvcache) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17 + offset) = - atb_speed::Utils::AtTensor2Tensor(int_tensor_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17 + offset) - .hostData = const_cast(placeholder_vec_.data()); - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17 + offset) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].q_seq_lens); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17 + offset) - .hostData = - const_cast(input_params[1].q_seq_lens_vec.data()); - } - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 17 + offset) = - atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); - } - - // set micro batch 0 dp_ep_padding part - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 18 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.attn_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 19 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.attn_unpadding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 20 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.ffn_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 21 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.ffn_unpadding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 22 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding.lm_head_skip_padding_token_indices()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.gather_prenorm_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 24 + offset) = - atb_speed::Utils::AtTensor2Tensor(at_start_expert_id_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 25 + offset) = - atb_speed::Utils::AtTensor2Tensor(at_in_device_expert_count_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 26 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 27 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.un_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 28 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.dynamic_ep_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 29 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.moe_idx()); - - // set micro batch 1 dp_ep_padding part - offset = 30; - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 18 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.attn_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 19 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.attn_unpadding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 20 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.ffn_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 21 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.ffn_unpadding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 22 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.lm_head_skip_padding_token_indices()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.gather_prenorm_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 24 + offset) = - atb_speed::Utils::AtTensor2Tensor(at_start_expert_id_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 25 + offset) = - atb_speed::Utils::AtTensor2Tensor(at_in_device_expert_count_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 26 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 27 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.un_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 28 + offset) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding_auxiliary.dynamic_ep_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 29 + offset) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding_auxiliary.moe_idx()); - - if (FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace) { - // set micro batch 0 eplb part - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 30 + offset) = - atb_speed::Utils::AtTensor2Tensor(expert_routing_map_); - // set micro batch 1 eplb part - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 31 + offset) = - atb_speed::Utils::AtTensor2Tensor(expert_routing_map_); - } - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 18) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.attn_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 19) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.attn_unpadding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 20) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.ffn_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 21) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.ffn_unpadding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 22) = - atb_speed::Utils::AtTensor2Tensor( - dp_ep_padding.lm_head_skip_padding_token_indices()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.gather_prenorm_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 24) = - atb_speed::Utils::AtTensor2Tensor(at_start_expert_id_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 25) = - atb_speed::Utils::AtTensor2Tensor(at_in_device_expert_count_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 26) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 27) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.un_padding_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 28) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.dynamic_ep_idx()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 29) = - atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.moe_idx()); - if (FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 30) = - atb_speed::Utils::AtTensor2Tensor(expert_routing_map_); - if (!is_prefill) { - node.variantPack.outTensors.at(1) = atb_speed::Utils::AtTensor2Tensor( - input_params[0].expert_load_data[layer_id_ - - decode_param_.firstKDenseReplace]); - } + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 18) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.attn_padding_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 19) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.attn_unpadding_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 20) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.ffn_padding_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 21) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.ffn_unpadding_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 22) = + atb_speed::Utils::AtTensor2Tensor( + dp_ep_padding.lm_head_skip_padding_token_indices()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.gather_prenorm_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 24) = + atb_speed::Utils::AtTensor2Tensor(at_start_expert_id_); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 25) = + atb_speed::Utils::AtTensor2Tensor(at_in_device_expert_count_); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 26) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.padding_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 27) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.un_padding_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 28) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.dynamic_ep_idx()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 29) = + atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.moe_idx()); + if (FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace) { + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 30) = + atb_speed::Utils::AtTensor2Tensor(expert_routing_map_); + if (!is_prefill) { + node.variantPack.outTensors.at(1) = atb_speed::Utils::AtTensor2Tensor( + input_params + .expert_load_data[layer_id_ - decode_param_.firstKDenseReplace]); } } @@ -1901,9 +1718,6 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( } node.variantPack.outTensors.at(0) = internal_tensor_; - if (is_prefill && FLAGS_enable_multi_stream_parallel) { - node.variantPack.outTensors.at(1) = internal_tensor_auxiliary_; - } } } // namespace layer diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h index 98996f2d3..c57964882 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h @@ -126,15 +126,14 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { virtual int64_t init_layer() override; - torch::Tensor forward(std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - const std::vector& input_params, - std::vector event = {nullptr, nullptr}, - std::vector*> event_flag = {nullptr, - nullptr}, + const ModelInputParams& input_params, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr, int node_id = 0); private: @@ -268,15 +267,14 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { int64_t init_node(atb_speed::Model::Node& node, atb_speed::deepseekV2::DecoderLayerParam& param); - void build_node_variant_pack( - atb_speed::Model::Node& node, - std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, - KVCache& kv_cache, - const std::vector& input_params, - bool is_prefill); + void build_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill); torch::Tensor block_tables_placeholder_; std::string model_name_; @@ -320,7 +318,6 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { atb_speed::Model::Node decode_mla_node_; atb::Tensor internal_tensor_; - atb::Tensor internal_tensor_auxiliary_; torch::Tensor at_cumsum_; torch::Tensor tensor_placeholder_; diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp index 927b98063..ed4ea1bf1 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp @@ -1073,17 +1073,16 @@ int64_t Glm4MoeDecoderImpl::init_node(atb_speed::Model::Node& node, return atb::NO_ERROR; } -torch::Tensor Glm4MoeDecoderImpl::forward( - torch::Tensor& x, - torch::Tensor& cos_pos, - torch::Tensor& sin_pos, - torch::Tensor& attn_mask, - KVCache& kv_cache, - const ModelInputParams& input_params, - torch::Tensor& expert_array, - std::vector event, - std::vector*> event_flag, - int node_id) { +torch::Tensor Glm4MoeDecoderImpl::forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + const ModelInputParams& input_params, + torch::Tensor& expert_array, + aclrtEvent* event, + std::atomic* event_flag, + int node_id) { atb::Status st; if (input_params.decode_seq_range.second != input_params.q_seq_lens.size(0) - 1) { diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h index 5b6e5ce8d..d0c2732b6 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h @@ -53,9 +53,8 @@ class Glm4MoeDecoderImpl : public NpuBaseLayer { KVCache& kv_cache, const ModelInputParams& input_params, torch::Tensor& expert_array, - std::vector event = {nullptr, nullptr}, - std::vector*> event_flag = {nullptr, - nullptr}, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr, int node_id = 0); private: diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index 0f788904e..ffbf45bd0 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -394,27 +394,26 @@ int64_t NpuQwen2DecoderLayerImpl::init_node( return atb::NO_ERROR; } -torch::Tensor NpuQwen2DecoderLayerImpl::forward( - std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, - KVCache& kv_cache, - std::vector& input_params, - std::vector event, - std::vector*> event_flag, - int node_id) { +torch::Tensor NpuQwen2DecoderLayerImpl::forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + aclrtEvent* event, + std::atomic* event_flag, + int node_id) { atb::Status st; - if (input_params[0].decode_seq_range.second != - input_params[0].q_seq_lens.size(0) - 1) { + if (input_params.decode_seq_range.second != + input_params.q_seq_lens.size(0) - 1) { // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); build_node_variant_pack(prefill_node_, - x[0], - cos_pos[0], - sin_pos[0], - attn_mask[0], + x, + cos_pos, + sin_pos, + attn_mask, kv_cache, - input_params[0], + input_params, true); // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); @@ -422,12 +421,12 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward( << "excute prefill layer fail, error code: " << st; } else { build_node_variant_pack(decode_node_, - x[0], - cos_pos[0], - sin_pos[0], + x, + cos_pos, + sin_pos, decode_attn_mask_, kv_cache, - input_params[0], + input_params, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h index c70bcfd74..17d2b15ac 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h @@ -120,15 +120,14 @@ class NpuQwen2DecoderLayerImpl : public NpuBaseLayer { virtual int64_t init_layer() override; - torch::Tensor forward(std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - std::vector& input_params, - std::vector event = {nullptr, nullptr}, - std::vector*> event_flag = {nullptr, - nullptr}, + ModelInputParams& input_params, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr, int node_id = 0); private: diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index d48379fb4..f7ae89231 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -187,7 +187,7 @@ void NpuQwen3DecoderLayerImpl::param_from_args( // NOTE: Currently, single-process startup requires setting enableLcoc to // false, which leads to performance degradation. param.enableLcoc = false; // //isPrefill - param.enableLcoc = FLAGS_enable_multi_stream_parallel ? isPrefill : false; + param.enableLcoc = false; param.rmsnormQKNorm = true; param.isPrefill = isPrefill; param.isBF16 = args.dtype() == "bfloat16"; @@ -295,10 +295,6 @@ NpuQwen3DecoderLayerImpl::NpuQwen3DecoderLayerImpl(const ModelContext& context) for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { at_weight_tensors_[i] = torch::zeros({1}).to(options); } - int_tensor_placeholder_ = torch::ones({1}).to(torch::kInt32).to(device_); - slot_tensor_placeholder_ = torch::full({1}, 0).to(torch::kInt32).to(device_); - block_tables_placeholder_ = - torch::zeros({1, 1}).to(torch::kInt32).to(device_); } void NpuQwen3DecoderLayerImpl::verify_loaded_weights() const { @@ -484,8 +480,6 @@ int64_t NpuQwen3DecoderLayerImpl::init_attn_mask() { int64_t NpuQwen3DecoderLayerImpl::init_node( atb_speed::Model::Node& node, atb_speed::qwen::QwenLayerParam& param) { - bool multi_stream_parallel_enabled = - param.isPrefill && FLAGS_enable_multi_stream_parallel; atb::Operation* operation = nullptr; atb_speed::qwen::QwenDecoderLayer decoder_layer(param); decoder_layer.BuildGraph(&operation); @@ -499,7 +493,7 @@ int64_t NpuQwen3DecoderLayerImpl::init_node( return -1; } node.inTensors.resize(node.operation->GetInputNum()); - node.outTensors.resize(multi_stream_parallel_enabled ? 2 : 1); + node.outTensors.resize(1); size_t inTensorId = 1; for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER; @@ -509,25 +503,24 @@ int64_t NpuQwen3DecoderLayerImpl::init_node( node.variantPack.inTensors.reserve(node.inTensors.size()); node.variantPack.inTensors.resize(node.inTensors.size()); - node.variantPack.outTensors.reserve(multi_stream_parallel_enabled ? 2 : 1); - node.variantPack.outTensors.resize(multi_stream_parallel_enabled ? 2 : 1); + node.variantPack.outTensors.reserve(1); + node.variantPack.outTensors.resize(1); return atb::NO_ERROR; } -torch::Tensor NpuQwen3DecoderLayerImpl::forward( - std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, - KVCache& kv_cache, - std::vector& input_params, - std::vector event, - std::vector*> event_flag, - int node_id) { +torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + aclrtEvent* event, + std::atomic* event_flag, + int node_id) { atb::Status st; - if (input_params[0].decode_seq_range.second != - input_params[0].q_seq_lens.size(0) - 1) { + if (input_params.decode_seq_range.second != + input_params.q_seq_lens.size(0) - 1) { // if (input_params.empty_kv_cache) { // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); build_node_variant_pack(prefill_node_, @@ -543,13 +536,11 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward( LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; } else { - std::vector decode_attn_masks = {this->decode_attn_mask_, - this->decode_attn_mask_}; build_node_variant_pack(decode_node_, x, cos_pos, sin_pos, - decode_attn_masks, + decode_attn_mask_, kv_cache, input_params, false); @@ -563,117 +554,44 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward( void NpuQwen3DecoderLayerImpl::build_node_variant_pack( atb_speed::Model::Node& node, - std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + at::Tensor& attn_mask, KVCache& kv_cache, - std::vector& input_params, + ModelInputParams& input_params, bool is_prefill) { - internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x[0]); + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); // std::cout<<"node.variantPack.inTensors.size:"<(placeholder_vec_.data()); - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6 + offset) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].kv_seq_lens); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6 + offset) - .hostData = input_params[1].kv_seq_lens_vec.data(); - } - - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7 + offset) = - placeholder_; - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7 + offset) - .hostData = placeholder_vec_.data(); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 8 + offset) = - placeholder_; - - if (!input_params[1].block_tables.defined() || - input_params[1].block_tables.storage().data() == nullptr) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9 + offset) = - atb_speed::Utils::AtTensor2Tensor(block_tables_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 10 + offset) = - atb_speed::Utils::AtTensor2Tensor(slot_tensor_placeholder_); - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9 + offset) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].block_tables); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 10 + offset) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].new_cache_slots); - } - - if (is_prefill && - (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 22) = - atb_speed::Utils::AtTensor2Tensor(input_params[0].q_seq_lens); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 22).hostData = - input_params[0].q_seq_lens_vec.data(); - - if (!input_params[1].block_tables.defined() || - input_params[1].block_tables.storage().data() == nullptr) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23) = - atb_speed::Utils::AtTensor2Tensor(int_tensor_placeholder_); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23).hostData = - const_cast(placeholder_vec_.data()); - } else { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23) = - atb_speed::Utils::AtTensor2Tensor(input_params[1].q_seq_lens); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 23).hostData = - input_params[1].q_seq_lens_vec.data(); - } - } + atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); + if (is_prefill && + (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) { + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11) = + atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11).hostData = + input_params.q_seq_lens_vec.data(); } for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { @@ -686,9 +604,6 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( } node.variantPack.outTensors.at(0) = internal_tensors_; - if (is_prefill && FLAGS_enable_multi_stream_parallel) { - node.variantPack.outTensors.at(1) = internal_tensors_auxiliary; - } } } // namespace layer diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h index 03ec17d10..785f43a16 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h @@ -57,15 +57,14 @@ class NpuQwen3DecoderLayerImpl : public NpuBaseLayer { virtual int64_t init_layer() override; - torch::Tensor forward(std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - std::vector& input_params, - std::vector event = {nullptr, nullptr}, - std::vector*> event_flag = {nullptr, - nullptr}, + ModelInputParams& input_params, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr, int node_id = 0); private: @@ -75,12 +74,12 @@ class NpuQwen3DecoderLayerImpl : public NpuBaseLayer { bool isPrefill); void build_node_variant_pack(atb_speed::Model::Node& node, - std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - std::vector& input_params, + ModelInputParams& input_params, bool is_prefill); void initialize_quantization_parameters( @@ -97,11 +96,7 @@ class NpuQwen3DecoderLayerImpl : public NpuBaseLayer { atb_speed::qwen::QwenLayerParam prefill_param_; atb_speed::qwen::QwenLayerParam decode_param_; atb::Tensor internal_tensors_; - atb::Tensor internal_tensors_auxiliary; atb::Tensor placeholder_; - torch::Tensor int_tensor_placeholder_; - torch::Tensor block_tables_placeholder_; - torch::Tensor slot_tensor_placeholder_; at::Tensor decode_attn_mask_; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index f7d74347d..c2a840f91 100755 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -875,7 +875,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( input_params, expert_array, true); - st = execute_node(prefill_node_, node_id, {event}, {event_flag}); + st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; } else { @@ -888,7 +888,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( input_params, expert_array, false); - st = execute_node(decode_node_, node_id + 1000, {event}, {event_flag}); + st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute decode layer fail, error code: " << st; } diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 4a603945a..1100a9bc2 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -207,7 +207,7 @@ torch::Tensor AclGraphExecutorImpl::run( // If not in decode phase, use eager mode directly without acl graph if (!in_decoding_phase) { COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(tokens[0], positions[0], kv_caches, params[0]); } // Only use acl graph in decode phase for performance optimization @@ -237,7 +237,7 @@ torch::Tensor AclGraphExecutorImpl::run( // Early return if conditions are not suitable for graph operations if (!capture_supported) { COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(tokens[0], positions[0], kv_caches, params[0]); } // Check if captured graph exists for this bucket size @@ -273,7 +273,7 @@ torch::Tensor AclGraphExecutorImpl::run( // Fallback to eager mode if capture fails LOG(ERROR) << "Failed to capture ACL graph for bucket size: " << bucket_size; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(tokens[0], positions[0], kv_caches, params[0]); } void AclGraph::copy_data_to_graph_buffer(const torch::Tensor& tokens, diff --git a/xllm/core/runtime/acl_graph_executor_test.cpp b/xllm/core/runtime/acl_graph_executor_test.cpp index 5a330d31f..988cd669f 100644 --- a/xllm/core/runtime/acl_graph_executor_test.cpp +++ b/xllm/core/runtime/acl_graph_executor_test.cpp @@ -187,19 +187,11 @@ class SimpleCausalLM : public CausalLM { } // Adapter method to match CausalLM base class interface - torch::Tensor forward( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& parameters) override { - // For SimpleCausalLM, we expect single tensor inputs - CHECK_EQ(tokens.size(), 1) << "SimpleCausalLM expects single token tensor"; - CHECK_EQ(positions.size(), 1) - << "SimpleCausalLM expects single position tensor"; - CHECK_EQ(parameters.size(), 1) - << "SimpleCausalLM expects single parameter set"; - - return forward_impl(tokens[0], positions[0], kv_caches, parameters[0]); + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& parameters) override { + return forward_impl(tokens, positions, kv_caches, parameters); } const torch::TensorOptions& options() const override { @@ -243,13 +235,12 @@ class SimpleCausalLM : public CausalLM { // Simple implementation for testing } - std::vector get_word_embedding() override { + layer::WordEmbedding get_word_embedding() override { // Simple implementation for testing - return std::vector{layer::WordEmbedding(nullptr)}; + return layer::WordEmbedding(nullptr); } - void set_word_embedding( - std::vector& embedding) override { + void set_word_embedding(layer::WordEmbedding& embedding) override { // Simple implementation for testing } diff --git a/xllm/core/runtime/base_executor_impl.cpp b/xllm/core/runtime/base_executor_impl.cpp index 37f931163..4216ed7bf 100644 --- a/xllm/core/runtime/base_executor_impl.cpp +++ b/xllm/core/runtime/base_executor_impl.cpp @@ -36,7 +36,7 @@ torch::Tensor BaseExecutorImpl::run( const std::vector& positions, std::vector& kv_caches, const std::vector& params) { - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(tokens[0], positions[0], kv_caches, params[0]); } } // namespace xllm diff --git a/xllm/core/runtime/llm_worker_impl.h b/xllm/core/runtime/llm_worker_impl.h index 7da99eb81..4387e5761 100644 --- a/xllm/core/runtime/llm_worker_impl.h +++ b/xllm/core/runtime/llm_worker_impl.h @@ -49,11 +49,11 @@ class LLMWorkerImpl : public WorkerImpl { void set_lm_head(layer::LmHead& head) { model_->set_lm_head(head); }; - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); }; - void set_word_embedding(std::vector& embedding) { + void set_word_embedding(layer::WordEmbedding& embedding) { model_->set_word_embedding(embedding); }; diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 174dde774..5d85f8af8 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -54,14 +54,14 @@ class DeepseekV2DecoderLayerImpl : public torch::nn::Module { "decoder_layer", layer::DeepseekV2DecoderLayer(context, i, sm_scale)); } - torch::Tensor forward(std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - const std::vector& input_params, - std::vector event, - std::vector*> event_flag) { + const ModelInputParams& input_params, + aclrtEvent* event, + std::atomic* event_flag) { return decoder_layer_(x, cos_pos, sin_pos, @@ -117,17 +117,17 @@ class DeepseekV2ModelImpl : public torch::nn::Module { model_args.rotary_dim(), model_args.rope_theta(), model_args.rope_scaling_original_max_position_embeddings()); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); float sm_scale = 1.0f; - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_.push_back(layer::WordEmbedding(context)); - pos_embs_.push_back(create_rotary_embedding(model_args, - model_args.rotary_dim(), - inv_freq, - /*interleaved=*/false, - sm_scale, - options)); - atb_pos_embs_.push_back(layer::PosEmbedding(context)); - } + pos_emb_ = create_rotary_embedding(model_args, + model_args.rotary_dim(), + inv_freq, + /*interleaved=*/false, + sm_scale, + options); + atb_pos_emb_ = layer::PosEmbedding(context); + max_seq_len_ = model_args.max_position_embeddings(); int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; attn_mask_ = layer::AttentionMask(options.device(), @@ -154,84 +154,61 @@ class DeepseekV2ModelImpl : public torch::nn::Module { } } - torch::Tensor forward(std::vector tokens, - std::vector positions, + torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, std::vector& kv_caches, - const std::vector& input_params) { - auto micro_batch_num = tokens.size(); - std::vector hs; - hs.reserve(micro_batch_num); - std::vector cos_poss; - cos_poss.reserve(micro_batch_num); - std::vector sin_poss; - sin_poss.reserve(micro_batch_num); - std::vector attn_masks; - attn_masks.reserve(micro_batch_num); - - for (auto i = 0; i < micro_batch_num; ++i) { - if (dp_size_ > 1) { - if (tokens[i].sizes() == 0) { - tokens[i] = torch::tensor({1}).to(torch::kInt32).to(device_); - positions[i] = torch::tensor({0}).to(torch::kInt32).to(device_); - } + const ModelInputParams& input_params) { + if (dp_size_ > 1) { + if (tokens.sizes() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(device_); + positions = torch::tensor({0}).to(torch::kInt32).to(device_); } + } - hs.push_back(std::move(embed_tokens_[i](tokens[i], i))); - auto cos_sin = - atb_pos_embs_[i](pos_embs_[i]->get_cos_sin_cache(), positions[i], 0); - auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = cos_sin_chunks[0].contiguous(); - auto sin_pos = cos_sin_chunks[1].contiguous(); - - torch::Tensor attn_mask; - if (num_speculative_tokens_ == 0 || - input_params[i].global_empty_kv_cache) { - attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); - } else { - attn_mask = attn_mask_.gen_free_mask( - num_speculative_tokens_ + 1, dtype_, device_); - } - cos_poss.push_back(std::move(cos_pos)); - sin_poss.push_back(std::move(sin_pos)); - attn_masks.push_back(std::move(attn_mask)); + auto h = embed_tokens_(tokens, 0); + auto cos_sin = atb_pos_emb_(pos_emb_->get_cos_sin_cache(), positions, 0); + auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = cos_sin_chunks[0].contiguous(); + auto sin_pos = cos_sin_chunks[1].contiguous(); + + torch::Tensor attn_mask; + if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); + } else { + attn_mask = attn_mask_.gen_free_mask( + num_speculative_tokens_ + 1, dtype_, device_); } for (size_t i = 0; i < layers_.size(); i++) { - std::vector events(micro_batch_num, nullptr); - std::vector*> event_flags(micro_batch_num, nullptr); - for (auto j = 0; j < micro_batch_num; ++j) { - if (input_params[j].layer_synchronizer != nullptr) { - events[j] = input_params[j].layer_synchronizer->get_event(i); - event_flags[j] = - input_params[j].layer_synchronizer->get_event_flag(i); - } - if (input_params[j].layer_wise_load_synchronizer != nullptr) { - if (!input_params[j].layer_wise_load_synchronizer->synchronize_layer( - i)) { - return torch::Tensor(); - } + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); } } + auto& layer = layers_[i]; - layer(hs, - cos_poss, - sin_poss, - attn_masks, + layer(h, + cos_pos, + sin_pos, + attn_mask, kv_caches[i], input_params, - events, - event_flags); + event, + event_flag); } - auto cancated_h = torch::cat(hs, 0); - return norm_(cancated_h, 0); + return norm_(h, 0); } // load the weight from the checkpoint void load_state_dict(const StateDict& state_dict) { - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_[i]->load_state_dict( - state_dict.get_dict_with_prefix("embed_tokens.")); - } + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); // call each layer's load_state_dict function for (int i = 0; i < layers_.size(); i++) { layers_[i]->load_state_dict( @@ -241,9 +218,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module { } void verify_loaded_weights(const std::string& prefix) const { - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_[i]->verify_loaded_weights(prefix + "embed_tokens."); - } + embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); for (int i = 0; i < layers_.size(); i++) { layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + "."); @@ -252,9 +227,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module { } void merge_loaded_weights() { - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_[i]->merge_loaded_weights(); - } + embed_tokens_->merge_loaded_weights(); for (int i = 0; i < layers_.size(); i++) { layers_[i]->merge_loaded_weights(); } @@ -270,11 +243,9 @@ class DeepseekV2ModelImpl : public torch::nn::Module { layers_[layer_id]->update_expert_weight(); } - std::vector get_word_embedding() { - return embed_tokens_; - } + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { embed_tokens_ = word_embedding; } @@ -291,9 +262,9 @@ class DeepseekV2ModelImpl : public torch::nn::Module { int32_t num_speculative_tokens_ = 0; at::Device device_; torch::Dtype dtype_; - std::vector embed_tokens_; - std::vector> pos_embs_; - std::vector atb_pos_embs_; + layer::WordEmbedding embed_tokens_{nullptr}; + std::shared_ptr pos_emb_{nullptr}; + layer::PosEmbedding atb_pos_emb_{nullptr}; layer::AttentionMask attn_mask_; layer::RmsNorm norm_{nullptr}; }; @@ -310,10 +281,10 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { + const ModelInputParams& input_params) { return model_(tokens, positions, kv_caches, input_params); } @@ -353,11 +324,11 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module { void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/llm/deepseek_v2_mtp.h b/xllm/models/llm/deepseek_v2_mtp.h index b1c27c149..be17951de 100644 --- a/xllm/models/llm/deepseek_v2_mtp.h +++ b/xllm/models/llm/deepseek_v2_mtp.h @@ -73,16 +73,16 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { layers_.push_back(block); blocks_->push_back(block); } - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - pos_embs_.push_back(create_rotary_embedding(model_args, - model_args.rotary_dim(), - inv_freq, - /*interleaved=*/false, - sm_scale, - options)); - atb_pos_embs_.push_back(layer::PosEmbedding(context)); - eh_projs_.push_back(layer::ColumnParallelLinear(context)); - } + + pos_emb_ = create_rotary_embedding(model_args, + model_args.rotary_dim(), + inv_freq, + /*interleaved=*/false, + sm_scale, + options); + atb_pos_emb_ = layer::PosEmbedding(context); + eh_proj_ = register_module("eh_proj", layer::ColumnParallelLinear(context)); + enorm_ = register_module("enorm", layer::RmsNorm(context)); hnorm_ = register_module("hnorm", layer::RmsNorm(context)); final_norm_ = register_module("final_norm", layer::RmsNorm(context)); @@ -102,84 +102,63 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - torch::Tensor forward(std::vector tokens, - std::vector positions, + torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, std::vector& kv_caches, - const std::vector& input_params) { - auto micro_batch_num = tokens.size(); - std::vector hs; - hs.reserve(micro_batch_num); - std::vector cos_poss; - cos_poss.reserve(micro_batch_num); - std::vector sin_poss; - sin_poss.reserve(micro_batch_num); - std::vector attn_masks; - attn_masks.reserve(micro_batch_num); - - for (auto i = 0; i < micro_batch_num; ++i) { - if (dp_size_ > 1) { - if (tokens[i].sizes() == 0) { - tokens[i] = torch::tensor({1}).to(torch::kInt32).to(device_); - positions[i] = torch::tensor({0}).to(torch::kInt32).to(device_); - } - } - - hs.push_back(std::move(embed_tokens_[i](tokens[i], 0))); - torch::Tensor enorm = enorm_(hs[i], 0); - const auto& res = input_params[i].mm_data.get("embedding"); - if (res) { - hs[i] = res.value(); - } else { - LOG(WARNING) << "hnorm use embedding from tokens."; + const ModelInputParams& input_params) { + if (dp_size_ > 1) { + if (tokens.sizes() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(device_); + positions = torch::tensor({0}).to(torch::kInt32).to(device_); } + } - torch::Tensor hnorm = hnorm_(hs[i], 0); - CHECK_EQ(enorm.dim(), hnorm.dim()); - CHECK_EQ(enorm.size(0), hnorm.size(0)); - hs[i] = torch::cat({enorm, hnorm}, /*dim=*/-1); - hs[i] = eh_projs_[i](hs[i], 0); - - auto cos_sin = - atb_pos_embs_[i](pos_embs_[i]->get_cos_sin_cache(), positions[i], 0); - auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = cos_sin_chunks[0].contiguous(); - auto sin_pos = cos_sin_chunks[1].contiguous(); - - auto attn_mask = attn_mask_.get_attn_mask( - 128, cos_pos.dtype().toScalarType(), cos_pos.device()); - cos_poss.push_back(std::move(cos_pos)); - sin_poss.push_back(std::move(sin_pos)); - attn_masks.push_back(std::move(attn_mask)); + torch::Tensor h = embed_tokens_(tokens, 0); + torch::Tensor enorm = enorm_(h, 0); + const auto& res = input_params.mm_data.get("embedding"); + if (res) { + h = res.value(); + } else { + LOG(WARNING) << "hnorm use embedding from tokens."; } + torch::Tensor hnorm = hnorm_(h, 0); + CHECK_EQ(enorm.dim(), hnorm.dim()); + CHECK_EQ(enorm.size(0), hnorm.size(0)); + h = torch::cat({enorm, hnorm}, /*dim=*/-1); + h = eh_proj_(h, 0); + + auto cos_sin = atb_pos_emb_(pos_emb_->get_cos_sin_cache(), positions, 0); + auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = cos_sin_chunks[0].contiguous(); + auto sin_pos = cos_sin_chunks[1].contiguous(); + + auto attn_mask = attn_mask_.get_attn_mask( + 128, cos_pos.dtype().toScalarType(), cos_pos.device()); for (size_t i = 0; i < layers_.size(); i++) { - std::vector events(micro_batch_num, nullptr); - std::vector*> event_flags(micro_batch_num, nullptr); - for (auto j = 0; j < micro_batch_num; ++j) { - if (input_params[j].layer_synchronizer != nullptr) { - events[j] = input_params[j].layer_synchronizer->get_event(i); - event_flags[j] = - input_params[j].layer_synchronizer->get_event_flag(i); - } - if (input_params[j].layer_wise_load_synchronizer != nullptr) { - if (!input_params[j].layer_wise_load_synchronizer->synchronize_layer( - i)) { - return torch::Tensor(); - } + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); } } + auto& layer = layers_[i]; - layer(hs, - cos_poss, - sin_poss, - attn_masks, + layer(h, + cos_pos, + sin_pos, + attn_mask, kv_caches[i], input_params, - events, - event_flags); + event, + event_flag); } - auto cancated_h = torch::cat(hs, 0); - return final_norm_(cancated_h, 0); + return final_norm_(h, 0); } // load the weight from the checkpoint @@ -189,10 +168,7 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { layers_[i]->load_state_dict( state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); } - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - eh_projs_[i]->load_state_dict( - state_dict.get_dict_with_prefix("eh_proj.")); - } + eh_proj_->load_state_dict(state_dict.get_dict_with_prefix("eh_proj.")); enorm_->load_state_dict(state_dict.get_dict_with_prefix("enorm.")); hnorm_->load_state_dict(state_dict.get_dict_with_prefix("hnorm.")); final_norm_->load_state_dict( @@ -204,9 +180,7 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + "."); } - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - eh_projs_[i]->verify_loaded_weights(prefix + "eh_proj."); - } + eh_proj_->verify_loaded_weights(prefix + "eh_proj."); enorm_->verify_loaded_weights(prefix + "enorm."); hnorm_->verify_loaded_weights(prefix + "hnorm."); final_norm_->verify_loaded_weights(prefix + "shared_head.norm."); @@ -216,19 +190,15 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { for (int i = 0; i < layers_.size(); i++) { layers_[i]->merge_loaded_weights(); } - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - eh_projs_[i]->merge_loaded_weights(); - } + eh_proj_->merge_loaded_weights(); enorm_->merge_loaded_weights(); hnorm_->merge_loaded_weights(); final_norm_->merge_loaded_weights(); } - std::vector get_word_embedding() { - return embed_tokens_; - } + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { embed_tokens_ = word_embedding; } @@ -243,11 +213,11 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { nlohmann::json mapping_data_; int32_t num_experts_per_tok_; at::Device device_; - std::vector embed_tokens_; - std::vector> pos_embs_; - std::vector atb_pos_embs_; + layer::WordEmbedding embed_tokens_{nullptr}; + std::shared_ptr pos_emb_{nullptr}; + layer::PosEmbedding atb_pos_emb_{nullptr}; layer::AttentionMask attn_mask_; - std::vector eh_projs_; + layer::ColumnParallelLinear eh_proj_{nullptr}; layer::RmsNorm enorm_{nullptr}; layer::RmsNorm hnorm_{nullptr}; layer::RmsNorm final_norm_{nullptr}; @@ -265,10 +235,10 @@ class DeepseekV2MtpForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { + const ModelInputParams& input_params) { return model_(tokens, positions, kv_caches, input_params); } @@ -305,11 +275,11 @@ class DeepseekV2MtpForCausalLMImpl : public torch::nn::Module { void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } @@ -379,4 +349,4 @@ REGISTER_MODEL_ARGS(deepseek_v3_mtp, [&] { SET_ARG(stop_token_ids, std::unordered_set({1})); }); -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/models/llm/embedding_model_base.h b/xllm/models/llm/embedding_model_base.h index e62527e99..058ab2545 100644 --- a/xllm/models/llm/embedding_model_base.h +++ b/xllm/models/llm/embedding_model_base.h @@ -35,11 +35,10 @@ class LlmForEmbeddingImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - virtual torch::Tensor forward( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& input_params) { + virtual torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { return model_(tokens, positions, kv_caches, input_params); } @@ -78,12 +77,11 @@ class LlmForEmbeddingImplBase : public torch::nn::Module { virtual void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - virtual std::vector get_word_embedding() { + virtual layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - virtual void set_word_embedding( - std::vector& word_embedding) { + virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } @@ -96,4 +94,4 @@ class LlmForEmbeddingImplBase : public torch::nn::Module { layer::LmHead lm_head_{nullptr}; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/models/llm/glm4_moe.h b/xllm/models/llm/glm4_moe.h index 41a5cddc2..714ba79c3 100644 --- a/xllm/models/llm/glm4_moe.h +++ b/xllm/models/llm/glm4_moe.h @@ -42,8 +42,8 @@ class Glm4MoeDecoderLayerImpl : public torch::nn::Module { KVCache& kv_cache, const ModelInputParams& input_params, torch::Tensor expert_array, - std::vector event, - std::vector*> event_flag) { + aclrtEvent* event, + std::atomic* event_flag) { return decoder_layer_(x, cos_pos, sin_pos, @@ -171,11 +171,11 @@ class Glm4MoeModelImpl : public torch::nn::Module { } for (size_t i = 0; i < layers_.size(); i++) { - std::vector events(1, nullptr); - std::vector*> event_flags(1, nullptr); + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; if (input_params.layer_synchronizer != nullptr) { - events[0] = input_params.layer_synchronizer->get_event(i); - event_flags[0] = input_params.layer_synchronizer->get_event_flag(i); + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); } if (input_params.layer_wise_load_synchronizer != nullptr) { if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { @@ -191,8 +191,8 @@ class Glm4MoeModelImpl : public torch::nn::Module { kv_caches[i], input_params, expert_array, - events, - event_flags); + event, + event_flag); } return norm_(h, 0); } @@ -226,12 +226,10 @@ class Glm4MoeModelImpl : public torch::nn::Module { norm_->merge_loaded_weights(); } - std::vector get_word_embedding() { - return {embed_tokens_}; - } + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - void set_word_embedding(std::vector& word_embedding) { - embed_tokens_ = word_embedding[0]; + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; } private: @@ -265,11 +263,11 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { - return model_(tokens[0], positions[0], kv_caches, input_params[0]); + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); } // hidden_states: [num_tokens, hidden_size] @@ -306,11 +304,11 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module { void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/llm/glm4_moe_mtp.h b/xllm/models/llm/glm4_moe_mtp.h index 83ff17bdf..41ef78944 100644 --- a/xllm/models/llm/glm4_moe_mtp.h +++ b/xllm/models/llm/glm4_moe_mtp.h @@ -147,11 +147,11 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module { torch::TensorOptions().dtype(torch::kInt32).device(tokens.device())); for (size_t i = 0; i < layers_.size(); i++) { - std::vector events(1, nullptr); - std::vector*> event_flags(1, nullptr); + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; if (input_params.layer_synchronizer != nullptr) { - events[0] = input_params.layer_synchronizer->get_event(i); - event_flags[0] = input_params.layer_synchronizer->get_event_flag(i); + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); } // TODO(liangzhiwei20): MTP need more support for layer wise copy. if (input_params.layer_wise_load_synchronizer != nullptr) { @@ -168,8 +168,8 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module { kv_caches[i], input_params, expert_array, - events, - event_flags); + event, + event_flag); } return final_norm_(h, 0); } @@ -212,12 +212,10 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module { final_norm_->merge_loaded_weights(); } - std::vector get_word_embedding() { - return {embed_tokens_}; - } + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - void set_word_embedding(std::vector& word_embedding) { - embed_tokens_ = word_embedding[0]; + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; } private: @@ -254,11 +252,11 @@ class Glm4MoeMtpForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { - return model_(tokens[0], positions[0], kv_caches, input_params[0]); + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); } // hidden_states: [num_tokens, hidden_size] @@ -294,11 +292,11 @@ class Glm4MoeMtpForCausalLMImpl : public torch::nn::Module { void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } @@ -347,4 +345,4 @@ REGISTER_MODEL_ARGS(glm4_moe_mtp, [&] { std::unordered_set(args->eos_token_id_vec().begin(), args->eos_token_id_vec().end())); }); -} // namespace xllm::hf \ No newline at end of file +} // namespace xllm::hf diff --git a/xllm/models/llm/llama.h b/xllm/models/llm/llama.h index e85169427..da82b9693 100644 --- a/xllm/models/llm/llama.h +++ b/xllm/models/llm/llama.h @@ -215,12 +215,10 @@ class LlamaModelImpl : public torch::nn::Module { norm_->merge_loaded_weights(); } - std::vector get_word_embedding() { - return {embed_tokens_}; - } + layer::WordEmbedding get_word_embedding() { return {embed_tokens_}; } - void set_word_embedding(std::vector& word_embedding) { - embed_tokens_ = word_embedding[0]; + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; } private: @@ -251,11 +249,11 @@ class LlamaForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { - return model_(tokens[0], positions[0], kv_caches, input_params[0]); + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); } // hidden_states: [num_tokens, hidden_size] @@ -290,11 +288,11 @@ class LlamaForCausalLMImpl : public torch::nn::Module { void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 6000a2ac5..989945667 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -88,25 +88,22 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { } #if defined(USE_NPU) - virtual torch::Tensor forward(std::vector& x, - std::vector& cos_pos, - std::vector& sin_pos, - std::vector& attn_mask, + virtual torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, KVCache& kv_cache, - std::vector& input_params, + ModelInputParams& input_params, int node_id, - std::vector event, - std::vector*> event_flag) { - auto micro_batch_num = x.size(); - for (auto i = 0; i < micro_batch_num; ++i) { - if (input_params[i].src_block_indices.numel() > 0) { - block_copy_(kv_cache.get_k_cache(), - kv_cache.get_v_cache(), - input_params[i].src_block_indices, - input_params[i].dst_block_indices, - input_params[i].cum_sum, - 0); - } + aclrtEvent* event, + std::atomic* event_flag) { + if (input_params.src_block_indices.numel() > 0) { + block_copy_(kv_cache.get_k_cache(), + kv_cache.get_v_cache(), + input_params.src_block_indices, + input_params.dst_block_indices, + input_params.cum_sum, + 0); } return decoder_layer_(x, @@ -164,135 +161,113 @@ class LlmModelImplBase : public torch::nn::Module { torch::Tensor get_input_embeddings(torch::Tensor input_ids) { #if defined(USE_NPU) - return embed_tokens_[0](input_ids, 0); + return embed_tokens_(input_ids, 0); #else - return embed_tokens_[0](input_ids); + return embed_tokens_(input_ids); #endif } // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - virtual torch::Tensor forward( - std::vector tokens, - std::vector positions, - std::vector& kv_caches, - const std::vector& input_params) { - auto micro_batch_num = tokens.size(); - std::vector hs; - hs.reserve(micro_batch_num); - std::vector cos_poss; - cos_poss.reserve(micro_batch_num); - std::vector sin_poss; - sin_poss.reserve(micro_batch_num); - std::vector attn_masks; - attn_masks.reserve(micro_batch_num); - std::vector& input_params_news = - const_cast&>(input_params); - - for (auto i = 0; i < micro_batch_num; ++i) { - if (tokens[i].numel() == 0) { - tokens[i] = torch::tensor({1}).to(torch::kInt32).to(tokens[0].device()); - positions[i] = - torch::tensor({0}).to(torch::kInt32).to(tokens[0].device()); - } - auto inputs_embeds = input_params[i].input_embedding; - // test - torch::Tensor h; - if (inputs_embeds.defined()) { - h = inputs_embeds; - } else { + virtual torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + // test + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { #if defined(USE_NPU) - h = embed_tokens_[i](tokens[i], 0); + h = embed_tokens_(tokens, 0); #else - h = embed_tokens_[i](tokens[i]); + h = embed_tokens_(tokens); #endif - } - hs.push_back(std::move(h)); + } + #if defined(USE_NPU) - auto target_cos_sin = atb_pos_embeds_[i](cos_sin_, positions[i], 0); - auto target_cos_sin_chunks = - target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = target_cos_sin_chunks[0].contiguous(); - auto sin_pos = target_cos_sin_chunks[1].contiguous(); - - if (positions[i].dim() == 2) { // mrope - auto apply = [this](torch::Tensor x) { - auto sections = mrope_section_; - sections.insert(sections.end(), sections.begin(), sections.end()); - - auto vec = x.split(sections, -1); - std::vector selects; - selects.reserve(vec.size()); - - for (int64_t i = 0; i < vec.size(); ++i) { - auto m = vec[i]; - selects.push_back(m[i % mrope_section_.size()]); - } - return torch::cat(selects, -1); - }; - cos_pos = apply(cos_pos.reshape( - {positions[i].sizes().front(), -1, cos_pos.sizes().back()})); - sin_pos = apply(sin_pos.reshape( - {positions[i].sizes().front(), -1, sin_pos.sizes().back()})); - } + auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); + auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + auto sin_pos = target_cos_sin_chunks[1].contiguous(); + + if (positions.dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + auto sections = mrope_section_; + sections.insert(sections.end(), sections.begin(), sections.end()); + + auto vec = x.split(sections, -1); + std::vector selects; + selects.reserve(vec.size()); + + for (int64_t i = 0; i < vec.size(); ++i) { + auto m = vec[i]; + selects.push_back(m[i % mrope_section_.size()]); + } + return torch::cat(selects, -1); + }; + cos_pos = apply(cos_pos.reshape( + {positions.sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions.sizes().front(), -1, sin_pos.sizes().back()})); + } - torch::Tensor attn_mask; - if (model_type_ == "qwen2") { - max_seq_len_ = - FLAGS_enable_chunked_prefill - ? std::max(input_params[i].kv_max_seq_len, max_seq_len_) - : 128; - attn_mask = attn_mask_.get_attn_mask( - max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); - } else { - max_seq_len_ = - FLAGS_enable_chunked_prefill - ? std::max(input_params[i].kv_max_seq_len, max_seq_len_) - : 128; - if (FLAGS_enable_chunked_prefill) { - int num_sequences = input_params[i].num_sequences; - if (num_sequences > 0) { - std::vector req_mask_vec; - req_mask_vec.reserve(num_sequences); - - for (int j = 0; j < num_sequences; j++) { - auto mask = - attn_mask_.gen_append_mask(input_params[i].q_seq_lens_vec[j], - input_params[i].kv_seq_lens_vec[j], - max_seq_len_, - cos_pos.dtype().toScalarType(), - cos_pos.device()); - req_mask_vec.emplace_back(mask); - } - attn_mask = torch::cat(req_mask_vec, 0); + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor attn_mask; + if (model_type_ == "qwen2") { + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input_params.kv_max_seq_len, max_seq_len_) + : 128; + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + } else { + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input_params.kv_max_seq_len, max_seq_len_) + : 128; + if (FLAGS_enable_chunked_prefill) { + int num_sequences = input_params.num_sequences; + if (num_sequences > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(num_sequences); + + for (int j = 0; j < num_sequences; j++) { + auto mask = + attn_mask_.gen_append_mask(input_params.q_seq_lens_vec[j], + input_params.kv_seq_lens_vec[j], + max_seq_len_, + cos_pos.dtype().toScalarType(), + cos_pos.device()); + req_mask_vec.emplace_back(mask); } - } else { - attn_mask = attn_mask_.get_attn_mask( - max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + attn_mask = torch::cat(req_mask_vec, 0); } + } else { + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); } - cos_poss.push_back(std::move(cos_pos)); - sin_poss.push_back(std::move(sin_pos)); - attn_masks.push_back(std::move(attn_mask)); -#endif } +#endif + #if defined(USE_NPU) for (size_t i = 0; i < layers_.size(); i++) { - std::vector events(micro_batch_num, nullptr); - std::vector*> event_flags(micro_batch_num, nullptr); - for (auto j = 0; j < micro_batch_num; ++j) { - if (input_params[j].layer_synchronizer != nullptr) { - events[j] = input_params[j].layer_synchronizer->get_event(i); - event_flags[j] = - input_params[j].layer_synchronizer->get_event_flag(i); - } - if (input_params[j].layer_wise_load_synchronizer != nullptr) { - if (!input_params[j].layer_wise_load_synchronizer->synchronize_layer( - i)) { - return torch::Tensor(); - } + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); } } + auto& layer = layers_[i]; if (layer_forward_interrupted_) { @@ -300,21 +275,21 @@ class LlmModelImplBase : public torch::nn::Module { return torch::Tensor(); } - layer(hs, - cos_poss, - sin_poss, - attn_masks, + layer(h, + cos_pos, + sin_pos, + attn_mask, kv_caches[i], - input_params_news, + input_params_new, i, - events, - event_flags); + event, + event_flag); } - auto cancated_h = torch::cat(hs, 0); - return norm_(cancated_h, 0); + + return norm_(h, 0); #else - auto modified_input_params = input_params[0]; - auto position = positions[0]; + auto modified_input_params = input_params; + auto position = positions; layer::update_dummy_run_input(dp_rank_, position, modified_input_params); bool is_prefill = modified_input_params.q_max_seq_len > 1; auto attn_metadata = @@ -324,7 +299,7 @@ class LlmModelImplBase : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; h = layer( - hs[0], position, attn_metadata, kv_caches[i], modified_input_params); + h, position, attn_metadata, kv_caches[i], modified_input_params); } return norm_(h); #endif @@ -332,10 +307,9 @@ class LlmModelImplBase : public torch::nn::Module { // load the weight from the checkpoint virtual void load_state_dict(const StateDict& state_dict) { - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_[i]->load_state_dict( - state_dict.get_dict_with_prefix("embed_tokens.")); - } + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + // call each layer's load_state_dict function for (int i = 0; i < layers_.size(); i++) { layers_[i]->load_state_dict( @@ -346,9 +320,8 @@ class LlmModelImplBase : public torch::nn::Module { #if defined(USE_NPU) virtual void verify_loaded_weights(const std::string& prefix) const { - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_[i]->verify_loaded_weights(prefix + "embed_tokens."); - } + embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); + for (int i = 0; i < layers_.size(); i++) { layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + "."); @@ -357,9 +330,8 @@ class LlmModelImplBase : public torch::nn::Module { } virtual void merge_loaded_weights() { - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_[i]->merge_loaded_weights(); - } + embed_tokens_->merge_loaded_weights(); + for (int i = 0; i < layers_.size(); i++) { layers_[i]->merge_loaded_weights(); } @@ -367,15 +339,10 @@ class LlmModelImplBase : public torch::nn::Module { } #endif - virtual std::vector get_word_embedding() { - return embed_tokens_; - } + virtual layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - virtual void set_word_embedding( - std::vector& word_embedding) { - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_[i] = word_embedding[i]; - } + virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; } protected: @@ -387,13 +354,13 @@ class LlmModelImplBase : public torch::nn::Module { layer::AttentionMask attn_mask_; int dp_rank_ = 0; #if defined(USE_NPU) - std::vector atb_pos_embeds_; + layer::PosEmbedding atb_pos_emb_{nullptr}; #endif std::vector mrope_section_; // test // ParallelEmbedding embed_tokens_{nullptr}; - std::vector embed_tokens_; + layer::WordEmbedding embed_tokens_{nullptr}; layer::RmsNorm norm_{nullptr}; torch::nn::ModuleList blocks_{nullptr}; @@ -424,11 +391,10 @@ class LlmForCausalLMImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - virtual torch::Tensor forward( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& input_params) { + virtual torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { return model_(tokens, positions, kv_caches, input_params); } @@ -482,12 +448,11 @@ class LlmForCausalLMImplBase : public torch::nn::Module { virtual void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - virtual std::vector get_word_embedding() { + virtual layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - virtual void set_word_embedding( - std::vector& word_embedding) { + virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/llm/qwen2.h b/xllm/models/llm/qwen2.h index b137e9939..af5bfacf4 100644 --- a/xllm/models/llm/qwen2.h +++ b/xllm/models/llm/qwen2.h @@ -47,12 +47,11 @@ class QWen2ModelImpl : public LlmModelImplBase { blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.n_layers()); norm_ = register_module("norm", layer::RmsNorm(context)); - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_.push_back(layer::WordEmbedding(context)); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); #if defined(USE_NPU) - atb_pos_embeds_.push_back(layer::PosEmbedding(context)); + atb_pos_emb_ = layer::PosEmbedding(context); #endif - } cos_sin_ = get_concat_rotary_embedding( model_args.hidden_size() / model_args.n_heads(), model_args.max_position_embeddings(), diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index 2d2e90851..277167ddc 100644 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -43,12 +43,11 @@ class QWen3ModelImpl : public LlmModelImplBase { blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.n_layers()); norm_ = register_module("norm", layer::RmsNorm(context)); - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_.push_back(layer::WordEmbedding(context)); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); #if defined(USE_NPU) - atb_pos_embeds_.push_back(layer::PosEmbedding(context)); + atb_pos_emb_ = layer::PosEmbedding(context); #endif - } cos_sin_ = get_concat_rotary_embedding(128, model_args.max_position_embeddings(), model_args.rope_theta(), @@ -81,154 +80,128 @@ class QWen3ModelImpl : public LlmModelImplBase { return hidden_states; } - virtual torch::Tensor forward( - std::vector tokens, - std::vector positions, - std::vector& kv_caches, - const std::vector& input_params) { - auto micro_batch_num = tokens.size(); - std::vector hs; - hs.reserve(micro_batch_num); - std::vector> deep_stacks; - deep_stacks.reserve(micro_batch_num); - bool use_deepstack = input_params[0].deep_stacks.size() > 0; - std::vector cos_poss; - cos_poss.reserve(micro_batch_num); - std::vector sin_poss; - sin_poss.reserve(micro_batch_num); - std::vector attn_masks; - attn_masks.reserve(micro_batch_num); - std::vector& input_params_news = - const_cast&>(input_params); - - for (auto i = 0; i < micro_batch_num; ++i) { - if (tokens[i].numel() == 0) { - tokens[i] = torch::tensor({1}).to(torch::kInt32).to(tokens[0].device()); - positions[i] = - torch::tensor({0}).to(torch::kInt32).to(tokens[0].device()); - } - auto inputs_embeds = input_params[i].input_embedding; - torch::Tensor h; - if (inputs_embeds.defined()) { - h = inputs_embeds; - } else { + virtual torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + bool use_deepstack = input_params.deep_stacks.size() > 0; + ModelInputParams& input_params_new = + const_cast(input_params); + std::vector deep_stacks; + + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { #if defined(USE_NPU) - h = embed_tokens_[i](tokens[i], 0); + h = embed_tokens_(tokens, 0); #else - h = embed_tokens_[i](tokens[i]); + h = embed_tokens_(tokens); #endif - } - hs.push_back(std::move(h)); + } #if defined(USE_NPU) - if (use_deepstack) { - deep_stacks.push_back( - input_params[i].deep_stacks); // [num_deepstack, hidden_size] - } - auto target_cos_sin = atb_pos_embeds_[i](cos_sin_, positions[i], 0); - auto target_cos_sin_chunks = - target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = target_cos_sin_chunks[0].contiguous(); - auto sin_pos = target_cos_sin_chunks[1].contiguous(); - - if (positions[i].dim() == 2) { // mrope - auto apply = [this](torch::Tensor x) { - auto freqs_t = x[0].clone(); - for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { - int64_t offset = dim_idx; - int64_t section_len = mrope_section_[dim_idx]; - int64_t length = section_len * 3; - auto idx_first_half = - torch::arange(offset, length, 3, torch::kLong); - auto idx_second_half = - torch::arange(offset, length, 3, torch::kLong); - auto idx_tensor = - torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); - // freqs_t[..., idx] = freqs[dim_idx][..., idx] - auto src = x[dim_idx].index_select(-1, idx_tensor); - freqs_t.index_copy_(-1, idx_tensor, src); - } - return freqs_t; - }; - cos_pos = apply(cos_pos.reshape( - {positions[i].sizes().front(), -1, cos_pos.sizes().back()})); - sin_pos = apply(sin_pos.reshape( - {positions[i].sizes().front(), -1, sin_pos.sizes().back()})); - } + if (use_deepstack) { + deep_stacks = input_params.deep_stacks; // [num_deepstack, hidden_size] + } + auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); + auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + auto sin_pos = target_cos_sin_chunks[1].contiguous(); - torch::Tensor attn_mask; - - torch::Tensor max_of_seq = torch::max(input_params[i].kv_seq_lens); - max_seq_len_ = FLAGS_enable_chunked_prefill - ? std::max(max_of_seq.item(), max_seq_len_) - : 128; - attn_mask = attn_mask_.get_attn_mask( - max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); - - if (FLAGS_enable_chunked_prefill) { - int batch_size = input_params[i].q_seq_lens_vec.size(); - if (batch_size > 0) { - std::vector req_mask_vec; - req_mask_vec.reserve(batch_size); - - for (int j = 0; j < batch_size; j++) { - int start = input_params[i].kv_seq_lens_vec[j] - - input_params[i].q_seq_lens_vec[j]; - int end = input_params[i].kv_seq_lens_vec[j]; - - auto req_mask_slice = attn_mask.slice(0, start, end); - req_mask_vec.emplace_back(req_mask_slice); - } - attn_mask = torch::cat(req_mask_vec, 0); + if (positions.dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + auto freqs_t = x[0].clone(); + for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { + int64_t offset = dim_idx; + int64_t section_len = mrope_section_[dim_idx]; + int64_t length = section_len * 3; + auto idx_first_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_second_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_tensor = + torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); + // freqs_t[..., idx] = freqs[dim_idx][..., idx] + auto src = x[dim_idx].index_select(-1, idx_tensor); + freqs_t.index_copy_(-1, idx_tensor, src); } + return freqs_t; + }; + cos_pos = apply(cos_pos.reshape( + {positions.sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions.sizes().front(), -1, sin_pos.sizes().back()})); + } + + torch::Tensor attn_mask; + + torch::Tensor max_of_seq = torch::max(input_params.kv_seq_lens); + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(max_of_seq.item(), max_seq_len_) + : 128; + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + + if (FLAGS_enable_chunked_prefill) { + int batch_size = input_params.q_seq_lens_vec.size(); + if (batch_size > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(batch_size); + + for (int j = 0; j < batch_size; j++) { + int start = + input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j]; + int end = input_params.kv_seq_lens_vec[j]; + + auto req_mask_slice = attn_mask.slice(0, start, end); + req_mask_vec.emplace_back(req_mask_slice); + } + attn_mask = torch::cat(req_mask_vec, 0); } + } - cos_poss.push_back(std::move(cos_pos)); - sin_poss.push_back(std::move(sin_pos)); - attn_masks.push_back(std::move(attn_mask)); #endif - } + #if defined(USE_NPU) for (size_t i = 0; i < layers_.size(); i++) { - std::vector events(micro_batch_num, nullptr); - std::vector*> event_flags(micro_batch_num, nullptr); - for (auto j = 0; j < micro_batch_num; ++j) { - if (input_params[j].layer_synchronizer != nullptr) { - events[j] = input_params[j].layer_synchronizer->get_event(i); - event_flags[j] = - input_params[j].layer_synchronizer->get_event_flag(i); - } - if (input_params[j].layer_wise_load_synchronizer != nullptr) { - if (!input_params[j].layer_wise_load_synchronizer->synchronize_layer( - i)) { - return torch::Tensor(); - } + aclrtEvent* event{nullptr}; + std::atomic* event_flag{nullptr}; + + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); } } + auto& layer = layers_[i]; - layer(hs, - cos_poss, - sin_poss, - attn_masks, + layer(h, + cos_pos, + sin_pos, + attn_mask, kv_caches[i], - input_params_news, + input_params_new, i, - events, - event_flags); + event, + event_flag); if (use_deepstack) { - for (auto j = 0; j < micro_batch_num; ++j) { - if (deep_stacks[j].size() > 0 && i < deep_stacks[j].size()) { - hs[j] = deepstack_process( - hs[j], input_params[j].visual_pos_masks, deep_stacks[j][i]); - } + if (deep_stacks.size() > 0 && i < deep_stacks.size()) { + h = deepstack_process( + h, input_params.visual_pos_masks, deep_stacks[i]); } } } - auto cancated_h = torch::cat(hs, 0); - return norm_(cancated_h, 0); + return norm_(h, 0); #else - auto modified_input_params = input_params[0]; - auto position = positions[0]; + auto modified_input_params = input_params; + auto position = positions; layer::update_dummy_run_input(dp_rank_, position, modified_input_params); bool is_prefill = modified_input_params.q_max_seq_len > 1; auto attn_metadata = @@ -238,7 +211,7 @@ class QWen3ModelImpl : public LlmModelImplBase { for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; h = layer( - hs[0], position, attn_metadata, kv_caches[i], modified_input_params); + h, positions, attn_metadata, kv_caches[i], modified_input_params); } return norm_(h); #endif diff --git a/xllm/models/llm/qwen3_embedding.h b/xllm/models/llm/qwen3_embedding.h index 9315590d5..57e1604a6 100644 --- a/xllm/models/llm/qwen3_embedding.h +++ b/xllm/models/llm/qwen3_embedding.h @@ -40,11 +40,10 @@ class EmbeddingLMImpl : public EmbeddingLM { const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - torch::Tensor forward( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& parameters) override { + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& parameters) override { return model_->forward(tokens, positions, kv_caches, parameters); } @@ -77,11 +76,10 @@ class EmbeddingLMImpl : public EmbeddingLM { // Delegate head/embedding accessors to underlying model implementation. layer::LmHead get_lm_head() override { return model_->get_lm_head(); } void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); } - std::vector get_word_embedding() override { + layer::WordEmbedding get_word_embedding() override { return model_->get_word_embedding(); } - void set_word_embedding( - std::vector& embedding) override { + void set_word_embedding(layer::WordEmbedding& embedding) override { model_->set_word_embedding(embedding); } diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 62de1214d..dec5d0159 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -331,12 +331,10 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } #endif - std::vector get_word_embedding() { - return {embed_tokens_}; - } + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - void set_word_embedding(std::vector& word_embedding) { - embed_tokens_ = word_embedding[0]; + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { #if defined(USE_NPU) @@ -382,11 +380,11 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { - return model_(tokens[0], positions[0], kv_caches, input_params[0]); + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); } // hidden_states: [num_tokens, hidden_size] @@ -437,11 +435,11 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/vlm/minicpmv.h b/xllm/models/vlm/minicpmv.h index 2dd3a60ac..89920461b 100644 --- a/xllm/models/vlm/minicpmv.h +++ b/xllm/models/vlm/minicpmv.h @@ -975,12 +975,12 @@ class MiniCPMV2_6Impl : public torch::nn::Module { } } - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { + const ModelInputParams& input_params) { torch::NoGradGuard no_grad; - const auto& mm_data = input_params[0].mm_data; + const auto& mm_data = input_params.mm_data; torch::Tensor image_embeds; if (const auto& res = mm_data.get("image_embeds")) @@ -995,11 +995,10 @@ class MiniCPMV2_6Impl : public torch::nn::Module { torch::Tensor image_embedding; std::optional image_inputs; if (image_embeds.defined()) { - image_inputs = - generate_image_inputs({}, tokens[0], image_embeds, tgt_sizes); + image_inputs = generate_image_inputs({}, tokens, image_embeds, tgt_sizes); } else if (pixel_values.size() > 0) { image_inputs = generate_image_inputs( - pixel_values, tokens[0], torch::Tensor(), tgt_sizes); + pixel_values, tokens, torch::Tensor(), tgt_sizes); } image_embedding = get_vision_embedding(image_inputs); @@ -1011,8 +1010,8 @@ class MiniCPMV2_6Impl : public torch::nn::Module { image_embedding = mlp_(image_embedding); } - input_params[0].input_embedding = - merge_text_vision_embeddings(tokens[0], image_inputs, image_embedding); + input_params.input_embedding = + merge_text_vision_embeddings(tokens, image_inputs, image_embedding); return language_model_(tokens, positions, kv_caches, input_params); } @@ -1254,11 +1253,11 @@ class MiniCPMV2_6Impl : public torch::nn::Module { void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return language_model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { language_model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/vlm/qwen2_5_vl.h b/xllm/models/vlm/qwen2_5_vl.h old mode 100755 new mode 100644 index fc373065c..ec6e6aa4a --- a/xllm/models/vlm/qwen2_5_vl.h +++ b/xllm/models/vlm/qwen2_5_vl.h @@ -700,12 +700,12 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { + const ModelInputParams& input_params) { torch::NoGradGuard no_grad; - const auto& mm_data = input_params[0].mm_data; + const auto& mm_data = input_params.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) @@ -720,9 +720,9 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { if (pixel_values.defined() && image_grid_thw.defined()) image_inputs = Qwen2_5_VLImageInputs{pixel_values, image_grid_thw}; - auto inputs_embeds = get_input_embeddings( - tokens[0], image_inputs, video_inputs, input_params[0]); - input_params[0].input_embedding = inputs_embeds; + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; auto emb = language_model_(tokens, positions, kv_caches, input_params); @@ -749,11 +749,11 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return language_model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { language_model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/vlm/qwen3_vl.h b/xllm/models/vlm/qwen3_vl.h old mode 100755 new mode 100644 index dae43c2d4..de63fb427 --- a/xllm/models/vlm/qwen3_vl.h +++ b/xllm/models/vlm/qwen3_vl.h @@ -669,12 +669,12 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { + const ModelInputParams& input_params) { torch::NoGradGuard no_grad; - const auto& mm_data = input_params[0].mm_data; + const auto& mm_data = input_params.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -688,9 +688,9 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { if (pixel_values.defined() && image_grid_thw.defined()) image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; - auto inputs_embeds = get_input_embeddings( - tokens[0], image_inputs, video_inputs, input_params[0]); - input_params[0].input_embedding = inputs_embeds; + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; auto emb = language_model_(tokens, positions, kv_caches, input_params); return emb; @@ -717,11 +717,11 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return language_model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { language_model_->set_word_embedding(word_embedding); } diff --git a/xllm/models/vlm/qwen3_vl_moe.h b/xllm/models/vlm/qwen3_vl_moe.h index 6a94835a5..081711fa5 100644 --- a/xllm/models/vlm/qwen3_vl_moe.h +++ b/xllm/models/vlm/qwen3_vl_moe.h @@ -74,12 +74,12 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { + const ModelInputParams& input_params) { torch::NoGradGuard no_grad; - const auto& mm_data = input_params[0].mm_data; + const auto& mm_data = input_params.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -93,9 +93,9 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { if (pixel_values.defined() && image_grid_thw.defined()) image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; - auto inputs_embeds = get_input_embeddings( - tokens[0], image_inputs, video_inputs, input_params[0]); - input_params[0].input_embedding = inputs_embeds; + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; auto emb = language_model_(tokens, positions, kv_caches, input_params); return emb; @@ -122,11 +122,11 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return language_model_->get_word_embedding(); } - void set_word_embedding(std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { language_model_->set_word_embedding(word_embedding); } From d451bf5adfb0e351ea9e3ad3605319a969fa12b6 Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Mon, 17 Nov 2025 20:57:18 +0800 Subject: [PATCH 2/2] feat: revert the original code before refactoring the multi-stream[2/2]. Signed-off-by: Tao Peng --- .../core/distributed_runtime/comm_channel.cpp | 20 +- xllm/core/distributed_runtime/comm_channel.h | 2 +- .../distributed_runtime/remote_worker.cpp | 11 +- xllm/core/distributed_runtime/remote_worker.h | 2 +- .../distributed_runtime/worker_service.cpp | 174 ++---- .../core/distributed_runtime/worker_service.h | 4 +- xllm/core/runtime/acl_graph_executor_impl.cpp | 26 +- xllm/core/runtime/acl_graph_executor_impl.h | 8 +- xllm/core/runtime/base_executor_impl.cpp | 11 +- xllm/core/runtime/base_executor_impl.h | 6 +- xllm/core/runtime/embed_vlm_worker_impl.cpp | 13 +- xllm/core/runtime/embed_vlm_worker_impl.h | 3 +- xllm/core/runtime/embed_worker_impl.cpp | 16 +- xllm/core/runtime/embed_worker_impl.h | 3 +- xllm/core/runtime/executor.cpp | 6 +- xllm/core/runtime/executor.h | 6 +- xllm/core/runtime/executor_impl.h | 6 +- xllm/core/runtime/llm_engine.cpp | 108 ++-- xllm/core/runtime/llm_engine.h | 3 +- xllm/core/runtime/llm_worker_impl.cpp | 91 ++- xllm/core/runtime/llm_worker_impl.h | 3 +- xllm/core/runtime/master.cpp | 4 + xllm/core/runtime/speculative_worker_impl.cpp | 584 ++++++++---------- xllm/core/runtime/speculative_worker_impl.h | 25 +- xllm/core/runtime/vlm_engine.cpp | 67 +- xllm/core/runtime/vlm_engine.h | 3 +- xllm/core/runtime/vlm_worker_impl.cpp | 12 +- xllm/core/runtime/vlm_worker_impl.h | 3 +- xllm/core/runtime/worker.cpp | 7 +- xllm/core/runtime/worker.h | 2 +- xllm/core/runtime/worker_client.cpp | 9 +- xllm/core/runtime/worker_client.h | 2 +- xllm/core/runtime/worker_impl.cpp | 150 +++-- xllm/core/runtime/worker_impl.h | 10 +- xllm/models/llm/llm_model_base.h | 6 +- xllm/models/llm/mlu/deepseek_v2.h | 21 +- xllm/models/llm/qwen3.h | 6 +- xllm/proto/worker.proto | 2 +- 38 files changed, 598 insertions(+), 837 deletions(-) mode change 100755 => 100644 xllm/core/distributed_runtime/worker_service.cpp mode change 100755 => 100644 xllm/core/runtime/llm_engine.cpp mode change 100755 => 100644 xllm/core/runtime/master.cpp mode change 100755 => 100644 xllm/core/runtime/vlm_engine.h diff --git a/xllm/core/distributed_runtime/comm_channel.cpp b/xllm/core/distributed_runtime/comm_channel.cpp index 08d1c3e00..c538ecd62 100644 --- a/xllm/core/distributed_runtime/comm_channel.cpp +++ b/xllm/core/distributed_runtime/comm_channel.cpp @@ -507,22 +507,14 @@ bool CommChannel::get_active_activation_memory_async( bool CommChannel::execute_model_with_brpc( const std::vector& inputs, folly::Promise>& promise) { - // convert to proto::BatchedForwardInputs - proto::BatchedForwardInputs pb_batched_fwd_inputs; - std::vector batched_fwd_inputs_vec; - batched_fwd_inputs_vec.reserve(inputs.size()); - for (auto i = 0; i < inputs.size(); ++i) { - proto::ForwardInput pb_fwd_input; - forward_input_to_proto(inputs[i], &pb_fwd_input); - batched_fwd_inputs_vec.push_back(std::move(pb_fwd_input)); - } - ADD_VECTOR_TO_PROTO(pb_batched_fwd_inputs.mutable_micro_inputs(), - batched_fwd_inputs_vec); + // convert to proto::ForwardInput + proto::ForwardInput pb_forward_input; + forward_input_to_proto(inputs[0], &pb_forward_input); + // call ExecuteModel with callback auto done = new ExecuteModelClosure(); done->promise = std::move(promise); - stub_->ExecuteModel( - &done->cntl, &pb_batched_fwd_inputs, &done->pb_output, done); + stub_->ExecuteModel(&done->cntl, &pb_forward_input, &done->pb_output, done); return true; } @@ -567,4 +559,4 @@ void TransferBlocksClosure::Run() { return; } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/distributed_runtime/comm_channel.h b/xllm/core/distributed_runtime/comm_channel.h index 975f4bda2..e3d0824c9 100644 --- a/xllm/core/distributed_runtime/comm_channel.h +++ b/xllm/core/distributed_runtime/comm_channel.h @@ -145,4 +145,4 @@ class TransferBlocksClosure : public google::protobuf::Closure { brpc::Controller cntl; folly::Promise promise; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/distributed_runtime/remote_worker.cpp b/xllm/core/distributed_runtime/remote_worker.cpp index fc244cf89..48c36ca28 100644 --- a/xllm/core/distributed_runtime/remote_worker.cpp +++ b/xllm/core/distributed_runtime/remote_worker.cpp @@ -167,13 +167,14 @@ folly::SemiFuture> RemoteWorker::step_async( } folly::SemiFuture> RemoteWorker::step_async( - const std::vector& inputs) { + const RawForwardInput& inputs) { folly::Promise> promise; auto future = promise.getSemiFuture(); - threadpool_.schedule( - [this, inputs = inputs, promise = std::move(promise)]() mutable { - channel_->execute_model_async(inputs, promise); - }); + threadpool_.schedule([this, + inputs = std::move(inputs), + promise = std::move(promise)]() mutable { + channel_->execute_model_async({inputs}, promise); + }); return future; } diff --git a/xllm/core/distributed_runtime/remote_worker.h b/xllm/core/distributed_runtime/remote_worker.h index 478f9e73d..65fa3eef4 100644 --- a/xllm/core/distributed_runtime/remote_worker.h +++ b/xllm/core/distributed_runtime/remote_worker.h @@ -127,7 +127,7 @@ class RemoteWorker : public WorkerClient { const ForwardInput& inputs) override; virtual folly::SemiFuture> step_async( - const std::vector& inputs) override; + const RawForwardInput& inputs) override; virtual folly::SemiFuture process_group_test_async() override; diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp old mode 100755 new mode 100644 index 171cb07e4..ca8079b4a --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -66,7 +66,7 @@ void WorkerService::set_worker(std::unique_ptr worker) { initialized_ = true; } -void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs, +void WorkerService::step(ForwardInput& fwd_input, torch::Tensor& next_tokens, torch::Tensor& logprobs, torch::Tensor& top_tokens, @@ -78,7 +78,7 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs, torch::Tensor& out_tokens, torch::Tensor& out_logprobs) { // execute model - auto future = worker_->step_async(batched_fwd_inputs); + auto future = worker_->step_async(fwd_input); if (!options_.enable_schedule_overlap()) { auto forward_outputs = std::move(future).get(); @@ -142,10 +142,10 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); auto total_prefill_seq_len = 0; auto total_num_sequences = 0; - for (auto& input : batched_fwd_inputs.micro_inputs) { - total_num_sequences += input.input_params.num_sequences; - total_prefill_seq_len += input.input_params.prefill_seq_len; - } + + total_num_sequences += fwd_input.input_params.num_sequences; + total_prefill_seq_len += fwd_input.input_params.prefill_seq_len; + next_tokens = torch::arange(-1, -1 * (total_num_sequences - total_prefill_seq_len + 1), @@ -166,7 +166,7 @@ void WorkerService::create_polling_shm_thread( output_shm_manager = std::move(output_shm_manager)]() mutable { Timer timer; while (true) { - BatchedForwardInputs batched_fwd_inputs; + ForwardInput fwd_input; std::vector inputs; input_shm_manager->raw_input_read(inputs); timer.reset(); @@ -184,31 +184,9 @@ void WorkerService::create_polling_shm_thread( torch::Tensor out_tokens; torch::Tensor out_logprobs; - auto micro_batches_num = inputs.size(); - batched_fwd_inputs.micro_inputs = std::move(inputs); - batched_fwd_inputs.concated_sampling_params = - batched_fwd_inputs.micro_inputs[0].sampling_params; - for (auto i = 1; i < micro_batches_num; ++i) { - batched_fwd_inputs.concated_sampling_params.concat( - batched_fwd_inputs.micro_inputs[i].sampling_params); - } - - // concat acc_logprob here for beam search together - if (micro_batches_num > 1) { - std::vector acc_logprob_vec; - acc_logprob_vec.reserve(micro_batches_num); - for (auto i = 0; i < micro_batches_num; ++i) { - acc_logprob_vec.push_back( - batched_fwd_inputs.micro_inputs[i].acc_logprob); - } - batched_fwd_inputs.acc_logprob = - torch::cat(acc_logprob_vec, /*dim=*/-1); - } else { - batched_fwd_inputs.acc_logprob = - batched_fwd_inputs.micro_inputs[0].acc_logprob; - } + fwd_input = std::move(inputs[0]); - step(batched_fwd_inputs, + step(fwd_input, next_tokens, logprobs, top_tokens, @@ -598,90 +576,58 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller, return; } -void WorkerService::ExecuteModel( - ::google::protobuf::RpcController* controller, - const proto::BatchedForwardInputs* pb_batched_fwd_inputs, - proto::ForwardOutput* pb_forward_output, - ::google::protobuf::Closure* done) { - threadpool_->schedule([this, - controller, - pb_batched_fwd_inputs, - pb_forward_output, - done]() mutable { - brpc::ClosureGuard done_guard(done); - Timer timer; - // convert proto::BatchedForwardInputs to BatchedForwardInputs - auto micro_batches_num = pb_batched_fwd_inputs->micro_inputs().size(); - BatchedForwardInputs batched_fwd_inputs; - batched_fwd_inputs.micro_inputs.reserve(micro_batches_num); - for (auto i = 0; i < micro_batches_num; ++i) { - ForwardInput forward_input; - proto_to_forward_input(&(pb_batched_fwd_inputs->micro_inputs()[i]), - forward_input, - options_.num_decoding_tokens()); - batched_fwd_inputs.micro_inputs.push_back(std::move(forward_input)); - } - - // concat sampling parameters - batched_fwd_inputs.concated_sampling_params = - batched_fwd_inputs.micro_inputs[0].sampling_params; - for (auto i = 1; i < micro_batches_num; ++i) { - batched_fwd_inputs.concated_sampling_params.concat( - batched_fwd_inputs.micro_inputs[i].sampling_params); - } - - // concat acc_logprob here for beam search together - if (micro_batches_num > 1) { - std::vector acc_logprob_vec; - acc_logprob_vec.reserve(micro_batches_num); - for (auto i = 0; i < micro_batches_num; ++i) { - acc_logprob_vec.push_back( - batched_fwd_inputs.micro_inputs[i].acc_logprob); - } - batched_fwd_inputs.acc_logprob = torch::cat(acc_logprob_vec, /*dim=*/-1); - } else { - batched_fwd_inputs.acc_logprob = - batched_fwd_inputs.micro_inputs[0].acc_logprob; - } +void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, + const proto::ForwardInput* pb_forward_input, + proto::ForwardOutput* pb_forward_output, + ::google::protobuf::Closure* done) { + threadpool_->schedule( + [this, controller, pb_forward_input, pb_forward_output, done]() mutable { + brpc::ClosureGuard done_guard(done); + // convert proto::ForwardInput to ForwardInput - // model output - torch::Tensor next_tokens; - torch::Tensor logprobs; - torch::Tensor top_tokens; - torch::Tensor top_logprobs; - torch::Tensor embeddings; - torch::Tensor expert_load_data; - int32_t prepared_layer_id = -1; - // beam search kernel output - torch::Tensor src_seq_idxes; - torch::Tensor out_tokens; - torch::Tensor out_logprobs; - - step(batched_fwd_inputs, - next_tokens, - logprobs, - top_tokens, - top_logprobs, - embeddings, - expert_load_data, - prepared_layer_id, - src_seq_idxes, - out_tokens, - out_logprobs); - // convert to proto output - forward_output_to_proto(next_tokens, - logprobs, - top_tokens, - top_logprobs, - embeddings, - expert_load_data, - prepared_layer_id, - src_seq_idxes, - out_tokens, - out_logprobs, - pb_forward_output); - COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds()); - }); + Timer timer; + ForwardInput forward_input; + proto_to_forward_input( + pb_forward_input, forward_input, options_.num_decoding_tokens()); + + // model output + torch::Tensor next_tokens; + torch::Tensor logprobs; + torch::Tensor top_tokens; + torch::Tensor top_logprobs; + torch::Tensor embeddings; + torch::Tensor expert_load_data; + int32_t prepared_layer_id = -1; + // beam search kernel output + torch::Tensor src_seq_idxes; + torch::Tensor out_tokens; + torch::Tensor out_logprobs; + + step(forward_input, + next_tokens, + logprobs, + top_tokens, + top_logprobs, + embeddings, + expert_load_data, + prepared_layer_id, + src_seq_idxes, + out_tokens, + out_logprobs); + // convert to proto output + forward_output_to_proto(next_tokens, + logprobs, + top_tokens, + top_logprobs, + embeddings, + expert_load_data, + prepared_layer_id, + src_seq_idxes, + out_tokens, + out_logprobs, + pb_forward_output); + COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds()); + }); } void WorkerService::GetLastStepResult( diff --git a/xllm/core/distributed_runtime/worker_service.h b/xllm/core/distributed_runtime/worker_service.h index 044960bfd..7c6e4a7a8 100644 --- a/xllm/core/distributed_runtime/worker_service.h +++ b/xllm/core/distributed_runtime/worker_service.h @@ -111,7 +111,7 @@ class WorkerService : public proto::DistributeWorker { ::google::protobuf::Closure* done) override; void ExecuteModel(::google::protobuf::RpcController* controller, - const proto::BatchedForwardInputs* pb_batched_fwd_inputs, + const proto::ForwardInput* pb_fwd_input, proto::ForwardOutput* pb_forward_output, ::google::protobuf::Closure* done) override; @@ -126,7 +126,7 @@ class WorkerService : public proto::DistributeWorker { ::google::protobuf::Closure* done) override; private: - void step(BatchedForwardInputs& batched_fwd_inputs, + void step(ForwardInput& fwd_input, torch::Tensor& next_tokens, torch::Tensor& logprobs, torch::Tensor& top_tokens, diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 1100a9bc2..ea7656f2b 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -187,15 +187,14 @@ ForwardInput AclGraphExecutorImpl::prepare_inputs(Batch& batch) { // tokens: [num_decode_tokens] // positions: [num_decode_tokens] token pos in the sequence // returns: [num_decode_tokens, hidden_size] -torch::Tensor AclGraphExecutorImpl::run( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& params) { +torch::Tensor AclGraphExecutorImpl::run(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& params) { // no mirco batch in decode phase - const torch::Tensor& tokens_tensor = tokens[0]; - const torch::Tensor& positions_tensor = positions[0]; - const ModelInputParams& params_single = params[0]; + const torch::Tensor& tokens_tensor = tokens; + const torch::Tensor& positions_tensor = positions; + const ModelInputParams& params_single = params; // Identify decode phase using q_max_seq_len for precise detection // Decode phase: all sequences have q_seq_len == 1 (generating one token at a // time) Prefill phase: sequences have q_seq_len > 1 (processing multiple @@ -207,7 +206,7 @@ torch::Tensor AclGraphExecutorImpl::run( // If not in decode phase, use eager mode directly without acl graph if (!in_decoding_phase) { COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens[0], positions[0], kv_caches, params[0]); + return model_->forward(tokens, positions, kv_caches, params); } // Only use acl graph in decode phase for performance optimization @@ -229,15 +228,12 @@ torch::Tensor AclGraphExecutorImpl::run( // Combined condition for graph capture support // ACL graph executor only supports single tensor inputs (no micro-batching) - const bool single_input = - (tokens.size() == 1) && (positions.size() == 1) && (params.size() == 1); - const bool capture_supported = - single_input && seq_len_supported && same_num_decoding_tokens; + const bool capture_supported = seq_len_supported && same_num_decoding_tokens; // Early return if conditions are not suitable for graph operations if (!capture_supported) { COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens[0], positions[0], kv_caches, params[0]); + return model_->forward(tokens, positions, kv_caches, params); } // Check if captured graph exists for this bucket size @@ -273,7 +269,7 @@ torch::Tensor AclGraphExecutorImpl::run( // Fallback to eager mode if capture fails LOG(ERROR) << "Failed to capture ACL graph for bucket size: " << bucket_size; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens[0], positions[0], kv_caches, params[0]); + return model_->forward(tokens, positions, kv_caches, params); } void AclGraph::copy_data_to_graph_buffer(const torch::Tensor& tokens, diff --git a/xllm/core/runtime/acl_graph_executor_impl.h b/xllm/core/runtime/acl_graph_executor_impl.h index 660a8b716..3c5b827a3 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.h +++ b/xllm/core/runtime/acl_graph_executor_impl.h @@ -101,10 +101,10 @@ class AclGraphExecutorImpl : public ExecutorImpl { ForwardInput prepare_inputs(Batch& batch) override; // Execute model with graph optimization for decode phase - torch::Tensor run(const std::vector& tokens, - const std::vector& positions, + torch::Tensor run(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& params) override; + const ModelInputParams& params) override; private: // not own @@ -123,4 +123,4 @@ class AclGraphExecutorImpl : public ExecutorImpl { uint32_t get_bucket_size(uint32_t batch_size) const; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/runtime/base_executor_impl.cpp b/xllm/core/runtime/base_executor_impl.cpp index 4216ed7bf..00f1902bf 100644 --- a/xllm/core/runtime/base_executor_impl.cpp +++ b/xllm/core/runtime/base_executor_impl.cpp @@ -31,12 +31,11 @@ ForwardInput BaseExecutorImpl::prepare_inputs(Batch& batch) { return batch.prepare_forward_input(options_.num_decoding_tokens(), 0, args_); } -torch::Tensor BaseExecutorImpl::run( - const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& params) { - return model_->forward(tokens[0], positions[0], kv_caches, params[0]); +torch::Tensor BaseExecutorImpl::run(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& params) { + return model_->forward(tokens, positions, kv_caches, params); } } // namespace xllm diff --git a/xllm/core/runtime/base_executor_impl.h b/xllm/core/runtime/base_executor_impl.h index bede1a974..3c5b04508 100644 --- a/xllm/core/runtime/base_executor_impl.h +++ b/xllm/core/runtime/base_executor_impl.h @@ -40,10 +40,10 @@ class BaseExecutorImpl : public ExecutorImpl { ForwardInput prepare_inputs(Batch& batch) override; - torch::Tensor run(const std::vector& tokens, - const std::vector& positions, + torch::Tensor run(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& params) override; + const ModelInputParams& params) override; private: // not own diff --git a/xllm/core/runtime/embed_vlm_worker_impl.cpp b/xllm/core/runtime/embed_vlm_worker_impl.cpp index bf39fc9e4..c70891921 100644 --- a/xllm/core/runtime/embed_vlm_worker_impl.cpp +++ b/xllm/core/runtime/embed_vlm_worker_impl.cpp @@ -53,7 +53,7 @@ bool EmbedVLMWorkerImpl::init_model(ModelContext& context) { } std::optional EmbedVLMWorkerImpl::step( - const BatchedForwardInputs& inputs) { + const ForwardInput& input) { torch::DeviceGuard device_guard(device_); auto ret = device_.synchronize_default_stream(); @@ -61,15 +61,14 @@ std::optional EmbedVLMWorkerImpl::step( // TODO to adapt multi stream parallel later, just use [0] temporarily // all tensors should be on the same device as model - auto flatten_tokens = inputs.micro_inputs[0].token_ids.to(device_); - auto flatten_positions = inputs.micro_inputs[0].positions.to(device_); - auto params = inputs.micro_inputs[0].input_params.to(device_); - auto sampling_params = - inputs.micro_inputs[0].sampling_params.to(device_, dtype_); + auto flatten_tokens = input.token_ids.to(device_); + auto flatten_positions = input.positions.to(device_); + auto params = input.input_params.to(device_); + auto sampling_params = input.sampling_params.to(device_, dtype_); // call model executor forward to get hidden states auto hidden_states = model_executor_->forward( - {flatten_tokens}, {flatten_positions}, kv_caches_, {params}); + flatten_tokens, flatten_positions, kv_caches_, params); ret = device_.synchronize_default_stream(); COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); diff --git a/xllm/core/runtime/embed_vlm_worker_impl.h b/xllm/core/runtime/embed_vlm_worker_impl.h index 059db25aa..79d100bc6 100644 --- a/xllm/core/runtime/embed_vlm_worker_impl.h +++ b/xllm/core/runtime/embed_vlm_worker_impl.h @@ -40,8 +40,7 @@ class EmbedVLMWorkerImpl : public WorkerImpl { bool init_model(ModelContext& context) override; - std::optional step( - const BatchedForwardInputs& inputs) override; + std::optional step(const ForwardInput& input) override; }; } // namespace xllm diff --git a/xllm/core/runtime/embed_worker_impl.cpp b/xllm/core/runtime/embed_worker_impl.cpp index 4a6eaa7cc..92be2cc88 100644 --- a/xllm/core/runtime/embed_worker_impl.cpp +++ b/xllm/core/runtime/embed_worker_impl.cpp @@ -53,23 +53,21 @@ bool EmbedWorkerImpl::init_model(ModelContext& context) { return true; } -std::optional EmbedWorkerImpl::step( - const BatchedForwardInputs& inputs) { +std::optional EmbedWorkerImpl::step(const ForwardInput& input) { torch::DeviceGuard device_guard(device_); Timer timer; // TODO to adapt multi stream parallel later, just use [0] temporarily // all tensors should be on the same device as model - auto flatten_tokens = inputs.micro_inputs[0].token_ids.to(device_); - auto flatten_positions = inputs.micro_inputs[0].positions.to(device_); - auto params = inputs.micro_inputs[0].input_params.to(device_); - auto sampling_params = - inputs.micro_inputs[0].sampling_params.to(device_, dtype_); + auto flatten_tokens = input.token_ids.to(device_); + auto flatten_positions = input.positions.to(device_); + auto params = input.input_params.to(device_); + auto sampling_params = input.sampling_params.to(device_, dtype_); // call model executor forward to get hidden states auto hidden_states = model_executor_->forward( - {flatten_tokens}, {flatten_positions}, kv_caches_, {params}); + flatten_tokens, flatten_positions, kv_caches_, params); COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); @@ -81,7 +79,7 @@ std::optional EmbedWorkerImpl::step( ForwardOutput output; SampleOutput sample_output; if (sampling_params.selected_token_idxes.defined() && - inputs.micro_inputs[0].sampling_params.is_embeddings) { + input.sampling_params.is_embeddings) { // create embeddings timer.reset(); // cast model_ from Causal model to Embedding model diff --git a/xllm/core/runtime/embed_worker_impl.h b/xllm/core/runtime/embed_worker_impl.h index 6542d84e4..efbb010f4 100644 --- a/xllm/core/runtime/embed_worker_impl.h +++ b/xllm/core/runtime/embed_worker_impl.h @@ -42,8 +42,7 @@ class EmbedWorkerImpl : public WorkerImpl { // initialize model, cache manager. blocking call bool init_model(ModelContext& context) override; - std::optional step( - const BatchedForwardInputs& inputs) override; + std::optional step(const ForwardInput& input) override; }; } // namespace xllm diff --git a/xllm/core/runtime/executor.cpp b/xllm/core/runtime/executor.cpp index bc9875145..2f4e652b9 100644 --- a/xllm/core/runtime/executor.cpp +++ b/xllm/core/runtime/executor.cpp @@ -47,10 +47,10 @@ ForwardInput Executor::prepare_inputs(Batch& batch) { return impl_->prepare_inputs(batch); } -torch::Tensor Executor::forward(const std::vector& tokens, - const std::vector& positions, +torch::Tensor Executor::forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& params) { + const ModelInputParams& params) { COUNTER_INC(num_model_execution_total_eager); return impl_->run(tokens, positions, kv_caches, params); } diff --git a/xllm/core/runtime/executor.h b/xllm/core/runtime/executor.h index 90859344a..77a99b42e 100644 --- a/xllm/core/runtime/executor.h +++ b/xllm/core/runtime/executor.h @@ -44,10 +44,10 @@ class Executor final { // tokens: vector size is dp_size, each element is [num_tokens/dp_size] // positions: vector size is dp_size, each element is [num_tokens/dp_size] // token pos in the sequence returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& params); + const ModelInputParams& params); private: std::unique_ptr impl_; diff --git a/xllm/core/runtime/executor_impl.h b/xllm/core/runtime/executor_impl.h index 80347844b..4576d5bd5 100644 --- a/xllm/core/runtime/executor_impl.h +++ b/xllm/core/runtime/executor_impl.h @@ -38,10 +38,10 @@ class ExecutorImpl { // tokens: vector size is dp_size, each element is [num_tokens/dp_size] // positions: vector size is dp_size, each element is [num_tokens/dp_size] // token pos in the sequence returns: [num_tokens, hidden_size] - virtual torch::Tensor run(const std::vector& tokens, - const std::vector& positions, + virtual torch::Tensor run(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& params) = 0; + const ModelInputParams& params) = 0; }; } // namespace xllm diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp old mode 100755 new mode 100644 index 31c54e1fc..f64b47e03 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -43,15 +43,20 @@ limitations under the License. namespace xllm { namespace { -uint32_t determine_micro_batches_num(const std::vector& batch) { - bool not_all_in_decode = - std::any_of(batch.begin(), batch.end(), [](const Batch& one_batch) { - return one_batch.get_batch_prefill_status(); - }); - if (not_all_in_decode && FLAGS_enable_multi_stream_parallel) { - return 2; - } - return 1; +void try_to_enable_mla(const std::vector& raw_forward_inputs) { + static bool set_enable_mla = FLAGS_enable_customize_mla_kernel; + // decode phase with tokens more than this limit will lead to error in + // customize mla kernel. once detect any input exceed the limit, fall back to + // default kernel. + const int num_tokens_limit = 230; + if (set_enable_mla) { + FLAGS_enable_customize_mla_kernel = + std::all_of(raw_forward_inputs.begin(), + raw_forward_inputs.end(), + [](const RawForwardInput& input) { + return input.flatten_tokens_vec.size() < num_tokens_limit; + }); + } } } // namespace @@ -705,38 +710,21 @@ ForwardOutput LLMEngine::step(std::vector& batch) { << "Split DP batch failed with dp_size as " << dp_size_ << " and actual batch size as " << batch.size() << "."; - // prepare input with DP and multi-stream parallel, 2-D micro batches - // batched_raw_forward_inputs[dp_size][micro_batch_size] - // currently we use two batch overlap(TBO), each micro_batch_size is 2. - auto batched_raw_forward_inputs = prepare_inputs(batch); - DCHECK(dp_size_ == batched_raw_forward_inputs.size()) - << "The processed raw forward inputs size " - << batched_raw_forward_inputs.size() << " is not equal to dp size " - << dp_size_ << "."; - static bool set_enable_mla = FLAGS_enable_customize_mla_kernel; - // decode phase with tokens more than this limit will lead to error in - // customize mla kernel. once detect any input exceed the limit, fall back to - // default kernel. - const int num_tokens_limit = 230; - if (set_enable_mla) { - FLAGS_enable_customize_mla_kernel = std::all_of( - batched_raw_forward_inputs.begin(), - batched_raw_forward_inputs.end(), - [](const std::vector& inputs) { - return std::all_of( - inputs.begin(), inputs.end(), [](const RawForwardInput& input) { - return input.flatten_tokens_vec.size() < num_tokens_limit; - }); - }); - } + auto raw_forward_inputs = prepare_inputs(batch); + DCHECK(dp_size_ == raw_forward_inputs.size()) + << "The processed raw forward inputs size " << raw_forward_inputs.size() + << " is not equal to dp size " << dp_size_ << "."; + + try_to_enable_mla(raw_forward_inputs); + std::vector>> futures; futures.reserve(worker_clients_num_); // update dp related global paramters and then execute model for (auto worker_rank = 0; worker_rank < worker_clients_num_; ++worker_rank) { auto dp_rank = worker_rank / dp_local_tp_size_; - futures.emplace_back(worker_clients_[worker_rank]->step_async( - batched_raw_forward_inputs[dp_rank])); + futures.emplace_back( + worker_clients_[worker_rank]->step_async(raw_forward_inputs[dp_rank])); } // wait for the all future to complete @@ -873,53 +861,37 @@ void LLMEngine::process_eplb_data( eplb_manager_->update_expert_load(tensors); } -std::vector> LLMEngine::prepare_inputs( +std::vector LLMEngine::prepare_inputs( std::vector& batch) { - // this is a nested 2-D inputs, with outer dimension indicates dp batches, - // inner dimension indicates multi-stream parallel micro batches - std::vector> batched_inputs(dp_size_); - // determine micro batches number with current batch prefill/decode status - auto micro_batches_num = determine_micro_batches_num(batch); - + std::vector batched_inputs; + batched_inputs.reserve(dp_size_); // some dp related variables - std::vector> dp_global_token_nums; - dp_global_token_nums.resize(micro_batches_num, - std::vector(dp_size_)); + std::vector dp_global_token_nums; + dp_global_token_nums.resize(dp_size_); bool global_empty_kv_cache = true; - // eplb related - EplbInfo eplb_info; - // build model input for every single micro batch for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - // calculate micro batch split indexes - auto split_seq_index = xllm::util::cal_vec_split_index( - batch[dp_rank].size(), micro_batches_num); - for (auto i = 0; i < micro_batches_num; ++i) { - batched_inputs[dp_rank].push_back( - std::move(batch[dp_rank].prepare_forward_input(split_seq_index[i], - split_seq_index[i + 1], - args_, - threadpool_.get()))); - dp_global_token_nums[i][dp_rank] = - batched_inputs[dp_rank][i].flatten_tokens_vec.size(); - global_empty_kv_cache = - batched_inputs[dp_rank][i].empty_kv_cache && global_empty_kv_cache; - } + batched_inputs.emplace_back(std::move(batch[dp_rank].prepare_forward_input( + 0, batch[dp_rank].size(), args_, threadpool_.get()))); + dp_global_token_nums[dp_rank] = + batched_inputs[dp_rank].flatten_tokens_vec.size(); + global_empty_kv_cache = + batched_inputs[dp_rank].empty_kv_cache && global_empty_kv_cache; } + // eplb related + EplbInfo eplb_info; if (FLAGS_enable_eplb) { eplb_info = eplb_manager_->get_eplb_info(); } // update dp_global_token_nums and global_empty_kv_cache for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - for (auto i = 0; i < micro_batches_num; ++i) { - batched_inputs[dp_rank][i].dp_global_token_nums = dp_global_token_nums[i]; - batched_inputs[dp_rank][i].global_empty_kv_cache = global_empty_kv_cache; - if (FLAGS_enable_eplb) { - batched_inputs[dp_rank][i].eplb_info = eplb_info; - } + batched_inputs[dp_rank].dp_global_token_nums = dp_global_token_nums; + batched_inputs[dp_rank].global_empty_kv_cache = global_empty_kv_cache; + if (FLAGS_enable_eplb) { + batched_inputs[dp_rank].eplb_info = eplb_info; } } diff --git a/xllm/core/runtime/llm_engine.h b/xllm/core/runtime/llm_engine.h index b6be9c3b5..10681d01f 100644 --- a/xllm/core/runtime/llm_engine.h +++ b/xllm/core/runtime/llm_engine.h @@ -118,8 +118,7 @@ class LLMEngine : public Engine { bool allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap); bool allocate_continuous_kv_cache( const Engine::KVCacheCapacity& kv_cache_cap); - std::vector> prepare_inputs( - std::vector& batch); + std::vector prepare_inputs(std::vector& batch); void process_group_test(); protected: diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 65318bba1..170e2e5f8 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -73,59 +73,45 @@ bool LLMWorkerImpl::init_model(ModelContext& context) { return true; } -std::optional LLMWorkerImpl::step( - const BatchedForwardInputs& inputs) { +std::optional LLMWorkerImpl::step(const ForwardInput& input) { Timer timer; - std::vector flatten_tokens_micro_batches; - std::vector flatten_positions_micro_batches; - std::vector input_params_micro_batches; - auto& concated_sampling_params = inputs.concated_sampling_params; + auto& sampling_params = input.sampling_params; std::vector> futures; - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - flatten_tokens_micro_batches.push_back( - std::move(inputs.micro_inputs[i].token_ids)); - flatten_positions_micro_batches.push_back( - std::move(inputs.micro_inputs[i].positions)); - input_params_micro_batches.push_back( - std::move(inputs.micro_inputs[i].input_params)); - - if (options_.kv_cache_transfer_mode() == "PUSH" && - !inputs.micro_inputs[i].transfer_kv_infos.empty()) { + if (options_.kv_cache_transfer_mode() == "PUSH" && + !input.transfer_kv_infos.empty()) { #if defined(USE_NPU) - std::shared_ptr layer_synchronizer = - std::make_shared( - context_.get_model_args().n_layers()); - const_cast(&(input_params_micro_batches[i])) - ->layer_synchronizer = layer_synchronizer; - - futures.emplace_back(kv_cache_transfer_->push_kv_blocks_async( - inputs.micro_inputs[i].transfer_kv_infos, - context_.get_parallel_args(), - layer_synchronizer, - is_spec_draft_)); + std::shared_ptr layer_synchronizer = + std::make_shared( + context_.get_model_args().n_layers()); + const_cast(&(input.input_params))->layer_synchronizer = + layer_synchronizer; + + futures.emplace_back( + kv_cache_transfer_->push_kv_blocks_async(input.transfer_kv_infos, + context_.get_parallel_args(), + layer_synchronizer, + is_spec_draft_)); #endif - } } + if (FLAGS_enable_eplb) { - eplb_executor_->eplb_execute(inputs.micro_inputs[0].eplb_info); + eplb_executor_->eplb_execute(input.eplb_info); } // temporarily use [0], will be adapted in next pr // call model executor forward to get hidden states - auto hidden_states = model_executor_->forward(flatten_tokens_micro_batches, - flatten_positions_micro_batches, - kv_caches_, - input_params_micro_batches); + auto hidden_states = model_executor_->forward( + input.token_ids, input.positions, kv_caches_, input.input_params); if (!hidden_states.defined()) { return std::nullopt; } torch::Tensor logits; - if (concated_sampling_params.selected_token_idxes.defined()) { - logits = model_->logits(hidden_states, - concated_sampling_params.selected_token_idxes); + if (sampling_params.selected_token_idxes.defined()) { + logits = + model_->logits(hidden_states, sampling_params.selected_token_idxes); } ForwardOutput output; @@ -142,9 +128,8 @@ std::optional LLMWorkerImpl::step( auto ret = device_.synchronize_default_stream(); // in p-d disaggregation scene, all micro batches should be in same // prefill/decode stage, so, to judge transfer_kv_infos.empty, - // just use micro inputs.micro_inputs[0] here if (options_.kv_cache_transfer_mode() == "PUSH" && - !inputs.micro_inputs[0].transfer_kv_infos.empty()) { + !input.transfer_kv_infos.empty()) { auto results = folly::collectAll(futures).within(std::chrono::seconds(60)).get(); for (const auto& result : results) { @@ -162,15 +147,15 @@ std::optional LLMWorkerImpl::step( // driver prepare model output SampleOutput sample_output; - if (concated_sampling_params.selected_token_idxes.defined()) { - sample_output = sampler_->forward(logits, concated_sampling_params); + if (sampling_params.selected_token_idxes.defined()) { + sample_output = sampler_->forward(logits, sampling_params); output.logits = logits; // beam search kernel BeamSearchOutput beam_search_output; - if (concated_sampling_params.use_beam_search && - inputs.acc_logprob.defined() && inputs.acc_logprob.numel() > 0) { - beam_search_output = beam_searcher_->forward(inputs.acc_logprob, + if (sampling_params.use_beam_search && input.acc_logprob.defined() && + input.acc_logprob.numel() > 0) { + beam_search_output = beam_searcher_->forward(input.acc_logprob, sample_output.top_tokens, sample_output.top_logprobs); } @@ -178,9 +163,9 @@ std::optional LLMWorkerImpl::step( // set sample output to output output.sample_output = sample_output; // carry over the sampling params - output.do_sample = concated_sampling_params.do_sample; - output.logprobs = concated_sampling_params.logprobs; - output.max_top_logprobs = concated_sampling_params.max_top_logprobs; + output.do_sample = sampling_params.do_sample; + output.logprobs = sampling_params.logprobs; + output.max_top_logprobs = sampling_params.max_top_logprobs; // set beam search output to output output.beam_search_output = beam_search_output; } @@ -189,14 +174,14 @@ std::optional LLMWorkerImpl::step( // should be in same prefill stage, so, to judge empty_kv_cache, // just use micro batch 0 here if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { + if (input.input_params.q_seq_lens_vec[0] > 1) { output.sample_output.embeddings = hidden_states; - } else if (concated_sampling_params.sample_idxes.defined()) { + } else if (sampling_params.sample_idxes.defined()) { // auto sample_idxes = // concated_sampling_params.selected_token_idxes.index_select( // /*dim=*/0, concated_sampling_params.sample_idxes); auto embeddings = hidden_states.index_select( - /*dim=*/0, concated_sampling_params.sample_idxes); + /*dim=*/0, sampling_params.sample_idxes); output.sample_output.embeddings = embeddings; } } @@ -205,14 +190,14 @@ std::optional LLMWorkerImpl::step( // should be in same prefill stage, so, to judge empty_kv_cache, // just use micro batch 0 here if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { + if (input.input_params.q_seq_lens_vec[0] > 1) { output.sample_output.embeddings = hidden_states; - } else if (concated_sampling_params.sample_idxes.defined()) { + } else if (sampling_params.sample_idxes.defined()) { // auto sample_idxes = // concated_sampling_params.selected_token_idxes.index_select( // /*dim=*/0, concated_sampling_params.sample_idxes); auto embeddings = hidden_states.index_select( - /*dim=*/0, concated_sampling_params.sample_idxes); + /*dim=*/0, sampling_params.sample_idxes); output.sample_output.embeddings = embeddings; } } @@ -220,7 +205,7 @@ std::optional LLMWorkerImpl::step( auto ret = device_.synchronize_default_stream(); if (options_.kv_cache_transfer_mode() == "PUSH" && - !inputs.micro_inputs[0].transfer_kv_infos.empty()) { + !input.transfer_kv_infos.empty()) { auto results = folly::collectAll(futures).within(std::chrono::seconds(60)).get(); for (const auto& result : results) { diff --git a/xllm/core/runtime/llm_worker_impl.h b/xllm/core/runtime/llm_worker_impl.h index 4387e5761..597f705c5 100644 --- a/xllm/core/runtime/llm_worker_impl.h +++ b/xllm/core/runtime/llm_worker_impl.h @@ -42,8 +42,7 @@ class LLMWorkerImpl : public WorkerImpl { // initialize model, cache manager. blocking call bool init_model(ModelContext& context) override; - std::optional step( - const BatchedForwardInputs& inputs) override; + std::optional step(const ForwardInput& input) override; layer::LmHead get_lm_head() { return model_->get_lm_head(); }; diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp old mode 100755 new mode 100644 index 8a46a7231..7871e1343 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -80,6 +80,10 @@ Master::Master(const Options& options, EngineType type) : options_(options) { #endif FLAGS_enable_multi_stream_parallel = options.enable_multi_stream_parallel() && (options.nnodes() > 1); + if (FLAGS_enable_multi_stream_parallel) { + LOG(FATAL) + << "Multi-stream parallel is refactoring now, will be supported later."; + } // construct engine const auto devices = diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 8e2c5a06f..8693338bc 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -165,56 +165,52 @@ bool SpeculativeWorkerImpl::allocate_kv_cache_with_transfer( #endif std::optional SpeculativeWorkerImpl::step( - const BatchedForwardInputs& inputs) { - // all micro batches in multi stream parallel share the same - // prefill/decode stage, use inputs[0] here - if (inputs.micro_inputs[0].token_ids.numel() == 0) { - return step_empty(inputs); + const ForwardInput& input) { + if (input.token_ids.numel() == 0) { + return step_empty(input); } // TODO: support data parallel case - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { - return step_prefill(inputs); + if (input.input_params.q_seq_lens_vec[0] > 1) { + return step_prefill(input); } else { - return step_decode(inputs); + return step_decode(input); } } std::optional SpeculativeWorkerImpl::step_empty( - const BatchedForwardInputs& inputs) { - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { - auto output = impl_->step(inputs); - auto draft_output = draft_impl_->step(inputs); + const ForwardInput& input) { + if (input.input_params.q_seq_lens_vec[0] > 1) { + auto output = impl_->step(input); + auto draft_output = draft_impl_->step(input); return output; } else { for (size_t i = 0; i < options_.num_speculative_tokens(); ++i) { - auto draft_future = draft_impl_->step_async(inputs); + auto draft_future = draft_impl_->step_async(input); ForwardOutput draft_output = std::move(draft_future).get().value(); } - BatchedForwardInputs new_inputs = inputs; - for (auto i = 0; i < new_inputs.micro_inputs.size(); ++i) { - for (auto& it : - new_inputs.micro_inputs[i].input_params.dp_global_token_nums) { - it *= options_.num_speculative_tokens() + 1; - } + ForwardInput new_input = input; + for (auto& it : new_input.input_params.dp_global_token_nums) { + it *= options_.num_speculative_tokens() + 1; } - auto future = impl_->step_async(new_inputs); + + auto future = impl_->step_async(new_input); ForwardOutput output = std::move(future).get().value(); return output; } } std::optional SpeculativeWorkerImpl::step_prefill( - const BatchedForwardInputs& inputs) { + const ForwardInput& input) { Timer timer; // run the target model to get first token and hidden states - auto future = impl_->step_async(inputs); + auto future = impl_->step_async(input); // MTP (Eagle Medusa) which depend on hidden states need this step // The others speculative model use inputs directly // ForwardInput prefill_inputs; - BatchedForwardInputs prefill_inputs; - prepare_prefill_inputs(inputs, prefill_inputs); + ForwardInput prefill_input; + prepare_prefill_inputs(input, prefill_input); ForwardOutput output = std::move(future).get().value(); COUNTER_ADD(speculative_execution_latency_seconds_target, timer.elapsed_seconds()); @@ -224,50 +220,40 @@ std::optional SpeculativeWorkerImpl::step_prefill( auto next_tokens = safe_to(output.sample_output.next_tokens, torch::kInt); auto start_idx = 0; auto token_start_idx = 0; - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - auto offset = inputs.micro_inputs[i].input_params.num_sequences; - auto token_offset = prefill_inputs.micro_inputs[i].token_ids.size(0); - if (token_offset > 0) { - prefill_inputs.micro_inputs[i].input_params.mm_data = MMData( - MMType::EMBEDDING, - {{"embedding", embeddings.narrow(0, token_start_idx, token_offset)}}); - } - if (next_tokens.defined()) { - auto& token_ids = prefill_inputs.micro_inputs[i].token_ids; - auto mask = (token_ids == -1); - // TODO: support multi stream parallel case - // token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, - // offset)); - token_ids.masked_scatter_(mask, next_tokens); - } - start_idx += offset; - token_start_idx += token_offset; + + auto offset = input.input_params.num_sequences; + auto token_offset = prefill_input.token_ids.size(0); + if (token_offset > 0) { + prefill_input.input_params.mm_data = MMData( + MMType::EMBEDDING, + {{"embedding", embeddings.narrow(0, token_start_idx, token_offset)}}); + } + if (next_tokens.defined()) { + auto& token_ids = prefill_input.token_ids; + auto mask = (token_ids == -1); + // TODO: support multi stream parallel case + // token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, + // offset)); + token_ids.masked_scatter_(mask, next_tokens); } + start_idx += offset; + token_start_idx += token_offset; // generate kv cache for draft model timer.reset(); - auto draft_future = draft_impl_->step_async(prefill_inputs); + auto draft_future = draft_impl_->step_async(prefill_input); ForwardOutput draft_output = std::move(draft_future).get().value(); COUNTER_ADD(speculative_execution_latency_seconds_draft, timer.elapsed_seconds()); - auto concated_embedding_ids = - inputs.micro_inputs[0].input_params.embedding_ids; - for (auto i = 1; i < inputs.micro_inputs.size(); ++i) { - concated_embedding_ids.insert( - concated_embedding_ids.end(), - inputs.micro_inputs[i].input_params.embedding_ids.begin(), - inputs.micro_inputs[i].input_params.embedding_ids.end()); - } - - if (inputs.concated_sampling_params.selected_token_idxes.defined()) { + if (input.sampling_params.selected_token_idxes.defined()) { embeddings = embeddings.index_select( - /*dim=*/0, inputs.concated_sampling_params.selected_token_idxes); + /*dim=*/0, input.sampling_params.selected_token_idxes); CHECK_EQ(embeddings.size(0), output.sample_output.next_tokens.size(0)); - embedding_allocator_->write(concated_embedding_ids, embeddings); + embedding_allocator_->write(input.input_params.embedding_ids, embeddings); #if defined(USE_NPU) if (kv_cache_transfer_) { - kv_cache_transfer_->copy_blocks(concated_embedding_ids, + kv_cache_transfer_->copy_blocks(input.input_params.embedding_ids, /*h2d*/ true); } #endif @@ -275,115 +261,96 @@ std::optional SpeculativeWorkerImpl::step_prefill( output.sample_output.embeddings = torch::Tensor(); #if defined(USE_NPU) - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - if (options_.kv_cache_transfer_mode() == "PUSH" && - !inputs.micro_inputs[i].transfer_kv_infos.empty()) { - auto future = kv_cache_transfer_->push_kv_blocks_async( - inputs.micro_inputs[i].transfer_kv_infos, - context_.get_parallel_args(), - nullptr, - true); - auto out = std::move(future).get(); - } + if (options_.kv_cache_transfer_mode() == "PUSH" && + !input.transfer_kv_infos.empty()) { + auto future = kv_cache_transfer_->push_kv_blocks_async( + input.transfer_kv_infos, context_.get_parallel_args(), nullptr, true); + auto out = std::move(future).get(); } #endif return output; } void SpeculativeWorkerImpl::prepare_prefill_inputs( - const BatchedForwardInputs& inputs, - BatchedForwardInputs& prefill_inputs) { - prefill_inputs.micro_inputs.reserve(inputs.micro_inputs.size()); - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - auto& input = inputs.micro_inputs[i]; - ForwardInput prefill_input; - prefill_input = input.to(device_, dtype_); - auto& input_params = prefill_input.input_params; - auto& extra_token_ids = input_params.extra_token_ids; - - torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); - Slice tokens_ids_slice = {token_ids.data_ptr(), - input.token_ids.numel()}; + const ForwardInput& input, + ForwardInput& prefill_input) { + prefill_input = input.to(device_, dtype_); + auto& input_params = prefill_input.input_params; + auto& extra_token_ids = input_params.extra_token_ids; - int32_t start_idx = 0; - std::vector new_token_ids; - new_token_ids.reserve(input.token_ids.numel()); - for (size_t i = 0; i < input_params.num_sequences; ++i) { - int32_t q_len = 0; - q_len = input_params.q_seq_lens_vec[i]; - Slice tokens_ids_slice_i = - tokens_ids_slice.slice(start_idx + 1, start_idx + q_len); - start_idx += q_len; - new_token_ids.insert(new_token_ids.end(), - tokens_ids_slice_i.begin(), - tokens_ids_slice_i.end()); - new_token_ids.emplace_back(extra_token_ids[i]); - } - prefill_input.token_ids = - torch::tensor(new_token_ids, prefill_input.positions.options()); - prefill_inputs.micro_inputs.push_back(std::move(prefill_input)); + torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); + Slice tokens_ids_slice = {token_ids.data_ptr(), + input.token_ids.numel()}; + + int32_t start_idx = 0; + std::vector new_token_ids; + new_token_ids.reserve(input.token_ids.numel()); + for (size_t i = 0; i < input_params.num_sequences; ++i) { + int32_t q_len = 0; + q_len = input_params.q_seq_lens_vec[i]; + Slice tokens_ids_slice_i = + tokens_ids_slice.slice(start_idx + 1, start_idx + q_len); + start_idx += q_len; + new_token_ids.insert(new_token_ids.end(), + tokens_ids_slice_i.begin(), + tokens_ids_slice_i.end()); + new_token_ids.emplace_back(extra_token_ids[i]); } - prefill_inputs.concated_sampling_params = inputs.concated_sampling_params; + prefill_input.token_ids = + torch::tensor(new_token_ids, prefill_input.positions.options()); } std::optional SpeculativeWorkerImpl::step_decode( - const BatchedForwardInputs& inputs) { + const ForwardInput& input) { // TODO : now only support Deepseek MTP // More work need to support n-gram and native speculative decoding. // ForwardInput draft_inputs = inputs; - BatchedForwardInputs draft_inputs = inputs; - for (auto i = 0; i < draft_inputs.micro_inputs.size(); ++i) { - auto& input = inputs.micro_inputs[i]; - auto& draft_input = draft_inputs.micro_inputs[i]; - // get embedding cache + ForwardInput draft_input = input; + // get embedding cache #if defined(USE_NPU) - if (kv_cache_transfer_) { - kv_cache_transfer_->copy_blocks(input.input_params.embedding_ids, - /*h2d*/ false); - } -#endif - torch::Tensor embeddings = - embedding_allocator_->read(draft_input.input_params.embedding_ids); - draft_input.input_params.mm_data = - MMData(MMType::EMBEDDING, {{"embedding", embeddings.to(device_)}}); + if (kv_cache_transfer_) { + kv_cache_transfer_->copy_blocks(input.input_params.embedding_ids, + /*h2d*/ false); } +#endif + torch::Tensor embeddings = + embedding_allocator_->read(draft_input.input_params.embedding_ids); + draft_input.input_params.mm_data = + MMData(MMType::EMBEDDING, {{"embedding", embeddings.to(device_)}}); // run the draft model to get proposals std::vector draft_outputs; - BatchedForwardInputs validate_inputs, next_step_input; + ForwardInput validate_input, next_step_input; Timer timer; std::vector>> futures; for (size_t i = 0; i < options_.num_speculative_tokens(); ++i) { - auto future = draft_impl_->step_async(draft_inputs); + auto future = draft_impl_->step_async(draft_input); if (i == options_.num_speculative_tokens() - 1) { // final step - prepare_validate_inputs(inputs, validate_inputs, true); + prepare_validate_inputs(input, validate_input, true); } else { - prepare_draft_inputs(draft_inputs, next_step_input, 1, device_); + prepare_draft_inputs(draft_input, next_step_input, 1, device_); } draft_outputs.push_back(std::move(future).get().value()); // update input of next step if (i < options_.num_speculative_tokens() - 1) { - draft_inputs = next_step_input; + draft_input = next_step_input; auto last_output = draft_outputs.back().sample_output; auto start_idx = 0; auto token_start_idx = 0; - for (auto i = 0; i < draft_inputs.micro_inputs.size(); ++i) { - auto& draft_input = draft_inputs.micro_inputs[i]; - auto offset = draft_input.input_params.num_sequences; - auto token_offset = draft_input.token_ids.size(0); - draft_input.token_ids = safe_to( - last_output.next_tokens.narrow(0, start_idx, offset), torch::kInt); - if (token_offset > 0) { - draft_input.input_params.mm_data = MMData( - MMType::EMBEDDING, - {{"embedding", - last_output.embeddings.narrow(0, token_start_idx, token_offset) - .to(device_)}}); - } - start_idx += offset; - token_start_idx += token_offset; + auto offset = draft_input.input_params.num_sequences; + auto token_offset = draft_input.token_ids.size(0); + draft_input.token_ids = safe_to( + last_output.next_tokens.narrow(0, start_idx, offset), torch::kInt); + if (token_offset > 0) { + draft_input.input_params.mm_data = MMData( + MMType::EMBEDDING, + {{"embedding", + last_output.embeddings.narrow(0, token_start_idx, token_offset) + .to(device_)}}); } + start_idx += offset; + token_start_idx += token_offset; } } COUNTER_ADD(speculative_execution_latency_seconds_draft, @@ -394,19 +361,16 @@ std::optional SpeculativeWorkerImpl::step_decode( auto next_tokens = safe_to(draft_output.sample_output.next_tokens, torch::kInt); int32_t start_idx = 0; - for (auto i = 0; i < validate_inputs.micro_inputs.size(); ++i) { - int32_t offset = draft_inputs.micro_inputs[i].input_params.num_sequences; - auto& validate_input = validate_inputs.micro_inputs[i]; - auto& token_ids = validate_input.token_ids; - auto mask = (token_ids == -1 * (i + 1)); - token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset)); - start_idx += offset; - } + int32_t offset = draft_input.input_params.num_sequences; + auto& token_ids = validate_input.token_ids; + auto mask = (token_ids == -1 * (i + 1)); + token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset)); + start_idx += offset; } // run the target model to get the verification scores timer.reset(); - auto future = impl_->step_async(validate_inputs); + auto future = impl_->step_async(validate_input); ForwardOutput target_output = std::move(future).get().value(); COUNTER_ADD(speculative_execution_latency_seconds_target, timer.elapsed_seconds()); @@ -414,23 +378,20 @@ std::optional SpeculativeWorkerImpl::step_decode( // verify the proposals with target and update the batch timer.reset(); SampleOutput val_output = - validate(inputs.concated_sampling_params, draft_outputs, target_output); + validate(input.sampling_params, draft_outputs, target_output); COUNTER_ADD(speculative_execution_latency_seconds_validation, timer.elapsed_seconds()); - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - auto& input = inputs.micro_inputs[i]; - // write the right cache and clear embeddings - embedding_allocator_->write_validate(input.input_params.embedding_ids, - val_output.next_tokens.to(torch::kCPU), - val_output.embeddings); + // write the right cache and clear embeddings + embedding_allocator_->write_validate(input.input_params.embedding_ids, + val_output.next_tokens.to(torch::kCPU), + val_output.embeddings); #if defined(USE_NPU) - if (kv_cache_transfer_) { - kv_cache_transfer_->copy_blocks(input.input_params.embedding_ids, - /*h2d*/ true); - } -#endif + if (kv_cache_transfer_) { + kv_cache_transfer_->copy_blocks(input.input_params.embedding_ids, + /*h2d*/ true); } +#endif val_output.embeddings = torch::Tensor(); @@ -441,192 +402,167 @@ std::optional SpeculativeWorkerImpl::step_decode( return target_output; } -void SpeculativeWorkerImpl::prepare_draft_inputs( - const BatchedForwardInputs& inputs, - BatchedForwardInputs& draft_inputs, - const int64_t offset, - const torch::Device device) { +void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input, + ForwardInput& draft_input, + const int64_t offset, + const torch::Device device) { // prepare input for MTP in decoding phase (Like Eagle). - draft_inputs.micro_inputs.reserve(inputs.micro_inputs.size()); - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - auto& input = inputs.micro_inputs[i]; - ForwardInput draft_input = input.to(device, dtype_); - - auto& input_params = draft_input.input_params; - const int32_t num_sequences = input_params.num_sequences; - torch::Tensor positions = safe_to(input.positions, torch::kCPU); - Slice positions_slice = {positions.data_ptr(), - positions.numel()}; - std::vector new_positions; - new_positions.reserve(num_sequences); - for (int32_t i = 0; i < num_sequences; ++i) { - new_positions.emplace_back(positions_slice[i] + offset); - } - torch::TensorOptions int_options = input.token_ids.options(); - draft_input.positions = torch::tensor(new_positions, int_options); - - std::vector kv_seq_lens_vec = {}; - // slot ids for new token - std::vector new_token_slot_ids; - - int32_t block_size = options_.block_size(); - torch::Tensor kv_seq_lens = safe_to(input_params.kv_seq_lens, torch::kCPU); - Slice kv_seq_lens_slice = {kv_seq_lens.data_ptr(), - kv_seq_lens.numel()}; - torch::Tensor block_tables = - safe_to(input_params.block_tables, torch::kCPU); - torch::Tensor new_cache_slots = - safe_to(input_params.new_cache_slots, torch::kCPU); - Slice new_cache_slots_slice = {new_cache_slots.data_ptr(), - new_cache_slots.numel()}; - for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) { - kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset); - torch::Tensor block_table = block_tables[seq_id]; - Slice block_table_slice = {block_table.data_ptr(), - block_table.numel()}; - int32_t new_token_slot_id = get_new_token_slot_id( - new_cache_slots_slice[seq_id], block_size, offset, block_table_slice); - new_token_slot_ids.emplace_back(new_token_slot_id); - } + draft_input = input.to(device, dtype_); - input_params.kv_max_seq_len = input_params.kv_max_seq_len + offset; - input_params.kv_seq_lens_vec = kv_seq_lens_vec; - input_params.kv_seq_lens = torch::tensor(kv_seq_lens_vec, int_options); - input_params.new_cache_slots = - torch::tensor(new_token_slot_ids, int_options); - draft_inputs.micro_inputs.push_back(std::move(draft_input)); + auto& input_params = draft_input.input_params; + const int32_t num_sequences = input_params.num_sequences; + torch::Tensor positions = safe_to(input.positions, torch::kCPU); + Slice positions_slice = {positions.data_ptr(), + positions.numel()}; + std::vector new_positions; + new_positions.reserve(num_sequences); + for (int32_t i = 0; i < num_sequences; ++i) { + new_positions.emplace_back(positions_slice[i] + offset); } - draft_inputs.concated_sampling_params = inputs.concated_sampling_params; + torch::TensorOptions int_options = input.token_ids.options(); + draft_input.positions = torch::tensor(new_positions, int_options); + + std::vector kv_seq_lens_vec = {}; + // slot ids for new token + std::vector new_token_slot_ids; + + int32_t block_size = options_.block_size(); + torch::Tensor kv_seq_lens = safe_to(input_params.kv_seq_lens, torch::kCPU); + Slice kv_seq_lens_slice = {kv_seq_lens.data_ptr(), + kv_seq_lens.numel()}; + torch::Tensor block_tables = safe_to(input_params.block_tables, torch::kCPU); + torch::Tensor new_cache_slots = + safe_to(input_params.new_cache_slots, torch::kCPU); + Slice new_cache_slots_slice = {new_cache_slots.data_ptr(), + new_cache_slots.numel()}; + for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) { + kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset); + torch::Tensor block_table = block_tables[seq_id]; + Slice block_table_slice = {block_table.data_ptr(), + block_table.numel()}; + int32_t new_token_slot_id = get_new_token_slot_id( + new_cache_slots_slice[seq_id], block_size, offset, block_table_slice); + new_token_slot_ids.emplace_back(new_token_slot_id); + } + + input_params.kv_max_seq_len = input_params.kv_max_seq_len + offset; + input_params.kv_seq_lens_vec = kv_seq_lens_vec; + input_params.kv_seq_lens = torch::tensor(kv_seq_lens_vec, int_options); + input_params.new_cache_slots = torch::tensor(new_token_slot_ids, int_options); } void SpeculativeWorkerImpl::prepare_validate_inputs( - const BatchedForwardInputs& inputs, - BatchedForwardInputs& validate_inputs, + const ForwardInput& input, + ForwardInput& validate_input, bool enable_schedule_overlap) { - validate_inputs.micro_inputs.reserve(inputs.micro_inputs.size()); - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - auto& input = inputs.micro_inputs[i]; - - ForwardInput validate_input = input.to(device_, dtype_); - auto& input_params = validate_input.input_params; - - const int32_t position_offset = enable_schedule_overlap ? 1 : 0; - const int32_t num_speculative_tokens = options_.num_speculative_tokens(); - const int32_t num_sequences = input_params.num_sequences; - const int32_t num_val_tokens = num_speculative_tokens + 1; - const int32_t total_num_val_tokens = num_sequences * num_val_tokens; - - std::vector> draft_tokens; - draft_tokens.reserve(num_speculative_tokens); - for (int i = 0; i < num_speculative_tokens; ++i) { - draft_tokens.emplace_back(std::vector(num_sequences, -1 * (i + 1))); - } + validate_input = input.to(device_, dtype_); + auto& input_params = validate_input.input_params; + + const int32_t position_offset = enable_schedule_overlap ? 1 : 0; + const int32_t num_speculative_tokens = options_.num_speculative_tokens(); + const int32_t num_sequences = input_params.num_sequences; + const int32_t num_val_tokens = num_speculative_tokens + 1; + const int32_t total_num_val_tokens = num_sequences * num_val_tokens; + + std::vector> draft_tokens; + draft_tokens.reserve(num_speculative_tokens); + for (int i = 0; i < num_speculative_tokens; ++i) { + draft_tokens.emplace_back(std::vector(num_sequences, -1 * (i + 1))); + } - torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); - Slice tokens_ids_slice = {token_ids.data_ptr(), - token_ids.numel()}; - torch::Tensor positions = safe_to(input.positions, torch::kCPU); - Slice positions_slice = {positions.data_ptr(), - positions.numel()}; - - std::vector new_token_ids; - std::vector new_positions; - new_token_ids.reserve(total_num_val_tokens); - new_positions.reserve(total_num_val_tokens); - for (int32_t i = 0; i < num_sequences; ++i) { - new_token_ids.emplace_back(tokens_ids_slice[i]); - new_positions.emplace_back(positions_slice[i] + position_offset); - for (int32_t j = 0; j < num_speculative_tokens; ++j) { - new_token_ids.emplace_back(draft_tokens[j][i]); - new_positions.emplace_back(positions_slice[i] + j + 1 + - position_offset); - } - } - torch::TensorOptions int_options = input.token_ids.options(); - validate_input.token_ids = torch::tensor(new_token_ids, int_options); - validate_input.positions = torch::tensor(new_positions, int_options); - - // update the input_params - input_params.num_sequences = total_num_val_tokens; - input_params.kv_max_seq_len = - input_params.kv_max_seq_len + num_speculative_tokens + position_offset; - for (auto& it : input_params.dp_global_token_nums) { - it *= num_val_tokens; + torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); + Slice tokens_ids_slice = {token_ids.data_ptr(), + token_ids.numel()}; + torch::Tensor positions = safe_to(input.positions, torch::kCPU); + Slice positions_slice = {positions.data_ptr(), + positions.numel()}; + + std::vector new_token_ids; + std::vector new_positions; + new_token_ids.reserve(total_num_val_tokens); + new_positions.reserve(total_num_val_tokens); + for (int32_t i = 0; i < num_sequences; ++i) { + new_token_ids.emplace_back(tokens_ids_slice[i]); + new_positions.emplace_back(positions_slice[i] + position_offset); + for (int32_t j = 0; j < num_speculative_tokens; ++j) { + new_token_ids.emplace_back(draft_tokens[j][i]); + new_positions.emplace_back(positions_slice[i] + j + 1 + position_offset); } + } + torch::TensorOptions int_options = input.token_ids.options(); + validate_input.token_ids = torch::tensor(new_token_ids, int_options); + validate_input.positions = torch::tensor(new_positions, int_options); - std::vector kv_seq_lens_vec = {}; - std::vector q_seq_lens_vec = {}; - // slot ids for new token - std::vector new_token_slot_ids; - std::vector> block_tables_vec; - - int32_t block_size = options_.block_size(); - torch::Tensor kv_seq_lens = safe_to(input_params.kv_seq_lens, torch::kCPU); - Slice kv_seq_lens_slice = {kv_seq_lens.data_ptr(), - kv_seq_lens.numel()}; - torch::Tensor block_tables = - safe_to(input_params.block_tables, torch::kCPU); - torch::Tensor new_cache_slots = - safe_to(input_params.new_cache_slots, torch::kCPU); - Slice new_cache_slots_slice = {new_cache_slots.data_ptr(), - new_cache_slots.numel()}; - for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) { - int32_t cur_token_slot_id = new_cache_slots_slice[seq_id]; - torch::Tensor block_table = block_tables[seq_id]; - Slice block_table_slice = {block_table.data_ptr(), - block_table.numel()}; - - // process kv length and q length - if (FLAGS_enable_atb_spec_kernel) { - kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + - num_speculative_tokens + position_offset); - q_seq_lens_vec.emplace_back(num_val_tokens); - } else { - for (int32_t token_id = position_offset; - token_id < num_val_tokens + position_offset; - ++token_id) { - q_seq_lens_vec.emplace_back(1); - kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + token_id); - // repeat block table - block_tables_vec.emplace_back(block_table_slice); - } - } + // update the input_params + input_params.num_sequences = total_num_val_tokens; + input_params.kv_max_seq_len = + input_params.kv_max_seq_len + num_speculative_tokens + position_offset; + for (auto& it : input_params.dp_global_token_nums) { + it *= num_val_tokens; + } + + std::vector kv_seq_lens_vec = {}; + std::vector q_seq_lens_vec = {}; + // slot ids for new token + std::vector new_token_slot_ids; + std::vector> block_tables_vec; + + int32_t block_size = options_.block_size(); + torch::Tensor kv_seq_lens = safe_to(input_params.kv_seq_lens, torch::kCPU); + Slice kv_seq_lens_slice = {kv_seq_lens.data_ptr(), + kv_seq_lens.numel()}; + torch::Tensor block_tables = safe_to(input_params.block_tables, torch::kCPU); + torch::Tensor new_cache_slots = + safe_to(input_params.new_cache_slots, torch::kCPU); + Slice new_cache_slots_slice = {new_cache_slots.data_ptr(), + new_cache_slots.numel()}; + for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) { + int32_t cur_token_slot_id = new_cache_slots_slice[seq_id]; + torch::Tensor block_table = block_tables[seq_id]; + Slice block_table_slice = {block_table.data_ptr(), + block_table.numel()}; - // process position related params + // process kv length and q length + if (FLAGS_enable_atb_spec_kernel) { + kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + + num_speculative_tokens + position_offset); + q_seq_lens_vec.emplace_back(num_val_tokens); + } else { for (int32_t token_id = position_offset; token_id < num_val_tokens + position_offset; ++token_id) { - int32_t new_token_slot_id = get_new_token_slot_id( - cur_token_slot_id, block_size, token_id, block_table_slice); - new_token_slot_ids.emplace_back(new_token_slot_id); + q_seq_lens_vec.emplace_back(1); + kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + token_id); + // repeat block table + block_tables_vec.emplace_back(block_table_slice); } } - input_params.kv_seq_lens_vec = kv_seq_lens_vec; - input_params.kv_seq_lens = torch::tensor(kv_seq_lens_vec, int_options); - input_params.q_seq_lens_vec = q_seq_lens_vec; - input_params.q_seq_lens = torch::tensor(q_seq_lens_vec, int_options); - input_params.new_cache_slots = - torch::tensor(new_token_slot_ids, int_options); - if (!FLAGS_enable_atb_spec_kernel) { - util::pad_2d_vector(block_tables_vec, /*pad_value=*/0); - input_params.block_tables = - create_2d_tensor(block_tables_vec, torch::kInt).to(device_); + // process position related params + for (int32_t token_id = position_offset; + token_id < num_val_tokens + position_offset; + ++token_id) { + int32_t new_token_slot_id = get_new_token_slot_id( + cur_token_slot_id, block_size, token_id, block_table_slice); + new_token_slot_ids.emplace_back(new_token_slot_id); } - input_params.decode_seq_range.second = input_params.num_sequences - 1; - - // update the sampling_params - update_sampling_params( - validate_input.sampling_params, num_val_tokens, total_num_val_tokens); - validate_inputs.micro_inputs.push_back(std::move(validate_input)); } - validate_inputs.concated_sampling_params = - validate_inputs.micro_inputs[0].sampling_params; - for (auto i = 1; i < validate_inputs.micro_inputs.size(); ++i) { - validate_inputs.concated_sampling_params.concat( - validate_inputs.micro_inputs[i].sampling_params); + input_params.kv_seq_lens_vec = kv_seq_lens_vec; + input_params.kv_seq_lens = torch::tensor(kv_seq_lens_vec, int_options); + input_params.q_seq_lens_vec = q_seq_lens_vec; + input_params.q_seq_lens = torch::tensor(q_seq_lens_vec, int_options); + input_params.new_cache_slots = torch::tensor(new_token_slot_ids, int_options); + if (!FLAGS_enable_atb_spec_kernel) { + util::pad_2d_vector(block_tables_vec, /*pad_value=*/0); + input_params.block_tables = + create_2d_tensor(block_tables_vec, torch::kInt).to(device_); } + input_params.decode_seq_range.second = input_params.num_sequences - 1; + + // update the sampling_params + update_sampling_params( + validate_input.sampling_params, num_val_tokens, total_num_val_tokens); } SampleOutput SpeculativeWorkerImpl::validate( @@ -821,15 +757,15 @@ void SpeculativeWorkerImpl::update_sampling_params( } void SpeculativeWorkerImpl::prepare_work_before_execute( - const BatchedForwardInputs& inputs, - BatchedForwardInputs& processed_inputs) { - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { - WorkerImpl::prepare_work_before_execute(inputs, processed_inputs); + const ForwardInput& input, + ForwardInput& processed_input) { + if (input.input_params.q_seq_lens_vec[0] > 1) { + WorkerImpl::prepare_work_before_execute(input, processed_input); } else { if (enable_schedule_overlap()) { - prepare_draft_inputs(inputs, processed_inputs, -1, torch::kCPU); + prepare_draft_inputs(input, processed_input, -1, torch::kCPU); } else { - prepare_draft_inputs(inputs, processed_inputs, -1, device_); + prepare_draft_inputs(input, processed_input, -1, device_); } } } diff --git a/xllm/core/runtime/speculative_worker_impl.h b/xllm/core/runtime/speculative_worker_impl.h index c4da14c9e..ebdb80d53 100644 --- a/xllm/core/runtime/speculative_worker_impl.h +++ b/xllm/core/runtime/speculative_worker_impl.h @@ -90,11 +90,10 @@ class SpeculativeWorkerImpl : public WorkerImpl { }; // prepare work before model execution - void prepare_work_before_execute(const BatchedForwardInputs& inputs, - BatchedForwardInputs& new_inputs) override; + void prepare_work_before_execute(const ForwardInput& input, + ForwardInput& new_input) override; - std::optional step( - const BatchedForwardInputs& inputs) override; + std::optional step(const ForwardInput& input) override; ForwardInput update_input_by_last_step_output(ForwardInput& inputs) override; @@ -114,26 +113,26 @@ class SpeculativeWorkerImpl : public WorkerImpl { }; private: - std::optional step_prefill(const BatchedForwardInputs& inputs); + std::optional step_prefill(const ForwardInput& input); - std::optional step_decode(const BatchedForwardInputs& inputs); + std::optional step_decode(const ForwardInput& inputs); // When enable DP, inputs sometimes be empty but model need to execute. - std::optional step_empty(const BatchedForwardInputs& inputs); + std::optional step_empty(const ForwardInput& inputs); // prepare inputs for draft model at Prefill phase. - void prepare_prefill_inputs(const BatchedForwardInputs& inputs, - BatchedForwardInputs& prefill_inputs); + void prepare_prefill_inputs(const ForwardInput& inputs, + ForwardInput& prefill_inputs); // prepare inputs for draft model at Decode phase. - void prepare_draft_inputs(const BatchedForwardInputs& inputs, - BatchedForwardInputs& draft_inputs, + void prepare_draft_inputs(const ForwardInput& inputs, + ForwardInput& draft_inputs, const int64_t offset, const torch::Device device); // prepare inputs for target model at Decode phase. - void prepare_validate_inputs(const BatchedForwardInputs& inputs, - BatchedForwardInputs& validate_inputs, + void prepare_validate_inputs(const ForwardInput& inputs, + ForwardInput& validate_inputs, bool enable_schedule_overlap); SampleOutput validate(const SamplingParameters& sampling_params, diff --git a/xllm/core/runtime/vlm_engine.cpp b/xllm/core/runtime/vlm_engine.cpp index 8ec675098..6227a6eeb 100644 --- a/xllm/core/runtime/vlm_engine.cpp +++ b/xllm/core/runtime/vlm_engine.cpp @@ -43,20 +43,6 @@ limitations under the License. namespace xllm { -namespace { -uint32_t determine_micro_batches_num(const std::vector& batch) { - bool not_all_in_decode = - std::any_of(batch.begin(), batch.end(), [](const Batch& one_batch) { - return one_batch.get_batch_prefill_status(); - }); - // TODO:VLM support multi stream parallel later. - // if (not_all_in_decode && FLAGS_enable_multi_stream_parallel) { - // return 2; - // } - return 1; -} -} // namespace - VLMEngine::VLMEngine(const runtime::Options& options, std::shared_ptr dist_manager) : options_(options), dist_manager_(dist_manager) { @@ -323,12 +309,11 @@ ForwardOutput VLMEngine::step(std::vector& batch) { << "Split DP batch failed with dp_size as " << dp_size_ << " and actual batch size as " << batch.size() << "."; - auto batched_raw_forward_inputs = prepare_inputs(batch); + auto raw_forward_inputs = prepare_inputs(batch); - DCHECK(dp_size_ == batched_raw_forward_inputs.size()) - << "The processed raw forward inputs size " - << batched_raw_forward_inputs.size() << " is not equal to dp size " - << dp_size_ << "."; + DCHECK(dp_size_ == raw_forward_inputs.size()) + << "The processed raw forward inputs size " << raw_forward_inputs.size() + << " is not equal to dp size " << dp_size_ << "."; std::vector>> futures; futures.reserve(worker_clients_num_); @@ -336,8 +321,8 @@ ForwardOutput VLMEngine::step(std::vector& batch) { // update dp related global paramters and then execute model for (auto worker_rank = 0; worker_rank < worker_clients_num_; ++worker_rank) { auto dp_rank = worker_rank / dp_local_tp_size_; - futures.emplace_back(worker_clients_[worker_rank]->step_async( - batched_raw_forward_inputs[dp_rank])); + futures.emplace_back( + worker_clients_[worker_rank]->step_async(raw_forward_inputs[dp_rank])); } // wait for the all future to complete @@ -435,42 +420,28 @@ std::vector VLMEngine::get_active_activation_memory() const { return active_activation_memories; } -std::vector> VLMEngine::prepare_inputs( +std::vector VLMEngine::prepare_inputs( std::vector& batch) { - std::vector> batched_inputs(dp_size_); - // determine micro batches number with current batch prefill/decode status - auto micro_batches_num = determine_micro_batches_num(batch); - + std::vector batched_inputs; + batched_inputs.reserve(dp_size_); // some dp related variables - std::vector> dp_global_token_nums; - dp_global_token_nums.resize(micro_batches_num, - std::vector(dp_size_)); + std::vector dp_global_token_nums; + dp_global_token_nums.resize(dp_size_); bool global_empty_kv_cache = true; - // build model input for every single micro batch for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - // calculate micro batch split indexes - auto split_seq_index = xllm::util::cal_vec_split_index( - batch[dp_rank].size(), micro_batches_num); - for (auto i = 0; i < micro_batches_num; ++i) { - batched_inputs[dp_rank].push_back( - std::move(batch[dp_rank].prepare_forward_input(split_seq_index[i], - split_seq_index[i + 1], - args_, - threadpool_.get()))); - dp_global_token_nums[i][dp_rank] = - batched_inputs[dp_rank][i].flatten_tokens_vec.size(); - global_empty_kv_cache = - batched_inputs[dp_rank][i].empty_kv_cache && global_empty_kv_cache; - } + batched_inputs.emplace_back(std::move(batch[dp_rank].prepare_forward_input( + 0, batch[dp_rank].size(), args_, threadpool_.get()))); + dp_global_token_nums[dp_rank] = + batched_inputs[dp_rank].flatten_tokens_vec.size(); + global_empty_kv_cache = + batched_inputs[dp_rank].empty_kv_cache && global_empty_kv_cache; } // update dp_global_token_nums and global_empty_kv_cache for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - for (auto i = 0; i < micro_batches_num; ++i) { - batched_inputs[dp_rank][i].dp_global_token_nums = dp_global_token_nums[i]; - batched_inputs[dp_rank][i].global_empty_kv_cache = global_empty_kv_cache; - } + batched_inputs[dp_rank].dp_global_token_nums = dp_global_token_nums; + batched_inputs[dp_rank].global_empty_kv_cache = global_empty_kv_cache; } return batched_inputs; diff --git a/xllm/core/runtime/vlm_engine.h b/xllm/core/runtime/vlm_engine.h old mode 100755 new mode 100644 index de0dabd85..910a9b6e8 --- a/xllm/core/runtime/vlm_engine.h +++ b/xllm/core/runtime/vlm_engine.h @@ -57,8 +57,7 @@ class VLMEngine : public Engine { bool init_model(); Engine::KVCacheCapacity estimate_kv_cache_capacity(); bool allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap); - std::vector> prepare_inputs( - std::vector& batch); + std::vector prepare_inputs(std::vector& batch); void setup_workers(const runtime::Options& options); void process_group_test(); diff --git a/xllm/core/runtime/vlm_worker_impl.cpp b/xllm/core/runtime/vlm_worker_impl.cpp index 03e40c0e4..7400fc41e 100644 --- a/xllm/core/runtime/vlm_worker_impl.cpp +++ b/xllm/core/runtime/vlm_worker_impl.cpp @@ -54,17 +54,13 @@ bool VLMWorkerImpl::init_model(ModelContext& context) { return true; } -std::optional VLMWorkerImpl::step( - const BatchedForwardInputs& inputs) { +std::optional VLMWorkerImpl::step(const ForwardInput& input) { Timer timer; // TODO guojinrong, to adapt multi stream parallel later // call model executor forward to get hidden states - auto hidden_states = - model_executor_->forward({inputs.micro_inputs[0].token_ids}, - {inputs.micro_inputs[0].positions}, - kv_caches_, - {inputs.micro_inputs[0].input_params}); - auto& sampling_params = inputs.micro_inputs[0].sampling_params; + auto hidden_states = model_executor_->forward( + input.token_ids, input.positions, kv_caches_, input.input_params); + auto& sampling_params = input.sampling_params; torch::Tensor logits; if (sampling_params.selected_token_idxes.defined()) { logits = diff --git a/xllm/core/runtime/vlm_worker_impl.h b/xllm/core/runtime/vlm_worker_impl.h index 06663f1d0..4aaacfef2 100644 --- a/xllm/core/runtime/vlm_worker_impl.h +++ b/xllm/core/runtime/vlm_worker_impl.h @@ -41,8 +41,7 @@ class VLMWorkerImpl : public WorkerImpl { // initialize model, cache manager. blocking call bool init_model(ModelContext& context) override; - std::optional step( - const BatchedForwardInputs& inputs) override; + std::optional step(const ForwardInput& input) override; }; } // namespace xllm diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index 271df4a0a..f1bd4fd29 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -106,10 +106,7 @@ ForwardInput Worker::prepare_inputs(Batch& batch) { } std::optional Worker::step(const ForwardInput& inputs) { - // TODO to adapt multi stream parallel later - BatchedForwardInputs batched_inputs; - batched_inputs.micro_inputs = {std::move(inputs)}; - return impl_->step(batched_inputs); + return impl_->step(inputs); } const bool Worker::is_driver() { return impl_->is_driver(); } @@ -120,7 +117,7 @@ Worker::estimate_kv_cache_capacity_async() { } folly::SemiFuture> Worker::step_async( - const BatchedForwardInputs& inputs) { + const ForwardInput& inputs) { return impl_->step_async(inputs); } diff --git a/xllm/core/runtime/worker.h b/xllm/core/runtime/worker.h index eed1c69ca..362075d8c 100644 --- a/xllm/core/runtime/worker.h +++ b/xllm/core/runtime/worker.h @@ -118,7 +118,7 @@ class Worker { // Run the model on the given input. async call // the future returns a successfull status with no meaningful value folly::SemiFuture> step_async( - const BatchedForwardInputs& inputs); + const ForwardInput& inputs); folly::SemiFuture process_group_test_async(); diff --git a/xllm/core/runtime/worker_client.cpp b/xllm/core/runtime/worker_client.cpp index fe6bb24f7..d9fab3913 100644 --- a/xllm/core/runtime/worker_client.cpp +++ b/xllm/core/runtime/worker_client.cpp @@ -104,15 +104,12 @@ WorkerClient::estimate_kv_cache_capacity_async() { } folly::SemiFuture> WorkerClient::step_async( - const ForwardInput& inputs) { - // TODO to adapt multi stream parallel later - BatchedForwardInputs batched_fwd_inputs; - batched_fwd_inputs.micro_inputs = {std::move(inputs)}; - return worker_->step_async(batched_fwd_inputs); + const ForwardInput& input) { + return worker_->step_async(input); } folly::SemiFuture> WorkerClient::step_async( - const std::vector& inputs) { + const RawForwardInput& inputs) { LOG(ERROR) << "Worker Method step_async with RawForwardInput param is " "UnImplemented."; } diff --git a/xllm/core/runtime/worker_client.h b/xllm/core/runtime/worker_client.h index d6294ca37..66a3ee5ea 100644 --- a/xllm/core/runtime/worker_client.h +++ b/xllm/core/runtime/worker_client.h @@ -126,7 +126,7 @@ class WorkerClient { // for multi-node serving, we pass an non-tensor params to remote workers. virtual folly::SemiFuture> step_async( - const std::vector& inputs); + const RawForwardInput& inputs); virtual folly::SemiFuture process_group_test_async(); diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 4f1c32224..e63fc04e2 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -418,114 +418,102 @@ ForwardInput WorkerImpl::update_input_by_last_step_output( return inputs; } -void WorkerImpl::prepare_work_before_execute( - const BatchedForwardInputs& inputs, - BatchedForwardInputs& processed_inputs) { +void WorkerImpl::prepare_work_before_execute(const ForwardInput& input, + ForwardInput& processed_input) { c10::StreamGuard streamGuard = prepare_stream_->set_stream_guard(); - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - ForwardInput fwd_inputs_on_device; - fwd_inputs_on_device = inputs.micro_inputs[i].to(device_, dtype_); - auto& input_params = fwd_inputs_on_device.input_params; + processed_input = input.to(device_, dtype_); + auto& input_params = processed_input.input_params; #if defined(USE_NPU) - if (input_params.swap_blocks.size() > 0 && - !FLAGS_enable_block_copy_kernel) { - auto& swap_blocks = input_params.swap_blocks; - - // collect src and dst indices - std::vector src_indices, dst_indices; - src_indices.reserve(swap_blocks.size()); - dst_indices.reserve(swap_blocks.size()); - - for (const auto& block : swap_blocks) { - src_indices.push_back(block.src_block_id); - dst_indices.push_back(block.dst_block_id); - } + if (input_params.swap_blocks.size() > 0 && !FLAGS_enable_block_copy_kernel) { + auto& swap_blocks = input_params.swap_blocks; - // batch select keys and values - auto src_tensor = torch::tensor( - src_indices, torch::dtype(torch::kLong).device(device_)); - auto dst_tensor = torch::tensor( - dst_indices, torch::dtype(torch::kLong).device(device_)); - const int64_t num_layers = context_.get_model_args().n_layers(); - for (int layer_id = 0; layer_id < num_layers; layer_id++) { - kv_caches_[layer_id].swap_blocks(src_tensor, dst_tensor); - } + // collect src and dst indices + std::vector src_indices, dst_indices; + src_indices.reserve(swap_blocks.size()); + dst_indices.reserve(swap_blocks.size()); + + for (const auto& block : swap_blocks) { + src_indices.push_back(block.src_block_id); + dst_indices.push_back(block.dst_block_id); } - if (!context_.get_parallel_args().mapping_data().empty()) { - torch::Tensor token_size_per_dp_group = - torch::tensor(fwd_inputs_on_device.input_params.dp_global_token_nums, - torch::TensorOptions() - .device(torch::kCPU) - .dtype(torch::kInt32) - .pinned_memory(true)); - bool is_prefill = fwd_inputs_on_device.input_params.global_empty_kv_cache - ? true - : false; - DpEpPadding dp_ep_padding(token_size_per_dp_group, - context_.get_model_args().num_experts_per_tok(), - context_.get_parallel_args().mapping_data(), - device_, - dtype_, - is_prefill); - fwd_inputs_on_device.input_params.dp_ep_padding_data = - dp_ep_padding.build(); - if (FLAGS_enable_eplb) { - // expert_load_data_.fill_(0); - fwd_inputs_on_device.input_params.expert_load_data = expert_load_data_; - } + + // batch select keys and values + auto src_tensor = + torch::tensor(src_indices, torch::dtype(torch::kLong).device(device_)); + auto dst_tensor = + torch::tensor(dst_indices, torch::dtype(torch::kLong).device(device_)); + const int64_t num_layers = context_.get_model_args().n_layers(); + for (int layer_id = 0; layer_id < num_layers; layer_id++) { + kv_caches_[layer_id].swap_blocks(src_tensor, dst_tensor); + } + } + + if (!context_.get_parallel_args().mapping_data().empty()) { + torch::Tensor token_size_per_dp_group = + torch::tensor(processed_input.input_params.dp_global_token_nums, + torch::TensorOptions() + .device(torch::kCPU) + .dtype(torch::kInt32) + .pinned_memory(true)); + bool is_prefill = + processed_input.input_params.global_empty_kv_cache ? true : false; + DpEpPadding dp_ep_padding(token_size_per_dp_group, + context_.get_model_args().num_experts_per_tok(), + context_.get_parallel_args().mapping_data(), + device_, + dtype_, + is_prefill); + processed_input.input_params.dp_ep_padding_data = dp_ep_padding.build(); + if (FLAGS_enable_eplb) { + // expert_load_data_.fill_(0); + processed_input.input_params.expert_load_data = expert_load_data_; } -#endif - processed_inputs.micro_inputs.push_back(std::move(fwd_inputs_on_device)); } - processed_inputs.concated_sampling_params = - inputs.concated_sampling_params.to(device_, dtype_); - if (inputs.acc_logprob.defined()) { - processed_inputs.acc_logprob = - inputs.acc_logprob.to(torch::kFloat32).to(device_); +#endif + + processed_input.sampling_params = input.sampling_params.to(device_, dtype_); + if (input.acc_logprob.defined()) { + processed_input.acc_logprob = + input.acc_logprob.to(torch::kFloat32).to(device_); } auto ret = prepare_stream_->synchronize(); } folly::SemiFuture> WorkerImpl::step_async( - const BatchedForwardInputs& inputs) { - BatchedForwardInputs batched_inputs_on_device; - batched_inputs_on_device.micro_inputs.reserve(inputs.micro_inputs.size()); + const ForwardInput& input) { + ForwardInput input_on_device; - prepare_work_before_execute(inputs, batched_inputs_on_device); + prepare_work_before_execute(input, input_on_device); folly::Promise> promise; auto future = promise.getSemiFuture(); threadpool_.schedule([this, - inputs = std::move(batched_inputs_on_device), + input = std::move(input_on_device), promise = std::move(promise)]() mutable { #if defined(USE_NPU) - for (auto& input : inputs.micro_inputs) { - { - std::lock_guard lock(mutex_); - if (layer_wise_load_synchronizer_.count(input.input_params.batch_id) != - 0) { - input.input_params.layer_wise_load_synchronizer = std::move( - layer_wise_load_synchronizer_[input.input_params.batch_id]); - layer_wise_load_synchronizer_.erase(input.input_params.batch_id); - } + { + std::lock_guard lock(mutex_); + if (layer_wise_load_synchronizer_.count(input.input_params.batch_id) != + 0) { + input.input_params.layer_wise_load_synchronizer = std::move( + layer_wise_load_synchronizer_[input.input_params.batch_id]); + layer_wise_load_synchronizer_.erase(input.input_params.batch_id); } } + #endif // run the model on the given input in working thread if (!enable_schedule_overlap()) { - const auto output = this->step(inputs); + const auto output = this->step(input); promise.setValue(output); } else { - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - if (last_step_output_valid_ && - !inputs.micro_inputs[i].input_params.empty_kv_cache) { - // replace step i model input with true output of step i-1 - inputs.micro_inputs[i] = - update_input_by_last_step_output(inputs.micro_inputs[i]); - } + if (last_step_output_valid_ && !input.input_params.empty_kv_cache) { + // replace step i model input with true output of step i-1 + input = update_input_by_last_step_output(input); } - const auto output = this->step(inputs); + + const auto output = this->step(input); if (output.has_value()) { if (is_driver() || FLAGS_enable_eplb) { std::unique_lock lock(mtx_); diff --git a/xllm/core/runtime/worker_impl.h b/xllm/core/runtime/worker_impl.h index 888377c54..6c7ca5a74 100644 --- a/xllm/core/runtime/worker_impl.h +++ b/xllm/core/runtime/worker_impl.h @@ -112,12 +112,10 @@ class WorkerImpl { virtual ForwardInput prepare_inputs(Batch& batch); // prepare work before model execution - virtual void prepare_work_before_execute( - const BatchedForwardInputs& inputs, - BatchedForwardInputs& processed_inputs); + virtual void prepare_work_before_execute(const ForwardInput& inputs, + ForwardInput& processed_inputs); - virtual std::optional step( - const BatchedForwardInputs& inputs) = 0; + virtual std::optional step(const ForwardInput& inputs) = 0; virtual void process_group_test(); @@ -159,7 +157,7 @@ class WorkerImpl { // Run the model on the given input. async call // the future returns a successfull status with no meaningful value virtual folly::SemiFuture> step_async( - const BatchedForwardInputs& inputs); + const ForwardInput& inputs); virtual folly::SemiFuture process_group_test_async(); diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 989945667..981a6e43d 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -295,13 +295,13 @@ class LlmModelImplBase : public torch::nn::Module { auto attn_metadata = layer::AttentionMetadata::build(modified_input_params, is_prefill); - torch::Tensor h; + torch::Tensor h_ret; for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; - h = layer( + h_ret = layer( h, position, attn_metadata, kv_caches[i], modified_input_params); } - return norm_(h); + return norm_(h_ret); #endif } diff --git a/xllm/models/llm/mlu/deepseek_v2.h b/xllm/models/llm/mlu/deepseek_v2.h index 34bccdb86..733d3e312 100644 --- a/xllm/models/llm/mlu/deepseek_v2.h +++ b/xllm/models/llm/mlu/deepseek_v2.h @@ -124,16 +124,11 @@ class DeepseekV2ModelImpl : public torch::nn::Module { } // Provide batched signature to satisfy callers that pass vectors - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, std::vector& kv_caches, - const std::vector& input_params) { - if (!(tokens.size() == 1 && positions.size() == 1 && - input_params.size() == 1)) { - LOG(FATAL) - << "DeepseekV2ModelImpl only supports micro-batch size == 1 for now"; - } - return forward_native(tokens[0], positions[0], kv_caches, input_params[0]); + const ModelInputParams& input_params) { + return forward_native(tokens, positions, kv_caches, input_params); } // load the weight from the checkpoint @@ -148,12 +143,10 @@ class DeepseekV2ModelImpl : public torch::nn::Module { norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } - std::vector get_word_embedding() { - return {embed_tokens_}; - } + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - void set_word_embedding(std::vector& word_embedding) { - embed_tokens_ = word_embedding[0]; + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; } private: diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index 277167ddc..f6e5aa65b 100644 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -207,13 +207,13 @@ class QWen3ModelImpl : public LlmModelImplBase { auto attn_metadata = layer::AttentionMetadata::build(modified_input_params, is_prefill); - torch::Tensor h; + torch::Tensor h_ret; for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; - h = layer( + h_ret = layer( h, positions, attn_metadata, kv_caches[i], modified_input_params); } - return norm_(h); + return norm_(h_ret); #endif } diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index b635d3227..0c6075fdc 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -254,7 +254,7 @@ service DistributeWorker { rpc GetCacheInfo(Empty) returns (CacheInfo) {} rpc LinkCluster(ClusterInfo) returns (Status) {} rpc UnlinkCluster(ClusterInfo) returns (Status) {} - rpc ExecuteModel (BatchedForwardInputs) returns (ForwardOutput); + rpc ExecuteModel (ForwardInput) returns (ForwardOutput); rpc GetLastStepResult (Empty) returns (ForwardOutput); rpc GetActiveActivationMemory (Empty) returns (ActivationMemory); rpc TransferBlocks(BlockTransferInfos) returns (TransferStatus) {}