@@ -71,6 +71,7 @@ ForwardInput BatchInputBuilder::build_forward_input(
7171 uint32_t num_decoding_tokens,
7272 uint32_t min_decoding_batch_size) {
7373 process_sequences (0 , static_cast <uint32_t >(num_sequences_));
74+ process_batch_forward_type ();
7475 padding_decode_batch_size (num_decoding_tokens, min_decoding_batch_size);
7576
7677 return state_to_forward_input ();
@@ -84,6 +85,7 @@ RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx,
8485 } else {
8586 process_sequences_multithreaded (start_idx, end_idx);
8687 }
88+ process_batch_forward_type ();
8789 return state_to_raw_forward_input ();
8890}
8991
@@ -189,7 +191,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
189191 state_.unique_token_lens_vec .insert (state_.unique_token_lens_vec .end (),
190192 state.unique_token_lens_vec .begin (),
191193 state.unique_token_lens_vec .end ());
192- state_.empty_kv_cache = state_.empty_kv_cache && state.empty_kv_cache ;
193194 state_.max_seq_len = std::max (state_.max_seq_len , state.max_seq_len );
194195 state_.q_max_seq_len = std::max (state_.q_max_seq_len , state.q_max_seq_len );
195196#if defined(USE_NPU)
@@ -278,7 +279,6 @@ void BatchInputBuilder::process_single_sequence(
278279 << allowed_max_tokens_[seq_index];
279280
280281 // Update state
281- state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0 );
282282 state.max_seq_len = std::max (state.max_seq_len , seq_len);
283283 state.q_max_seq_len = std::max (state.q_max_seq_len , q_seq_len);
284284#if defined(USE_NPU)
@@ -498,7 +498,7 @@ void BatchInputBuilder::padding_decode_batch_size(
498498 if (num_sequences_ < min_decoding_batch_size) {
499499 const uint32_t n_tokens = state_.flatten_tokens_vec .size ();
500500 // kv_cache is not empty in decoding phase
501- const bool in_decoding_phase = !state_.empty_kv_cache ;
501+ const bool in_decoding_phase = !state_.batch_forward_type . is_prefill () ;
502502 const bool same_num_decoding_tokens =
503503 state_.q_max_seq_len == num_decoding_tokens &&
504504 n_tokens == num_sequences_ * num_decoding_tokens;
@@ -551,7 +551,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
551551 }
552552
553553 auto & input_params = forward_input.input_params ;
554- input_params.empty_kv_cache = state_.empty_kv_cache ;
554+ input_params.batch_forward_type = state_.batch_forward_type ;
555555 input_params.num_sequences = state_.block_tables_vec .size ();
556556 input_params.kv_max_seq_len = state_.max_seq_len ;
557557 input_params.q_max_seq_len = state_.q_max_seq_len ;
@@ -561,8 +561,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
561561 input_params.q_seq_lens_vec = std::move (state_.q_seq_lens );
562562 input_params.new_cache_slots =
563563 torch::tensor (state_.new_token_slot_ids , torch::kInt );
564- input_params.decode_seq_range =
565- util::find_ones_indices (input_params.q_seq_lens_vec );
566564
567565 // for flashinfer
568566 input_params.paged_kv_indptr =
@@ -644,8 +642,7 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
644642 std::move (state_.unique_token_counts_vec );
645643 raw_forward_input.unique_token_lens_vec =
646644 std::move (state_.unique_token_lens_vec );
647- raw_forward_input.empty_kv_cache = state_.empty_kv_cache ;
648- // raw_forward_input.global_empty_kv_cache = ;
645+ raw_forward_input.batch_forward_type = state_.batch_forward_type ;
649646 raw_forward_input.max_seq_len = state_.max_seq_len ;
650647 raw_forward_input.q_max_seq_len = state_.q_max_seq_len ;
651648 raw_forward_input.seq_lens = std::move (state_.seq_lens );
@@ -727,4 +724,69 @@ void BatchInputBuilder::process_swap_block_infos(
727724 swap_block_transfer_infos_->end ());
728725 }
729726}
727+
728+ void BatchInputBuilder::process_batch_forward_type () {
729+ CHECK_EQ (state_.seq_lens .size (), state_.q_seq_lens .size ())
730+ << " seq_lens size must be equal to q_seq_lens size" ;
731+
732+ if (state_.q_max_seq_len == 1 ) {
733+ state_.batch_forward_type = BatchForwardType::DECODE;
734+ return ;
735+ }
736+
737+ bool empty_kv_cache = true ;
738+ bool all_decode = true ;
739+ bool all_prefill = true ;
740+
741+ #if defined(USE_NPU)
742+ if (state_.seq_lens .size () == 0 ) {
743+ state_.batch_forward_type = BatchForwardType::EMPTY;
744+ return ;
745+ }
746+ for (size_t i = 0 ; i < state_.seq_lens .size (); ++i) {
747+ auto q_len = state_.q_seq_lens [i];
748+ auto kv_len = state_.seq_lens [i];
749+ auto cache_len = kv_len - q_len;
750+ if (cache_len > 0 ) {
751+ empty_kv_cache = false ;
752+ }
753+ if (q_len > 1 ) {
754+ all_decode = false ;
755+ }
756+ if (q_len == 1 ) {
757+ all_prefill = false ;
758+ }
759+ }
760+ #elif defined(USE_MLU)
761+ if (state_.seq_lens .size () == 1 ) {
762+ state_.batch_forward_type = BatchForwardType::EMPTY;
763+ return ;
764+ }
765+ for (size_t i = 1 ; i < state_.seq_lens .size (); ++i) {
766+ auto q_len = state_.q_seq_lens [i] - state_.q_seq_lens [i - 1 ];
767+ auto kv_len = state_.seq_lens [i] - state_.seq_lens [i - 1 ];
768+ auto cache_len = kv_len - q_len;
769+ if (cache_len > 0 ) {
770+ empty_kv_cache = false ;
771+ }
772+ if (q_len > 1 ) {
773+ all_decode = false ;
774+ }
775+ if (q_len == 1 ) {
776+ all_prefill = false ;
777+ }
778+ }
779+ #endif
780+ if (empty_kv_cache) {
781+ state_.batch_forward_type = BatchForwardType::PREFILL;
782+ } else {
783+ if (all_prefill) {
784+ state_.batch_forward_type = BatchForwardType::CHUNKED_PREFILL;
785+ } else if (all_decode) {
786+ state_.batch_forward_type = BatchForwardType::DECODE;
787+ } else {
788+ state_.batch_forward_type = BatchForwardType::MIXED;
789+ }
790+ }
791+ }
730792} // namespace xllm
0 commit comments