From 3eb71100aecc08af47bdacedbe9faa74798f7a93 Mon Sep 17 00:00:00 2001 From: liangzhiwei20 Date: Mon, 17 Nov 2025 21:49:27 +0800 Subject: [PATCH] feat: remove redundant input parameters by add batch forward type. --- .../framework/batch/batch_input_builder.cpp | 78 +++++++++++++++-- .../framework/batch/batch_input_builder.h | 4 +- xllm/core/framework/batch/batch_test.cpp | 2 +- xllm/core/framework/model/CMakeLists.txt | 1 + .../core/framework/model/batch_forward_type.h | 83 +++++++++++++++++++ .../core/framework/model/model_input_params.h | 28 ++----- .../npu_deepseek_v2_decoder_layer_impl.cpp | 5 +- .../layers/npu/npu_glm4_moe_decoder_layer.cpp | 3 +- .../npu/npu_llama_decoder_layer_impl.cpp | 3 +- .../npu/npu_qwen2_decoder_layer_impl.cpp | 3 +- .../npu/npu_qwen3_decoder_layer_impl.cpp | 6 +- .../npu/npu_qwen3_moe_decoder_layer_impl.cpp | 2 +- xllm/core/runtime/acl_graph_executor_impl.cpp | 9 +- xllm/core/runtime/forward_params.h | 3 +- .../runtime/forward_shared_memory_manager.cpp | 21 ++--- xllm/core/runtime/llm_engine.cpp | 17 ++-- xllm/core/runtime/llm_worker_impl.cpp | 20 +---- xllm/core/runtime/params_utils.cpp | 16 +--- xllm/core/runtime/speculative_worker_impl.cpp | 2 +- xllm/core/runtime/vlm_engine.cpp | 6 +- xllm/core/runtime/worker_impl.cpp | 8 +- xllm/models/llm/deepseek_v2.h | 2 +- xllm/models/llm/glm4_moe.h | 3 +- xllm/models/llm/glm4_moe_mtp.h | 3 +- xllm/models/llm/qwen3_moe.h | 3 +- xllm/proto/worker.proto | 4 +- 26 files changed, 214 insertions(+), 121 deletions(-) create mode 100644 xllm/core/framework/model/batch_forward_type.h diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 754321644..07e751ebd 100755 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -71,6 +71,7 @@ ForwardInput BatchInputBuilder::build_forward_input( uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size) { process_sequences(0, static_cast(num_sequences_)); + process_batch_forward_type(); padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size); return state_to_forward_input(); @@ -84,6 +85,7 @@ RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx, } else { process_sequences_multithreaded(start_idx, end_idx); } + process_batch_forward_type(); return state_to_raw_forward_input(); } @@ -189,7 +191,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, state_.unique_token_lens_vec.insert(state_.unique_token_lens_vec.end(), state.unique_token_lens_vec.begin(), state.unique_token_lens_vec.end()); - state_.empty_kv_cache = state_.empty_kv_cache && state.empty_kv_cache; state_.max_seq_len = std::max(state_.max_seq_len, state.max_seq_len); state_.q_max_seq_len = std::max(state_.q_max_seq_len, state.q_max_seq_len); #if defined(USE_NPU) @@ -278,7 +279,6 @@ void BatchInputBuilder::process_single_sequence( << allowed_max_tokens_[seq_index]; // Update state - state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0); state.max_seq_len = std::max(state.max_seq_len, seq_len); state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len); #if defined(USE_NPU) @@ -498,7 +498,7 @@ void BatchInputBuilder::padding_decode_batch_size( if (num_sequences_ < min_decoding_batch_size) { const uint32_t n_tokens = state_.flatten_tokens_vec.size(); // kv_cache is not empty in decoding phase - const bool in_decoding_phase = !state_.empty_kv_cache; + const bool in_decoding_phase = !state_.batch_forward_type.is_prefill(); const bool same_num_decoding_tokens = state_.q_max_seq_len == num_decoding_tokens && n_tokens == num_sequences_ * num_decoding_tokens; @@ -551,7 +551,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { } auto& input_params = forward_input.input_params; - input_params.empty_kv_cache = state_.empty_kv_cache; + input_params.batch_forward_type = state_.batch_forward_type; input_params.num_sequences = state_.block_tables_vec.size(); input_params.kv_max_seq_len = state_.max_seq_len; input_params.q_max_seq_len = state_.q_max_seq_len; @@ -561,8 +561,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.q_seq_lens_vec = std::move(state_.q_seq_lens); input_params.new_cache_slots = torch::tensor(state_.new_token_slot_ids, torch::kInt); - input_params.decode_seq_range = - util::find_ones_indices(input_params.q_seq_lens_vec); // for flashinfer input_params.paged_kv_indptr = @@ -644,8 +642,7 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { std::move(state_.unique_token_counts_vec); raw_forward_input.unique_token_lens_vec = std::move(state_.unique_token_lens_vec); - raw_forward_input.empty_kv_cache = state_.empty_kv_cache; - // raw_forward_input.global_empty_kv_cache = ; + raw_forward_input.batch_forward_type = state_.batch_forward_type; raw_forward_input.max_seq_len = state_.max_seq_len; raw_forward_input.q_max_seq_len = state_.q_max_seq_len; raw_forward_input.seq_lens = std::move(state_.seq_lens); @@ -727,4 +724,69 @@ void BatchInputBuilder::process_swap_block_infos( swap_block_transfer_infos_->end()); } } + +void BatchInputBuilder::process_batch_forward_type() { + CHECK_EQ(state_.seq_lens.size(), state_.q_seq_lens.size()) + << "seq_lens size must be equal to q_seq_lens size"; + + if (state_.q_max_seq_len == 1) { + state_.batch_forward_type = BatchForwardType::DECODE; + return; + } + + bool empty_kv_cache = true; + bool all_decode = true; + bool all_prefill = true; + +#if defined(USE_NPU) + if (state_.seq_lens.size() == 0) { + state_.batch_forward_type = BatchForwardType::EMPTY; + return; + } + for (size_t i = 0; i < state_.seq_lens.size(); ++i) { + auto q_len = state_.q_seq_lens[i]; + auto kv_len = state_.seq_lens[i]; + auto cache_len = kv_len - q_len; + if (cache_len > 0) { + empty_kv_cache = false; + } + if (q_len > 1) { + all_decode = false; + } + if (q_len == 1) { + all_prefill = false; + } + } +#elif defined(USE_MLU) + if (state_.seq_lens.size() == 1) { + state_.batch_forward_type = BatchForwardType::EMPTY; + return; + } + for (size_t i = 1; i < state_.seq_lens.size(); ++i) { + auto q_len = state_.q_seq_lens[i] - state_.q_seq_lens[i - 1]; + auto kv_len = state_.seq_lens[i] - state_.seq_lens[i - 1]; + auto cache_len = kv_len - q_len; + if (cache_len > 0) { + empty_kv_cache = false; + } + if (q_len > 1) { + all_decode = false; + } + if (q_len == 1) { + all_prefill = false; + } + } +#endif + if (empty_kv_cache) { + state_.batch_forward_type = BatchForwardType::PREFILL; + } else { + if (all_prefill) { + state_.batch_forward_type = BatchForwardType::CHUNKED_PREFILL; + } else if (all_decode) { + state_.batch_forward_type = BatchForwardType::DECODE; + } else { + state_.batch_forward_type = BatchForwardType::MIXED; + } + } +} } // namespace xllm diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 0c2a4cee8..df2f62e41 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -59,6 +59,8 @@ class BatchInputBuilder { void process_swap_block_infos(RawForwardInput& raw_forward_input); + void process_batch_forward_type(); + // State management struct BuilderState { // Token and position data @@ -77,7 +79,7 @@ class BatchInputBuilder { std::vector unique_token_lens_vec; // Sequence metadata - bool empty_kv_cache = true; + BatchForwardType batch_forward_type; uint32_t max_seq_len = 0; uint32_t q_max_seq_len = 0; #if defined(USE_NPU) diff --git a/xllm/core/framework/batch/batch_test.cpp b/xllm/core/framework/batch/batch_test.cpp index 2645fe564..ea319fcf9 100644 --- a/xllm/core/framework/batch/batch_test.cpp +++ b/xllm/core/framework/batch/batch_test.cpp @@ -145,7 +145,7 @@ TEST(BatchTest, Basic) { // check the input parameters const ModelInputParams& input_params = forward_input.input_params; - EXPECT_FALSE(input_params.empty_kv_cache); + EXPECT_TRUE(input_params.batch_forward_type.is_mixed()); EXPECT_EQ(input_params.num_sequences, 4); EXPECT_EQ(input_params.q_max_seq_len, 9); EXPECT_EQ(input_params.kv_max_seq_len, 16); diff --git a/xllm/core/framework/model/CMakeLists.txt b/xllm/core/framework/model/CMakeLists.txt index 2dcf74114..4f3fb4495 100644 --- a/xllm/core/framework/model/CMakeLists.txt +++ b/xllm/core/framework/model/CMakeLists.txt @@ -34,6 +34,7 @@ cc_library( embedding_lm.h model_args.h npu_dp_ep_padding.h + batch_forward_type.h model_input_params.h SRCS npu_dp_ep_padding.cpp diff --git a/xllm/core/framework/model/batch_forward_type.h b/xllm/core/framework/model/batch_forward_type.h new file mode 100644 index 000000000..a852334f3 --- /dev/null +++ b/xllm/core/framework/model/batch_forward_type.h @@ -0,0 +1,83 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +namespace xllm { + +class BatchForwardType { + public: + enum Value : int32_t { + // Prefill without using kv cache. + PREFILL = 0, + // Chunked prefill using kv cache. + // No decode sequence in this type. + CHUNKED_PREFILL = 1, + // Decode one token. + // No prefill sequence in this type. + DECODE = 2, + // Mixed prefill and decode in one batch when doing chunked prefill. + MIXED = 3, + // No sequence to forward. + EMPTY = 4, + }; + + BatchForwardType() : value_(EMPTY) {} + + BatchForwardType(int32_t v) : value_(static_cast(v)) {} + + constexpr BatchForwardType(Value v) : value_(v) {} + + BatchForwardType& operator=(Value v) { + value_ = v; + return *this; + } + + int32_t value() const { return value_; } + + bool is_prefill() const { return (value_ == PREFILL); } + + bool is_chunked_prefill() const { return (value_ == CHUNKED_PREFILL); } + + bool has_decode() const { return (value_ == DECODE || value_ == MIXED); } + + bool is_decode() const { return (value_ == DECODE); } + + bool is_mixed() const { return (value_ == MIXED); } + + bool is_empty() const { return (value_ == EMPTY); } + + const char* to_string() const { + switch (value_) { + case PREFILL: + return "PREFILL"; + case CHUNKED_PREFILL: + return "CHUNKED_PREFILL"; + case DECODE: + return "DECODE"; + case MIXED: + return "MIXED"; + case EMPTY: + return "EMPTY"; + default: + return "UNKNOWN"; + } + } + + private: + Value value_; +}; +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index fa69dd800..82cec82e2 100755 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -21,6 +21,7 @@ limitations under the License. #if defined(USE_NPU) #include "platform/npu/npu_layer_synchronizer.h" #endif +#include "framework/model/batch_forward_type.h" #include "framework/request/mm_data.h" #include "npu_dp_ep_padding.h" #include "util/tensor_helper.h" @@ -86,8 +87,7 @@ struct BlockTransferInfo { struct ModelInputParams { ModelInputParams to(const torch::Device& device) const { ModelInputParams params; - params.empty_kv_cache = empty_kv_cache; - params.global_empty_kv_cache = global_empty_kv_cache; + params.batch_forward_type = batch_forward_type; params.num_sequences = num_sequences; params.kv_max_seq_len = kv_max_seq_len; params.q_max_seq_len = q_max_seq_len; @@ -99,7 +99,6 @@ struct ModelInputParams { params.block_tables = safe_to(block_tables, device, true); params.kv_seq_lens_vec = kv_seq_lens_vec; params.q_seq_lens_vec = q_seq_lens_vec; - params.decode_seq_range = decode_seq_range; params.input_embedding = safe_to(input_embedding, device); @@ -141,15 +140,13 @@ struct ModelInputParams { } void print() const { - LOG(INFO) << "ModelInputParams: empty_kv_cache is " << empty_kv_cache - << " , global_empty_kv_cache is " << global_empty_kv_cache - << " , num_sequences is " << num_sequences - << " , kv_max_seq_len is " << kv_max_seq_len + LOG(INFO) << "ModelInputParams: batch_forward_type is " + << batch_forward_type.to_string() << " , num_sequences is " + << num_sequences << " , kv_max_seq_len is " << kv_max_seq_len << " , q_max_seq_len is " << q_max_seq_len << " , prefill_seq_len is " << prefill_seq_len; LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec; LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec; - LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range; print_tensor(kv_seq_lens, "ModelInputParams: kv_seq_lens", 4); print_tensor(q_seq_lens, "ModelInputParams: q_seq_lens", 4); print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4); @@ -157,8 +154,8 @@ struct ModelInputParams { LOG(INFO) << "ModelInputParams: dp_global_token_nums is " << dp_global_token_nums; } - // whether the kv-cache is empty for all sequences. - bool empty_kv_cache = true; + // forward type of the batch, used by worker/kernel. + BatchForwardType batch_forward_type; // total number of sequences in the batch int32_t num_sequences = 0; @@ -167,15 +164,6 @@ struct ModelInputParams { torch::Tensor kv_seq_lens; std::vector kv_seq_lens_vec; std::vector q_seq_lens_vec; - // Range of decode sequence indices in the batch [start, end]. - // Decode sequences are identified by q_seq_lens == 1, - // prefill sequences by q_seq_lens > 1 . - // Used to determine whether to use prefill_node_ or - // decode_node_ in NPU layers - // Values: {-1, -1} if no decode requests (all prefill), - // {0, batch_size-1} if all decode requests, - // {start_idx, end_idx} if mixed prefill/decode requests - std::pair decode_seq_range; // max length for qkv. int32_t kv_max_seq_len = 0; int32_t q_max_seq_len = 0; @@ -199,8 +187,6 @@ struct ModelInputParams { // num tokens of all workers,mainly used for dp case std::vector dp_global_token_nums; - // whether the kv-cache is empty for all sequences,mainly used for dp case - bool global_empty_kv_cache = true; // num of prefill sequence in chunked prefill case uint32_t prefill_seq_len = 0; diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index f8417d6e8..2e312cf36 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -1523,9 +1523,8 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( std::vector*> event_flag, int node_id) { atb::Status st; - // all micro batches are in same prefill/decode stage, - // so, to judge empty_kv_cache, use input_params[0] here - if (input_params[0].global_empty_kv_cache) { + // deepseek dont support chunked prefill, so only check is_prefill. + if (input_params[0].batch_forward_type.is_prefill()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp index 927b98063..20429a83a 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp @@ -1085,8 +1085,7 @@ torch::Tensor Glm4MoeDecoderImpl::forward( std::vector*> event_flag, int node_id) { atb::Status st; - if (input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1) { + if (!input_params.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp index 9696353ce..eb7794c2c 100644 --- a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp @@ -277,8 +277,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x, int node_id) { atb::Status st; - if (input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1) { + if (!input_params.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index 0f788904e..c147fc5ca 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -405,8 +405,7 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward( std::vector*> event_flag, int node_id) { atb::Status st; - if (input_params[0].decode_seq_range.second != - input_params[0].q_seq_lens.size(0) - 1) { + if (!input_params[0].batch_forward_type.is_decode()) { // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); build_node_variant_pack(prefill_node_, x[0], diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index d48379fb4..a69f76647 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -526,10 +526,7 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward( std::vector*> event_flag, int node_id) { atb::Status st; - if (input_params[0].decode_seq_range.second != - input_params[0].q_seq_lens.size(0) - 1) { - // if (input_params.empty_kv_cache) { - // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); + if (!input_params[0].batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, @@ -538,7 +535,6 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward( kv_cache, input_params, true); - // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index f7d74347d..6913a6ddf 100755 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -865,7 +865,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( std::atomic* event_flag, int node_id) { atb::Status st; - if (input_params.global_empty_kv_cache) { + if (input_params.batch_forward_type.is_prefill()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 4a603945a..23988fdae 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -196,16 +196,9 @@ torch::Tensor AclGraphExecutorImpl::run( const torch::Tensor& tokens_tensor = tokens[0]; const torch::Tensor& positions_tensor = positions[0]; const ModelInputParams& params_single = params[0]; - // Identify decode phase using q_max_seq_len for precise detection - // Decode phase: all sequences have q_seq_len == 1 (generating one token at a - // time) Prefill phase: sequences have q_seq_len > 1 (processing multiple - // prompt tokens) We also check empty_kv_cache to ensure KV cache is not empty - // (not first forward pass) - const bool in_decoding_phase = - (params_single.q_max_seq_len == 1) && !params_single.empty_kv_cache; // If not in decode phase, use eager mode directly without acl graph - if (!in_decoding_phase) { + if (!params_single.batch_forward_type.is_decode()) { COUNTER_INC(num_model_execution_total_eager); return model_->forward(tokens, positions, kv_caches, params); } diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 4ddad7b97..54a29f7fa 100755 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -146,8 +146,7 @@ struct RawForwardInput { std::vector> unique_token_ids_vec; std::vector> unique_token_counts_vec; std::vector unique_token_lens_vec; - bool empty_kv_cache = true; - bool global_empty_kv_cache = true; + BatchForwardType batch_forward_type; uint32_t max_seq_len; uint32_t q_max_seq_len; std::vector seq_lens; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 5ddcc4ff0..769b58181 100755 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -149,7 +149,7 @@ INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) { total += type_size + input.swap_blocks.size() * swap_block_info_fixed_size(); - total += type_size * 2 // empty_kv_cache + global_empty_kv_cache + total += type_size // batch_forward_type + type_size * 3 // max_seq_len + q_max_seq_len + prefill_seq_len + type_size // num_sequences @@ -567,8 +567,9 @@ INLINE void deserialize_raw_forward_input( read_swap_blocks(buffer, input.swap_blocks); read_data(buffer, input.batch_id); - read_data(buffer, input.empty_kv_cache); - read_data(buffer, input.global_empty_kv_cache); + int32_t batch_forward_type; + read_data(buffer, batch_forward_type); + input.batch_forward_type = BatchForwardType(batch_forward_type); read_data(buffer, input.max_seq_len); read_data(buffer, input.q_max_seq_len); read_data(buffer, input.num_sequences); @@ -619,8 +620,7 @@ INLINE void serialize_raw_forward_input(const RawForwardInput& input, write_swap_blocks(buffer, input.swap_blocks); write_data(buffer, input.batch_id); - write_data(buffer, input.empty_kv_cache); - write_data(buffer, input.global_empty_kv_cache); + write_data(buffer, input.batch_forward_type.value()); write_data(buffer, input.max_seq_len); write_data(buffer, input.q_max_seq_len); write_data(buffer, input.num_sequences); @@ -814,15 +814,9 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, forward_input.positions = create_2d_tensor(std::move(raw_input.m_positions_vec), torch::kInt); } - std::pair decode_seq_range{0, 0}; -#if defined(USE_NPU) - if (raw_input.q_seq_lens.size() >= 1) { - decode_seq_range = util::find_ones_indices(raw_input.q_seq_lens); - } -#endif + auto& input_params = forward_input.input_params; - input_params.empty_kv_cache = raw_input.empty_kv_cache; - input_params.global_empty_kv_cache = raw_input.global_empty_kv_cache; + input_params.batch_forward_type = raw_input.batch_forward_type; input_params.num_sequences = raw_input.num_sequences; input_params.kv_max_seq_len = raw_input.max_seq_len; input_params.q_max_seq_len = raw_input.q_max_seq_len; @@ -839,7 +833,6 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.new_cache_slots = torch::tensor(std::move(raw_input.new_token_slot_ids), tensor_options); - input_params.decode_seq_range = decode_seq_range; util::pad_2d_vector(raw_input.block_tables_vec, 0); input_params.block_tables = create_2d_tensor(std::move(raw_input.block_tables_vec), torch::kInt); diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index 31c54e1fc..694dc88a3 100755 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -885,7 +885,8 @@ std::vector> LLMEngine::prepare_inputs( std::vector> dp_global_token_nums; dp_global_token_nums.resize(micro_batches_num, std::vector(dp_size_)); - bool global_empty_kv_cache = true; + // All empty batches use the first non-empty batch's forward type. + BatchForwardType batch_forward_type; // eplb related EplbInfo eplb_info; @@ -903,8 +904,12 @@ std::vector> LLMEngine::prepare_inputs( threadpool_.get()))); dp_global_token_nums[i][dp_rank] = batched_inputs[dp_rank][i].flatten_tokens_vec.size(); - global_empty_kv_cache = - batched_inputs[dp_rank][i].empty_kv_cache && global_empty_kv_cache; + if (batched_inputs[dp_rank][i].batch_forward_type.is_empty()) { + continue; + } + if (batch_forward_type.is_empty() || batch_forward_type.is_prefill()) { + batch_forward_type = batched_inputs[dp_rank][i].batch_forward_type; + } } } @@ -912,11 +917,13 @@ std::vector> LLMEngine::prepare_inputs( eplb_info = eplb_manager_->get_eplb_info(); } - // update dp_global_token_nums and global_empty_kv_cache + // update dp_global_token_nums and batch_forward_type for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { for (auto i = 0; i < micro_batches_num; ++i) { batched_inputs[dp_rank][i].dp_global_token_nums = dp_global_token_nums[i]; - batched_inputs[dp_rank][i].global_empty_kv_cache = global_empty_kv_cache; + if (batched_inputs[dp_rank][i].batch_forward_type.is_empty()) { + batched_inputs[dp_rank][i].batch_forward_type = batch_forward_type; + } if (FLAGS_enable_eplb) { batched_inputs[dp_rank][i].eplb_info = eplb_info; } diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 65318bba1..7d279c910 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -186,26 +186,10 @@ std::optional LLMWorkerImpl::step( } // if running in multi_stream_parallel step, all micro batches - // should be in same prefill stage, so, to judge empty_kv_cache, + // should be in same prefill stage, so, to judge batch_forward_type, // just use micro batch 0 here if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { - output.sample_output.embeddings = hidden_states; - } else if (concated_sampling_params.sample_idxes.defined()) { - // auto sample_idxes = - // concated_sampling_params.selected_token_idxes.index_select( - // /*dim=*/0, concated_sampling_params.sample_idxes); - auto embeddings = hidden_states.index_select( - /*dim=*/0, concated_sampling_params.sample_idxes); - output.sample_output.embeddings = embeddings; - } - } - - // if running in multi_stream_parallel step, all micro batches - // should be in same prefill stage, so, to judge empty_kv_cache, - // just use micro batch 0 here - if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { + if (input_params_micro_batches[0].batch_forward_type.is_decode()) { output.sample_output.embeddings = hidden_states; } else if (concated_sampling_params.sample_idxes.defined()) { // auto sample_idxes = diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 909d4e7ca..96731b906 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -177,16 +177,10 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, forward_inputs.acc_logprob = torch::tensor( acc_logprob_vec, torch::dtype(torch::kFloat32).device(torch::kCPU).pinned_memory(true)); - std::pair decode_seq_range{0, 0}; -#if defined(USE_NPU) - if (q_seq_lens.size() >= 1) { - decode_seq_range = util::find_ones_indices(q_seq_lens); - } -#endif + auto& input_params = forward_inputs.input_params; - input_params.empty_kv_cache = pb_forward_input->empty_kv_cache(); - input_params.global_empty_kv_cache = - pb_forward_input->global_empty_kv_cache(); + input_params.batch_forward_type = + BatchForwardType(pb_forward_input->batch_forward_type()); input_params.num_sequences = block_tables_vec.size(); assert(input_params.num_sequences == pb_forward_input->num_sequences()); input_params.prefill_seq_len = pb_forward_input->prefill_seq_len(); @@ -205,7 +199,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.new_cache_slots = torch::tensor(new_token_slot_ids, tensor_options); - input_params.decode_seq_range = decode_seq_range; util::pad_2d_vector(block_tables_vec, /*pad_value=*/0); input_params.block_tables = @@ -378,8 +371,7 @@ void forward_input_to_proto(const RawForwardInput& inputs, } ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_unique_token_lens_vec(), inputs.unique_token_lens_vec); - pb_forward_input->set_empty_kv_cache(inputs.empty_kv_cache); - pb_forward_input->set_global_empty_kv_cache(inputs.global_empty_kv_cache); + pb_forward_input->set_batch_forward_type(inputs.batch_forward_type.value()); pb_forward_input->set_max_seq_len(inputs.max_seq_len); pb_forward_input->set_q_max_seq_len(inputs.q_max_seq_len); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_seq_lens(), inputs.seq_lens); diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 8e2c5a06f..9770a7c9a 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -613,7 +613,7 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( input_params.block_tables = create_2d_tensor(block_tables_vec, torch::kInt).to(device_); } - input_params.decode_seq_range.second = input_params.num_sequences - 1; + input_params.batch_forward_type = BatchForwardType::DECODE; // update the sampling_params update_sampling_params( diff --git a/xllm/core/runtime/vlm_engine.cpp b/xllm/core/runtime/vlm_engine.cpp index 8ec675098..39522ee62 100644 --- a/xllm/core/runtime/vlm_engine.cpp +++ b/xllm/core/runtime/vlm_engine.cpp @@ -445,7 +445,6 @@ std::vector> VLMEngine::prepare_inputs( std::vector> dp_global_token_nums; dp_global_token_nums.resize(micro_batches_num, std::vector(dp_size_)); - bool global_empty_kv_cache = true; // build model input for every single micro batch for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { @@ -460,16 +459,13 @@ std::vector> VLMEngine::prepare_inputs( threadpool_.get()))); dp_global_token_nums[i][dp_rank] = batched_inputs[dp_rank][i].flatten_tokens_vec.size(); - global_empty_kv_cache = - batched_inputs[dp_rank][i].empty_kv_cache && global_empty_kv_cache; } } - // update dp_global_token_nums and global_empty_kv_cache + // update dp_global_token_nums for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { for (auto i = 0; i < micro_batches_num; ++i) { batched_inputs[dp_rank][i].dp_global_token_nums = dp_global_token_nums[i]; - batched_inputs[dp_rank][i].global_empty_kv_cache = global_empty_kv_cache; } } diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index c859ede6f..f20277ccb 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -461,9 +461,8 @@ void WorkerImpl::prepare_work_before_execute( .device(torch::kCPU) .dtype(torch::kInt32) .pinned_memory(true)); - bool is_prefill = fwd_inputs_on_device.input_params.global_empty_kv_cache - ? true - : false; + bool is_prefill = + fwd_inputs_on_device.input_params.batch_forward_type.is_prefill(); DpEpPadding dp_ep_padding(token_size_per_dp_group, context_.get_model_args().num_experts_per_tok(), context_.get_parallel_args().mapping_data(), @@ -521,7 +520,8 @@ folly::SemiFuture> WorkerImpl::step_async( } else { for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { if (last_step_output_valid_ && - !inputs.micro_inputs[i].input_params.empty_kv_cache) { + inputs.micro_inputs[i] + .input_params.batch_forward_type.has_decode()) { // replace step i model input with true output of step i-1 inputs.micro_inputs[i] = update_input_by_last_step_output(inputs.micro_inputs[i]); diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 174dde774..4cf3ecac6 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -185,7 +185,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module { torch::Tensor attn_mask; if (num_speculative_tokens_ == 0 || - input_params[i].global_empty_kv_cache) { + input_params[i].batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/models/llm/glm4_moe.h b/xllm/models/llm/glm4_moe.h index 41a5cddc2..03ea890fb 100644 --- a/xllm/models/llm/glm4_moe.h +++ b/xllm/models/llm/glm4_moe.h @@ -162,7 +162,8 @@ class Glm4MoeModelImpl : public torch::nn::Module { attn_mask = torch::cat(req_mask_vec, 0); } } else { - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + if (num_speculative_tokens_ == 0 || + input_params.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/models/llm/glm4_moe_mtp.h b/xllm/models/llm/glm4_moe_mtp.h index 83ff17bdf..5d671832e 100644 --- a/xllm/models/llm/glm4_moe_mtp.h +++ b/xllm/models/llm/glm4_moe_mtp.h @@ -132,7 +132,8 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module { attn_mask = torch::cat(req_mask_vec, 0); } } else { - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + if (num_speculative_tokens_ == 0 || + input_params.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 62de1214d..3ac5005d2 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -248,7 +248,8 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } torch::Tensor attn_mask; - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + if (num_speculative_tokens_ == 0 || + input_params.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index fa728e319..ef8bec1d1 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -168,7 +168,7 @@ message ForwardInput { repeated UniqueTokenIds unique_token_ids_vec = 6; repeated UniqueTokenCounts unique_token_counts_vec = 7; repeated int32 unique_token_lens_vec = 8; - bool empty_kv_cache = 9; + int32 batch_forward_type = 9; uint32 max_seq_len = 10; uint32 q_max_seq_len = 11; repeated int32 seq_lens = 12; @@ -181,7 +181,7 @@ message ForwardInput { repeated BlockTables block_tables_vec = 18; int32 num_sequences = 19; repeated int32 dp_global_token_nums = 20; - bool global_empty_kv_cache = 21; + // bool global_empty_kv_cache = 21; repeated TransferKVInfo transfer_kv_infos = 22; repeated Embeddings embeds = 23; uint32 prefill_seq_len = 24;