@@ -90,7 +90,7 @@ ForwardInput BatchInputBuilder::build_forward_input(
9090 uint32_t min_decoding_batch_size) {
9191 process_sequences (0 , static_cast <uint32_t >(num_sequences_));
9292 padding_decode_batch_size (num_decoding_tokens, min_decoding_batch_size);
93-
93+ process_batch_forward_type ();
9494 return state_to_forward_input ();
9595}
9696
@@ -102,6 +102,7 @@ RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx,
102102 } else {
103103 process_sequences_multithreaded (start_idx, end_idx);
104104 }
105+ process_batch_forward_type ();
105106 return state_to_raw_forward_input ();
106107}
107108
@@ -548,6 +549,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
548549
549550 auto & input_params = forward_input.input_params ;
550551 input_params.empty_kv_cache = state_.empty_kv_cache ;
552+ input_params.batch_forward_type = state_.batch_forward_type ;
551553 input_params.num_sequences = state_.block_tables_vec .size ();
552554 input_params.kv_max_seq_len = state_.max_seq_len ;
553555 input_params.q_max_seq_len = state_.q_max_seq_len ;
@@ -633,7 +635,7 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
633635 raw_forward_input.unique_token_lens_vec =
634636 std::move (state_.unique_token_lens_vec );
635637 raw_forward_input.empty_kv_cache = state_.empty_kv_cache ;
636- // raw_forward_input.global_empty_kv_cache = ;
638+ raw_forward_input.batch_forward_type = state_. batch_forward_type ;
637639 raw_forward_input.max_seq_len = state_.max_seq_len ;
638640 raw_forward_input.q_max_seq_len = state_.q_max_seq_len ;
639641 raw_forward_input.seq_lens = std::move (state_.seq_lens );
@@ -723,4 +725,69 @@ void BatchInputBuilder::process_swap_block_infos(
723725 swap_cache_block_infos_->end ());
724726 }
725727}
728+
729+ void BatchInputBuilder::process_batch_forward_type () {
730+ CHECK_EQ (state_.seq_lens .size (), state_.q_seq_lens .size ())
731+ << " seq_lens size must be equal to q_seq_lens size" ;
732+
733+ if (state_.q_max_seq_len == 1 ) {
734+ state_.batch_forward_type = BatchForwardType::DECODE;
735+ return ;
736+ }
737+
738+ bool empty_kv_cache = true ;
739+ bool all_decode = true ;
740+ bool all_prefill = true ;
741+
742+ #if defined(USE_NPU)
743+ if (state_.seq_lens .size () == 0 ) {
744+ state_.batch_forward_type = BatchForwardType::EMPTY;
745+ return ;
746+ }
747+ for (size_t i = 0 ; i < state_.seq_lens .size (); ++i) {
748+ auto q_len = state_.q_seq_lens [i];
749+ auto kv_len = state_.seq_lens [i];
750+ auto cache_len = kv_len - q_len;
751+ if (cache_len > 0 ) {
752+ empty_kv_cache = false ;
753+ }
754+ if (q_len > 1 ) {
755+ all_decode = false ;
756+ }
757+ if (q_len == 1 ) {
758+ all_prefill = false ;
759+ }
760+ }
761+ #elif defined(USE_MLU)
762+ if (state_.seq_lens .size () == 1 ) {
763+ state_.batch_forward_type = BatchForwardType::EMPTY;
764+ return ;
765+ }
766+ for (size_t i = 1 ; i < state_.seq_lens .size (); ++i) {
767+ auto q_len = state_.q_seq_lens [i] - state_.q_seq_lens [i - 1 ];
768+ auto kv_len = state_.seq_lens [i] - state_.seq_lens [i - 1 ];
769+ auto cache_len = kv_len - q_len;
770+ if (cache_len > 0 ) {
771+ empty_kv_cache = false ;
772+ }
773+ if (q_len > 1 ) {
774+ all_decode = false ;
775+ }
776+ if (q_len == 1 ) {
777+ all_prefill = false ;
778+ }
779+ }
780+ #endif
781+ if (empty_kv_cache) {
782+ state_.batch_forward_type = BatchForwardType::PREFILL;
783+ } else {
784+ if (all_prefill) {
785+ state_.batch_forward_type = BatchForwardType::CHUNKED_PREFILL;
786+ } else if (all_decode) {
787+ state_.batch_forward_type = BatchForwardType::DECODE;
788+ } else {
789+ state_.batch_forward_type = BatchForwardType::MIXED;
790+ }
791+ }
792+ }
726793} // namespace xllm
0 commit comments