Skip to content

Commit 6c367fa

Browse files
RobbieLeungyq33victor
authored andcommitted
feat: add batch forward type.
1 parent 53a6858 commit 6c367fa

12 files changed

+192
-27
lines changed

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ ForwardInput BatchInputBuilder::build_forward_input(
9090
uint32_t min_decoding_batch_size) {
9191
process_sequences(0, static_cast<uint32_t>(num_sequences_));
9292
padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size);
93-
93+
process_batch_forward_type();
9494
return state_to_forward_input();
9595
}
9696

@@ -102,6 +102,7 @@ RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx,
102102
} else {
103103
process_sequences_multithreaded(start_idx, end_idx);
104104
}
105+
process_batch_forward_type();
105106
return state_to_raw_forward_input();
106107
}
107108

@@ -548,6 +549,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
548549

549550
auto& input_params = forward_input.input_params;
550551
input_params.empty_kv_cache = state_.empty_kv_cache;
552+
input_params.batch_forward_type = state_.batch_forward_type;
551553
input_params.num_sequences = state_.block_tables_vec.size();
552554
input_params.kv_max_seq_len = state_.max_seq_len;
553555
input_params.q_max_seq_len = state_.q_max_seq_len;
@@ -633,7 +635,7 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
633635
raw_forward_input.unique_token_lens_vec =
634636
std::move(state_.unique_token_lens_vec);
635637
raw_forward_input.empty_kv_cache = state_.empty_kv_cache;
636-
// raw_forward_input.global_empty_kv_cache = ;
638+
raw_forward_input.batch_forward_type = state_.batch_forward_type;
637639
raw_forward_input.max_seq_len = state_.max_seq_len;
638640
raw_forward_input.q_max_seq_len = state_.q_max_seq_len;
639641
raw_forward_input.seq_lens = std::move(state_.seq_lens);
@@ -723,4 +725,69 @@ void BatchInputBuilder::process_swap_block_infos(
723725
swap_cache_block_infos_->end());
724726
}
725727
}
728+
729+
void BatchInputBuilder::process_batch_forward_type() {
730+
CHECK_EQ(state_.seq_lens.size(), state_.q_seq_lens.size())
731+
<< "seq_lens size must be equal to q_seq_lens size";
732+
733+
if (state_.q_max_seq_len == 1) {
734+
state_.batch_forward_type = BatchForwardType::DECODE;
735+
return;
736+
}
737+
738+
bool empty_kv_cache = true;
739+
bool all_decode = true;
740+
bool all_prefill = true;
741+
742+
#if defined(USE_NPU)
743+
if (state_.seq_lens.size() == 0) {
744+
state_.batch_forward_type = BatchForwardType::EMPTY;
745+
return;
746+
}
747+
for (size_t i = 0; i < state_.seq_lens.size(); ++i) {
748+
auto q_len = state_.q_seq_lens[i];
749+
auto kv_len = state_.seq_lens[i];
750+
auto cache_len = kv_len - q_len;
751+
if (cache_len > 0) {
752+
empty_kv_cache = false;
753+
}
754+
if (q_len > 1) {
755+
all_decode = false;
756+
}
757+
if (q_len == 1) {
758+
all_prefill = false;
759+
}
760+
}
761+
#elif defined(USE_MLU)
762+
if (state_.seq_lens.size() == 1) {
763+
state_.batch_forward_type = BatchForwardType::EMPTY;
764+
return;
765+
}
766+
for (size_t i = 1; i < state_.seq_lens.size(); ++i) {
767+
auto q_len = state_.q_seq_lens[i] - state_.q_seq_lens[i - 1];
768+
auto kv_len = state_.seq_lens[i] - state_.seq_lens[i - 1];
769+
auto cache_len = kv_len - q_len;
770+
if (cache_len > 0) {
771+
empty_kv_cache = false;
772+
}
773+
if (q_len > 1) {
774+
all_decode = false;
775+
}
776+
if (q_len == 1) {
777+
all_prefill = false;
778+
}
779+
}
780+
#endif
781+
if (empty_kv_cache) {
782+
state_.batch_forward_type = BatchForwardType::PREFILL;
783+
} else {
784+
if (all_prefill) {
785+
state_.batch_forward_type = BatchForwardType::CHUNKED_PREFILL;
786+
} else if (all_decode) {
787+
state_.batch_forward_type = BatchForwardType::DECODE;
788+
} else {
789+
state_.batch_forward_type = BatchForwardType::MIXED;
790+
}
791+
}
792+
}
726793
} // namespace xllm

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class BatchInputBuilder {
6262

6363
void process_swap_block_infos(RawForwardInput& raw_forward_input);
6464

65+
void process_batch_forward_type();
66+
6567
// State management
6668
struct BuilderState {
6769
// Token and position data
@@ -81,6 +83,7 @@ class BatchInputBuilder {
8183

8284
// Sequence metadata
8385
bool empty_kv_cache = true;
86+
BatchForwardType batch_forward_type;
8487
uint32_t max_seq_len = 0;
8588
uint32_t q_max_seq_len = 0;
8689
#if defined(USE_NPU)

xllm/core/framework/model/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ cc_library(
3434
embedding_lm.h
3535
model_args.h
3636
npu_dp_ep_padding.h
37+
batch_forward_type.h
3738
model_input_params.h
3839
SRCS
3940
npu_dp_ep_padding.cpp
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
==============================================================================*/
13+
14+
#pragma once
15+
16+
namespace xllm {
17+
18+
class BatchForwardType {
19+
public:
20+
enum Value : int32_t {
21+
// Prefill without using kv cache.
22+
PREFILL = 0,
23+
// Chunked prefill using kv cache.
24+
// No decode sequence in this type.
25+
CHUNKED_PREFILL = 1,
26+
// Decode one token.
27+
// No prefill sequence in this type.
28+
DECODE = 2,
29+
// Mixed prefill and decode in one batch when doing chunked prefill.
30+
MIXED = 3,
31+
// No sequence to forward.
32+
EMPTY = 4,
33+
};
34+
35+
BatchForwardType() : value_(EMPTY) {}
36+
37+
BatchForwardType(int32_t v) : value_(static_cast<Value>(v)) {}
38+
39+
constexpr BatchForwardType(Value v) : value_(v) {}
40+
41+
BatchForwardType& operator=(Value v) {
42+
value_ = v;
43+
return *this;
44+
}
45+
46+
int32_t value() const { return value_; }
47+
48+
bool is_prefill() const { return (value_ == PREFILL); }
49+
50+
bool is_chunked_prefill() const { return (value_ == CHUNKED_PREFILL); }
51+
52+
bool has_decode() const { return (value_ == DECODE || value_ == MIXED); }
53+
54+
bool no_decode() const {
55+
return (value_ == PREFILL || value_ == CHUNKED_PREFILL);
56+
}
57+
58+
bool is_decode() const { return (value_ == DECODE); }
59+
60+
bool is_mixed() const { return (value_ == MIXED); }
61+
62+
bool is_empty() const { return (value_ == EMPTY); }
63+
64+
const char* to_string() const {
65+
switch (value_) {
66+
case PREFILL:
67+
return "PREFILL";
68+
case CHUNKED_PREFILL:
69+
return "CHUNKED_PREFILL";
70+
case DECODE:
71+
return "DECODE";
72+
case MIXED:
73+
return "MIXED";
74+
case EMPTY:
75+
return "EMPTY";
76+
default:
77+
return "UNKNOWN";
78+
}
79+
}
80+
81+
private:
82+
Value value_;
83+
};
84+
} // namespace xllm

xllm/core/framework/model/model_input_params.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#if defined(USE_NPU)
2222
#include "platform/npu/npu_layer_synchronizer.h"
2323
#endif
24+
#include "framework/model/batch_forward_type.h"
2425
#include "framework/request/mm_data.h"
2526
#include "npu_dp_ep_padding.h"
2627
#include "util/tensor_helper.h"
@@ -52,6 +53,7 @@ struct ModelInputParams {
5253
ModelInputParams params;
5354
params.empty_kv_cache = empty_kv_cache;
5455
params.global_empty_kv_cache = global_empty_kv_cache;
56+
params.batch_forward_type = batch_forward_type;
5557
params.num_sequences = num_sequences;
5658
params.kv_max_seq_len = kv_max_seq_len;
5759
params.q_max_seq_len = q_max_seq_len;
@@ -103,6 +105,7 @@ struct ModelInputParams {
103105
void print() const {
104106
LOG(INFO) << "ModelInputParams: empty_kv_cache is " << empty_kv_cache
105107
<< " , global_empty_kv_cache is " << global_empty_kv_cache
108+
<< " , batch_forward_type is " << batch_forward_type.to_string()
106109
<< " , num_sequences is " << num_sequences
107110
<< " , kv_max_seq_len is " << kv_max_seq_len
108111
<< " , q_max_seq_len is " << q_max_seq_len
@@ -120,6 +123,9 @@ struct ModelInputParams {
120123
// whether the kv-cache is empty for all sequences.
121124
bool empty_kv_cache = true;
122125

126+
// forward type of the batch, used by worker/kernel.
127+
BatchForwardType batch_forward_type;
128+
123129
// total number of sequences in the batch
124130
int32_t num_sequences = 0;
125131

xllm/core/runtime/forward_params.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct RawForwardInput {
148148
std::vector<int32_t> unique_token_lens_vec;
149149
bool empty_kv_cache = true;
150150
bool global_empty_kv_cache = true;
151+
BatchForwardType batch_forward_type;
151152
uint32_t max_seq_len;
152153
uint32_t q_max_seq_len;
153154
std::vector<int32_t> seq_lens;

xllm/core/runtime/forward_shared_memory_manager.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) {
153153
total += type_size<uint64_t> * 4 +
154154
cache_block_size * cache_block_info_fixed_size();
155155

156-
total += type_size<bool> * 2 // empty_kv_cache + global_empty_kv_cache
156+
total += type_size<bool> * 2 // empty_kv_cache + global_empty_kv_cache
157+
+ type_size<int32_t> // batch_forward_type
157158
+ type_size<uint32_t> *
158159
3 // max_seq_len + q_max_seq_len + prefill_seq_len
159160
+ type_size<int32_t> // num_sequences
@@ -599,6 +600,9 @@ INLINE void deserialize_raw_forward_input(
599600

600601
read_data(buffer, input.empty_kv_cache);
601602
read_data(buffer, input.global_empty_kv_cache);
603+
int32_t batch_forward_type;
604+
read_data(buffer, batch_forward_type);
605+
input.batch_forward_type = BatchForwardType(batch_forward_type);
602606
read_data(buffer, input.max_seq_len);
603607
read_data(buffer, input.q_max_seq_len);
604608
read_data(buffer, input.num_sequences);
@@ -653,6 +657,7 @@ INLINE void serialize_raw_forward_input(const RawForwardInput& input,
653657

654658
write_data(buffer, input.empty_kv_cache);
655659
write_data(buffer, input.global_empty_kv_cache);
660+
write_data(buffer, input.batch_forward_type.value());
656661
write_data(buffer, input.max_seq_len);
657662
write_data(buffer, input.q_max_seq_len);
658663
write_data(buffer, input.num_sequences);
@@ -855,6 +860,7 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,
855860
auto& input_params = forward_input.input_params;
856861
input_params.empty_kv_cache = raw_input.empty_kv_cache;
857862
input_params.global_empty_kv_cache = raw_input.global_empty_kv_cache;
863+
input_params.batch_forward_type = raw_input.batch_forward_type;
858864
input_params.num_sequences = raw_input.num_sequences;
859865
input_params.kv_max_seq_len = raw_input.max_seq_len;
860866
input_params.q_max_seq_len = raw_input.q_max_seq_len;

xllm/core/runtime/llm_engine.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,8 @@ std::vector<std::vector<RawForwardInput>> LLMEngine::prepare_inputs(
822822
dp_global_token_nums.resize(micro_batches_num,
823823
std::vector<int32_t>(dp_size_));
824824
bool global_empty_kv_cache = true;
825+
// All empty batches use the first non-empty batch's forward type.
826+
BatchForwardType batch_forward_type;
825827

826828
// eplb related
827829
EplbInfo eplb_info;
@@ -841,6 +843,12 @@ std::vector<std::vector<RawForwardInput>> LLMEngine::prepare_inputs(
841843
batched_inputs[dp_rank][i].flatten_tokens_vec.size();
842844
global_empty_kv_cache =
843845
batched_inputs[dp_rank][i].empty_kv_cache && global_empty_kv_cache;
846+
if (batched_inputs[dp_rank][i].batch_forward_type.is_empty()) {
847+
continue;
848+
}
849+
if (batch_forward_type.is_empty() || batch_forward_type.is_prefill()) {
850+
batch_forward_type = batched_inputs[dp_rank][i].batch_forward_type;
851+
}
844852
}
845853
}
846854

@@ -853,6 +861,9 @@ std::vector<std::vector<RawForwardInput>> LLMEngine::prepare_inputs(
853861
for (auto i = 0; i < micro_batches_num; ++i) {
854862
batched_inputs[dp_rank][i].dp_global_token_nums = dp_global_token_nums[i];
855863
batched_inputs[dp_rank][i].global_empty_kv_cache = global_empty_kv_cache;
864+
if (batched_inputs[dp_rank][i].batch_forward_type.is_empty()) {
865+
batched_inputs[dp_rank][i].batch_forward_type = batch_forward_type;
866+
}
856867
if (FLAGS_enable_eplb) {
857868
batched_inputs[dp_rank][i].eplb_info = eplb_info;
858869
}

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,11 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(
182182
// should be in same prefill stage, so, to judge empty_kv_cache,
183183
// just use micro batch 0 here
184184
if (options_.enable_speculative_decode() && !is_spec_draft_) {
185-
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
185+
if (!inputs.micro_inputs[0].input_params.batch_forward_type.is_decode()) {
186186
output.sample_output.embeddings = hidden_states;
187-
} else if (concated_sampling_params.sample_idxes.defined()) {
188-
// auto sample_idxes =
189-
// concated_sampling_params.selected_token_idxes.index_select(
190-
// /*dim=*/0, concated_sampling_params.sample_idxes);
187+
} else if (concated_sampling_params.selected_token_idxes.defined()) {
191188
auto embeddings = hidden_states.index_select(
192-
/*dim=*/0, concated_sampling_params.sample_idxes);
189+
/*dim=*/0, concated_sampling_params.selected_token_idxes);
193190
output.sample_output.embeddings = embeddings;
194191
}
195192
}

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
173173
}
174174

175175
// TODO: support data parallel case
176-
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
176+
if (!inputs.micro_inputs[0].input_params.batch_forward_type.is_decode()) {
177177
return step_prefill(inputs);
178178
} else {
179179
return step_decode(inputs);
@@ -182,7 +182,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
182182

183183
std::optional<ForwardOutput> SpeculativeWorkerImpl::step_empty(
184184
const BatchedForwardInputs& inputs) {
185-
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
185+
if (!inputs.micro_inputs[0].input_params.batch_forward_type.is_decode()) {
186186
auto output = impl_->step(inputs);
187187
auto draft_output = draft_impl_->step(inputs);
188188
return output;
@@ -833,7 +833,7 @@ void SpeculativeWorkerImpl::update_sampling_params(
833833
void SpeculativeWorkerImpl::prepare_work_before_execute(
834834
const BatchedForwardInputs& inputs,
835835
BatchedForwardInputs& processed_inputs) {
836-
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
836+
if (!inputs.micro_inputs[0].input_params.batch_forward_type.is_decode()) {
837837
WorkerImpl::prepare_work_before_execute(inputs, processed_inputs);
838838
} else {
839839
if (enable_schedule_overlap()) {

0 commit comments

Comments
 (0)