Skip to content

Commit 074d07f

Browse files
committed
feat: remove redundant input parameters by add batch forward type.
1 parent 9bbd770 commit 074d07f

26 files changed

+214
-121
lines changed

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ ForwardInput BatchInputBuilder::build_forward_input(
7171
uint32_t num_decoding_tokens,
7272
uint32_t min_decoding_batch_size) {
7373
process_sequences(0, static_cast<uint32_t>(num_sequences_));
74+
process_batch_forward_type();
7475
padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size);
7576

7677
return state_to_forward_input();
@@ -84,6 +85,7 @@ RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx,
8485
} else {
8586
process_sequences_multithreaded(start_idx, end_idx);
8687
}
88+
process_batch_forward_type();
8789
return state_to_raw_forward_input();
8890
}
8991

@@ -189,7 +191,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
189191
state_.unique_token_lens_vec.insert(state_.unique_token_lens_vec.end(),
190192
state.unique_token_lens_vec.begin(),
191193
state.unique_token_lens_vec.end());
192-
state_.empty_kv_cache = state_.empty_kv_cache && state.empty_kv_cache;
193194
state_.max_seq_len = std::max(state_.max_seq_len, state.max_seq_len);
194195
state_.q_max_seq_len = std::max(state_.q_max_seq_len, state.q_max_seq_len);
195196
#if defined(USE_NPU)
@@ -278,7 +279,6 @@ void BatchInputBuilder::process_single_sequence(
278279
<< allowed_max_tokens_[seq_index];
279280

280281
// Update state
281-
state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0);
282282
state.max_seq_len = std::max(state.max_seq_len, seq_len);
283283
state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len);
284284
#if defined(USE_NPU)
@@ -498,7 +498,7 @@ void BatchInputBuilder::padding_decode_batch_size(
498498
if (num_sequences_ < min_decoding_batch_size) {
499499
const uint32_t n_tokens = state_.flatten_tokens_vec.size();
500500
// kv_cache is not empty in decoding phase
501-
const bool in_decoding_phase = !state_.empty_kv_cache;
501+
const bool in_decoding_phase = !state_.batch_forward_type.is_prefill();
502502
const bool same_num_decoding_tokens =
503503
state_.q_max_seq_len == num_decoding_tokens &&
504504
n_tokens == num_sequences_ * num_decoding_tokens;
@@ -551,7 +551,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
551551
}
552552

553553
auto& input_params = forward_input.input_params;
554-
input_params.empty_kv_cache = state_.empty_kv_cache;
554+
input_params.batch_forward_type = state_.batch_forward_type;
555555
input_params.num_sequences = state_.block_tables_vec.size();
556556
input_params.kv_max_seq_len = state_.max_seq_len;
557557
input_params.q_max_seq_len = state_.q_max_seq_len;
@@ -561,8 +561,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
561561
input_params.q_seq_lens_vec = std::move(state_.q_seq_lens);
562562
input_params.new_cache_slots =
563563
torch::tensor(state_.new_token_slot_ids, torch::kInt);
564-
input_params.decode_seq_range =
565-
util::find_ones_indices(input_params.q_seq_lens_vec);
566564

567565
// for flashinfer
568566
input_params.paged_kv_indptr =
@@ -644,8 +642,7 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
644642
std::move(state_.unique_token_counts_vec);
645643
raw_forward_input.unique_token_lens_vec =
646644
std::move(state_.unique_token_lens_vec);
647-
raw_forward_input.empty_kv_cache = state_.empty_kv_cache;
648-
// raw_forward_input.global_empty_kv_cache = ;
645+
raw_forward_input.batch_forward_type = state_.batch_forward_type;
649646
raw_forward_input.max_seq_len = state_.max_seq_len;
650647
raw_forward_input.q_max_seq_len = state_.q_max_seq_len;
651648
raw_forward_input.seq_lens = std::move(state_.seq_lens);
@@ -727,4 +724,69 @@ void BatchInputBuilder::process_swap_block_infos(
727724
swap_block_transfer_infos_->end());
728725
}
729726
}
727+
728+
void BatchInputBuilder::process_batch_forward_type() {
729+
CHECK_EQ(state_.seq_lens.size(), state_.q_seq_lens.size())
730+
<< "seq_lens size must be equal to q_seq_lens size";
731+
732+
if (state_.q_max_seq_len == 1) {
733+
state_.batch_forward_type = BatchForwardType::DECODE;
734+
return;
735+
}
736+
737+
bool empty_kv_cache = true;
738+
bool all_decode = true;
739+
bool all_prefill = true;
740+
741+
#if defined(USE_NPU)
742+
if (state_.seq_lens.size() == 0) {
743+
state_.batch_forward_type = BatchForwardType::EMPTY;
744+
return;
745+
}
746+
for (size_t i = 0; i < state_.seq_lens.size(); ++i) {
747+
auto q_len = state_.q_seq_lens[i];
748+
auto kv_len = state_.seq_lens[i];
749+
auto cache_len = kv_len - q_len;
750+
if (cache_len > 0) {
751+
empty_kv_cache = false;
752+
}
753+
if (q_len > 1) {
754+
all_decode = false;
755+
}
756+
if (q_len == 1) {
757+
all_prefill = false;
758+
}
759+
}
760+
#elif defined(USE_MLU)
761+
if (state_.seq_lens.size() == 1) {
762+
state_.batch_forward_type = BatchForwardType::EMPTY;
763+
return;
764+
}
765+
for (size_t i = 1; i < state_.seq_lens.size(); ++i) {
766+
auto q_len = state_.q_seq_lens[i] - state_.q_seq_lens[i - 1];
767+
auto kv_len = state_.seq_lens[i] - state_.seq_lens[i - 1];
768+
auto cache_len = kv_len - q_len;
769+
if (cache_len > 0) {
770+
empty_kv_cache = false;
771+
}
772+
if (q_len > 1) {
773+
all_decode = false;
774+
}
775+
if (q_len == 1) {
776+
all_prefill = false;
777+
}
778+
}
779+
#endif
780+
if (empty_kv_cache) {
781+
state_.batch_forward_type = BatchForwardType::PREFILL;
782+
} else {
783+
if (all_prefill) {
784+
state_.batch_forward_type = BatchForwardType::CHUNKED_PREFILL;
785+
} else if (all_decode) {
786+
state_.batch_forward_type = BatchForwardType::DECODE;
787+
} else {
788+
state_.batch_forward_type = BatchForwardType::MIXED;
789+
}
790+
}
791+
}
730792
} // namespace xllm

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class BatchInputBuilder {
5959

6060
void process_swap_block_infos(RawForwardInput& raw_forward_input);
6161

62+
void process_batch_forward_type();
63+
6264
// State management
6365
struct BuilderState {
6466
// Token and position data
@@ -77,7 +79,7 @@ class BatchInputBuilder {
7779
std::vector<int32_t> unique_token_lens_vec;
7880

7981
// Sequence metadata
80-
bool empty_kv_cache = true;
82+
BatchForwardType batch_forward_type;
8183
uint32_t max_seq_len = 0;
8284
uint32_t q_max_seq_len = 0;
8385
#if defined(USE_NPU)

xllm/core/framework/batch/batch_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ TEST(BatchTest, Basic) {
145145

146146
// check the input parameters
147147
const ModelInputParams& input_params = forward_input.input_params;
148-
EXPECT_FALSE(input_params.empty_kv_cache);
148+
EXPECT_FALSE(input_params.batch_forward_type.is_mixed());
149149
EXPECT_EQ(input_params.num_sequences, 4);
150150
EXPECT_EQ(input_params.q_max_seq_len, 9);
151151
EXPECT_EQ(input_params.kv_max_seq_len, 16);

xllm/core/framework/model/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ cc_library(
3434
embedding_lm.h
3535
model_args.h
3636
npu_dp_ep_padding.h
37+
batch_forward_type.h
3738
model_input_params.h
3839
SRCS
3940
npu_dp_ep_padding.cpp
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================*/
16+
17+
#pragma once
18+
19+
namespace xllm {
20+
21+
class BatchForwardType {
22+
public:
23+
enum Value : int32_t {
24+
// Prefill without using kv cache.
25+
PREFILL = 0,
26+
// Chunked prefill using kv cache.
27+
// No decode sequence in this type.
28+
CHUNKED_PREFILL = 1,
29+
// Decode one token.
30+
// No prefill sequence in this type.
31+
DECODE = 2,
32+
// Mixed prefill and decode in one batch when doing chunked prefill.
33+
MIXED = 3,
34+
// No sequence to forward.
35+
EMPTY = 4,
36+
};
37+
38+
BatchForwardType() : value_(EMPTY) {}
39+
40+
BatchForwardType(int32_t v) : value_(static_cast<Value>(v)) {}
41+
42+
constexpr BatchForwardType(Value v) : value_(v) {}
43+
44+
BatchForwardType& operator=(Value v) {
45+
value_ = v;
46+
return *this;
47+
}
48+
49+
int32_t value() const { return value_; }
50+
51+
bool is_prefill() const { return (value_ == PREFILL); }
52+
53+
bool is_chunked_prefill() const { return (value_ == CHUNKED_PREFILL); }
54+
55+
bool has_decode() const { return (value_ == DECODE || value_ == MIXED); }
56+
57+
bool is_decode() const { return (value_ == DECODE); }
58+
59+
bool is_mixed() const { return (value_ == MIXED); }
60+
61+
bool is_empty() const { return (value_ == EMPTY); }
62+
63+
const char* to_string() const {
64+
switch (value_) {
65+
case PREFILL:
66+
return "PREFILL";
67+
case CHUNKED_PREFILL:
68+
return "CHUNKED_PREFILL";
69+
case DECODE:
70+
return "DECODE";
71+
case MIXED:
72+
return "MIXED";
73+
case EMPTY:
74+
return "EMPTY";
75+
default:
76+
return "UNKNOWN";
77+
}
78+
}
79+
80+
private:
81+
Value value_;
82+
};
83+
} // namespace xllm

xllm/core/framework/model/model_input_params.h

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#if defined(USE_NPU)
2222
#include "platform/npu/npu_layer_synchronizer.h"
2323
#endif
24+
#include "framework/model/batch_forward_type.h"
2425
#include "framework/request/mm_data.h"
2526
#include "npu_dp_ep_padding.h"
2627
#include "util/tensor_helper.h"
@@ -86,8 +87,7 @@ struct BlockTransferInfo {
8687
struct ModelInputParams {
8788
ModelInputParams to(const torch::Device& device) const {
8889
ModelInputParams params;
89-
params.empty_kv_cache = empty_kv_cache;
90-
params.global_empty_kv_cache = global_empty_kv_cache;
90+
params.batch_forward_type = batch_forward_type;
9191
params.num_sequences = num_sequences;
9292
params.kv_max_seq_len = kv_max_seq_len;
9393
params.q_max_seq_len = q_max_seq_len;
@@ -99,7 +99,6 @@ struct ModelInputParams {
9999
params.block_tables = safe_to(block_tables, device, true);
100100
params.kv_seq_lens_vec = kv_seq_lens_vec;
101101
params.q_seq_lens_vec = q_seq_lens_vec;
102-
params.decode_seq_range = decode_seq_range;
103102

104103
params.input_embedding = safe_to(input_embedding, device);
105104

@@ -141,24 +140,22 @@ struct ModelInputParams {
141140
}
142141

143142
void print() const {
144-
LOG(INFO) << "ModelInputParams: empty_kv_cache is " << empty_kv_cache
145-
<< " , global_empty_kv_cache is " << global_empty_kv_cache
146-
<< " , num_sequences is " << num_sequences
147-
<< " , kv_max_seq_len is " << kv_max_seq_len
143+
LOG(INFO) << "ModelInputParams: batch_forward_type is "
144+
<< batch_forward_type.to_string() << " , num_sequences is "
145+
<< num_sequences << " , kv_max_seq_len is " << kv_max_seq_len
148146
<< " , q_max_seq_len is " << q_max_seq_len
149147
<< " , prefill_seq_len is " << prefill_seq_len;
150148
LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec;
151149
LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec;
152-
LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range;
153150
print_tensor(kv_seq_lens, "ModelInputParams: kv_seq_lens", 4);
154151
print_tensor(q_seq_lens, "ModelInputParams: q_seq_lens", 4);
155152
print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4);
156153
print_tensor(block_tables, "ModelInputParams: block_tables", 4);
157154
LOG(INFO) << "ModelInputParams: dp_global_token_nums is "
158155
<< dp_global_token_nums;
159156
}
160-
// whether the kv-cache is empty for all sequences.
161-
bool empty_kv_cache = true;
157+
// forward type of the batch, used by worker/kernel.
158+
BatchForwardType batch_forward_type;
162159

163160
// total number of sequences in the batch
164161
int32_t num_sequences = 0;
@@ -167,15 +164,6 @@ struct ModelInputParams {
167164
torch::Tensor kv_seq_lens;
168165
std::vector<int> kv_seq_lens_vec;
169166
std::vector<int> q_seq_lens_vec;
170-
// Range of decode sequence indices in the batch [start, end].
171-
// Decode sequences are identified by q_seq_lens == 1,
172-
// prefill sequences by q_seq_lens > 1 .
173-
// Used to determine whether to use prefill_node_ or
174-
// decode_node_ in NPU layers
175-
// Values: {-1, -1} if no decode requests (all prefill),
176-
// {0, batch_size-1} if all decode requests,
177-
// {start_idx, end_idx} if mixed prefill/decode requests
178-
std::pair<int, int> decode_seq_range;
179167
// max length for qkv.
180168
int32_t kv_max_seq_len = 0;
181169
int32_t q_max_seq_len = 0;
@@ -199,8 +187,6 @@ struct ModelInputParams {
199187

200188
// num tokens of all workers,mainly used for dp case
201189
std::vector<int32_t> dp_global_token_nums;
202-
// whether the kv-cache is empty for all sequences,mainly used for dp case
203-
bool global_empty_kv_cache = true;
204190

205191
// num of prefill sequence in chunked prefill case
206192
uint32_t prefill_seq_len = 0;

xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,9 +1523,8 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward(
15231523
std::vector<std::atomic<bool>*> event_flag,
15241524
int node_id) {
15251525
atb::Status st;
1526-
// all micro batches are in same prefill/decode stage,
1527-
// so, to judge empty_kv_cache, use input_params[0] here
1528-
if (input_params[0].global_empty_kv_cache) {
1526+
// deepseek dont support chunked prefill, so only check is_prefill.
1527+
if (input_params[0].batch_forward_type.is_prefill()) {
15291528
build_node_variant_pack(prefill_node_,
15301529
x,
15311530
cos_pos,

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,8 +1085,7 @@ torch::Tensor Glm4MoeDecoderImpl::forward(
10851085
std::vector<std::atomic<bool>*> event_flag,
10861086
int node_id) {
10871087
atb::Status st;
1088-
if (input_params.decode_seq_range.second !=
1089-
input_params.q_seq_lens.size(0) - 1) {
1088+
if (!input_params.batch_forward_type.is_decode()) {
10901089
build_node_variant_pack(prefill_node_,
10911090
x,
10921091
cos_pos,

xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x,
277277
int node_id) {
278278
atb::Status st;
279279

280-
if (input_params.decode_seq_range.second !=
281-
input_params.q_seq_lens.size(0) - 1) {
280+
if (!input_params.batch_forward_type.is_decode()) {
282281
build_node_variant_pack(prefill_node_,
283282
x,
284283
cos_pos,

xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,7 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward(
405405
std::vector<std::atomic<bool>*> event_flag,
406406
int node_id) {
407407
atb::Status st;
408-
if (input_params[0].decode_seq_range.second !=
409-
input_params[0].q_seq_lens.size(0) - 1) {
408+
if (!input_params[0].batch_forward_type.is_decode()) {
410409
// mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr);
411410
build_node_variant_pack(prefill_node_,
412411
x[0],

0 commit comments

Comments
 (0)