@@ -66,7 +66,7 @@ void WorkerService::set_worker(std::unique_ptr<Worker> worker) {
6666 initialized_ = true ;
6767}
6868
69- void WorkerService::step (BatchedForwardInputs& batched_fwd_inputs ,
69+ void WorkerService::step (ForwardInput& fwd_input ,
7070 torch::Tensor& next_tokens,
7171 torch::Tensor& logprobs,
7272 torch::Tensor& top_tokens,
@@ -78,7 +78,7 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
7878 torch::Tensor& out_tokens,
7979 torch::Tensor& out_logprobs) {
8080 // execute model
81- auto future = worker_->step_async (batched_fwd_inputs );
81+ auto future = worker_->step_async (fwd_input );
8282
8383 if (!options_.enable_schedule_overlap ()) {
8484 auto forward_outputs = std::move (future).get ();
@@ -142,10 +142,10 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
142142 torch::TensorOptions ().dtype (torch::kInt32 ).device (torch::kCPU );
143143 auto total_prefill_seq_len = 0 ;
144144 auto total_num_sequences = 0 ;
145- for ( auto & input : batched_fwd_inputs. micro_inputs ) {
146- total_num_sequences += input .input_params .num_sequences ;
147- total_prefill_seq_len += input .input_params .prefill_seq_len ;
148- }
145+
146+ total_num_sequences += fwd_input .input_params .num_sequences ;
147+ total_prefill_seq_len += fwd_input .input_params .prefill_seq_len ;
148+
149149 next_tokens =
150150 torch::arange (-1 ,
151151 -1 * (total_num_sequences - total_prefill_seq_len + 1 ),
@@ -166,7 +166,7 @@ void WorkerService::create_polling_shm_thread(
166166 output_shm_manager = std::move (output_shm_manager)]() mutable {
167167 Timer timer;
168168 while (true ) {
169- BatchedForwardInputs batched_fwd_inputs ;
169+ ForwardInput fwd_input ;
170170 std::vector<ForwardInput> inputs;
171171 input_shm_manager->raw_input_read (inputs);
172172 timer.reset ();
@@ -184,31 +184,9 @@ void WorkerService::create_polling_shm_thread(
184184 torch::Tensor out_tokens;
185185 torch::Tensor out_logprobs;
186186
187- auto micro_batches_num = inputs.size ();
188- batched_fwd_inputs.micro_inputs = std::move (inputs);
189- batched_fwd_inputs.concated_sampling_params =
190- batched_fwd_inputs.micro_inputs [0 ].sampling_params ;
191- for (auto i = 1 ; i < micro_batches_num; ++i) {
192- batched_fwd_inputs.concated_sampling_params .concat (
193- batched_fwd_inputs.micro_inputs [i].sampling_params );
194- }
195-
196- // concat acc_logprob here for beam search together
197- if (micro_batches_num > 1 ) {
198- std::vector<torch::Tensor> acc_logprob_vec;
199- acc_logprob_vec.reserve (micro_batches_num);
200- for (auto i = 0 ; i < micro_batches_num; ++i) {
201- acc_logprob_vec.push_back (
202- batched_fwd_inputs.micro_inputs [i].acc_logprob );
203- }
204- batched_fwd_inputs.acc_logprob =
205- torch::cat (acc_logprob_vec, /* dim=*/ -1 );
206- } else {
207- batched_fwd_inputs.acc_logprob =
208- batched_fwd_inputs.micro_inputs [0 ].acc_logprob ;
209- }
187+ fwd_input = std::move (inputs[0 ]);
210188
211- step (batched_fwd_inputs ,
189+ step (fwd_input ,
212190 next_tokens,
213191 logprobs,
214192 top_tokens,
@@ -598,90 +576,58 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
598576 return ;
599577}
600578
601- void WorkerService::ExecuteModel (
602- ::google::protobuf::RpcController* controller,
603- const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
604- proto::ForwardOutput* pb_forward_output,
605- ::google::protobuf::Closure* done) {
606- threadpool_->schedule ([this ,
607- controller,
608- pb_batched_fwd_inputs,
609- pb_forward_output,
610- done]() mutable {
611- brpc::ClosureGuard done_guard (done);
612- Timer timer;
613- // convert proto::BatchedForwardInputs to BatchedForwardInputs
614- auto micro_batches_num = pb_batched_fwd_inputs->micro_inputs ().size ();
615- BatchedForwardInputs batched_fwd_inputs;
616- batched_fwd_inputs.micro_inputs .reserve (micro_batches_num);
617- for (auto i = 0 ; i < micro_batches_num; ++i) {
618- ForwardInput forward_input;
619- proto_to_forward_input (&(pb_batched_fwd_inputs->micro_inputs ()[i]),
620- forward_input,
621- options_.num_decoding_tokens ());
622- batched_fwd_inputs.micro_inputs .push_back (std::move (forward_input));
623- }
624-
625- // concat sampling parameters
626- batched_fwd_inputs.concated_sampling_params =
627- batched_fwd_inputs.micro_inputs [0 ].sampling_params ;
628- for (auto i = 1 ; i < micro_batches_num; ++i) {
629- batched_fwd_inputs.concated_sampling_params .concat (
630- batched_fwd_inputs.micro_inputs [i].sampling_params );
631- }
632-
633- // concat acc_logprob here for beam search together
634- if (micro_batches_num > 1 ) {
635- std::vector<torch::Tensor> acc_logprob_vec;
636- acc_logprob_vec.reserve (micro_batches_num);
637- for (auto i = 0 ; i < micro_batches_num; ++i) {
638- acc_logprob_vec.push_back (
639- batched_fwd_inputs.micro_inputs [i].acc_logprob );
640- }
641- batched_fwd_inputs.acc_logprob = torch::cat (acc_logprob_vec, /* dim=*/ -1 );
642- } else {
643- batched_fwd_inputs.acc_logprob =
644- batched_fwd_inputs.micro_inputs [0 ].acc_logprob ;
645- }
579+ void WorkerService::ExecuteModel (::google::protobuf::RpcController* controller,
580+ const proto::ForwardInput* pb_forward_input,
581+ proto::ForwardOutput* pb_forward_output,
582+ ::google::protobuf::Closure* done) {
583+ threadpool_->schedule (
584+ [this , controller, pb_forward_input, pb_forward_output, done]() mutable {
585+ brpc::ClosureGuard done_guard (done);
586+ // convert proto::ForwardInput to ForwardInput
646587
647- // model output
648- torch::Tensor next_tokens;
649- torch::Tensor logprobs;
650- torch::Tensor top_tokens;
651- torch::Tensor top_logprobs;
652- torch::Tensor embeddings;
653- torch::Tensor expert_load_data;
654- int32_t prepared_layer_id = -1 ;
655- // beam search kernel output
656- torch::Tensor src_seq_idxes;
657- torch::Tensor out_tokens;
658- torch::Tensor out_logprobs;
659-
660- step (batched_fwd_inputs,
661- next_tokens,
662- logprobs,
663- top_tokens,
664- top_logprobs,
665- embeddings,
666- expert_load_data,
667- prepared_layer_id,
668- src_seq_idxes,
669- out_tokens,
670- out_logprobs);
671- // convert to proto output
672- forward_output_to_proto (next_tokens,
673- logprobs,
674- top_tokens,
675- top_logprobs,
676- embeddings,
677- expert_load_data,
678- prepared_layer_id,
679- src_seq_idxes,
680- out_tokens,
681- out_logprobs,
682- pb_forward_output);
683- COUNTER_ADD (worker_service_latency_seconds, timer.elapsed_seconds ());
684- });
588+ Timer timer;
589+ ForwardInput forward_input;
590+ proto_to_forward_input (
591+ pb_forward_input, forward_input, options_.num_decoding_tokens ());
592+
593+ // model output
594+ torch::Tensor next_tokens;
595+ torch::Tensor logprobs;
596+ torch::Tensor top_tokens;
597+ torch::Tensor top_logprobs;
598+ torch::Tensor embeddings;
599+ torch::Tensor expert_load_data;
600+ int32_t prepared_layer_id = -1 ;
601+ // beam search kernel output
602+ torch::Tensor src_seq_idxes;
603+ torch::Tensor out_tokens;
604+ torch::Tensor out_logprobs;
605+
606+ step (forward_input,
607+ next_tokens,
608+ logprobs,
609+ top_tokens,
610+ top_logprobs,
611+ embeddings,
612+ expert_load_data,
613+ prepared_layer_id,
614+ src_seq_idxes,
615+ out_tokens,
616+ out_logprobs);
617+ // convert to proto output
618+ forward_output_to_proto (next_tokens,
619+ logprobs,
620+ top_tokens,
621+ top_logprobs,
622+ embeddings,
623+ expert_load_data,
624+ prepared_layer_id,
625+ src_seq_idxes,
626+ out_tokens,
627+ out_logprobs,
628+ pb_forward_output);
629+ COUNTER_ADD (worker_service_latency_seconds, timer.elapsed_seconds ());
630+ });
685631}
686632
687633void WorkerService::GetLastStepResult (
0 commit comments