Skip to content

Commit ba62789

Browse files
committed
feat: revert the original code before refactoring the multi-stream[2/2].
Signed-off-by: Tao Peng <[email protected]>
1 parent aad330a commit ba62789

38 files changed

+598
-832
lines changed

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -507,22 +507,14 @@ bool CommChannel::get_active_activation_memory_async(
507507
bool CommChannel::execute_model_with_brpc(
508508
const std::vector<RawForwardInput>& inputs,
509509
folly::Promise<std::optional<RawForwardOutput>>& promise) {
510-
// convert to proto::BatchedForwardInputs
511-
proto::BatchedForwardInputs pb_batched_fwd_inputs;
512-
std::vector<proto::ForwardInput> batched_fwd_inputs_vec;
513-
batched_fwd_inputs_vec.reserve(inputs.size());
514-
for (auto i = 0; i < inputs.size(); ++i) {
515-
proto::ForwardInput pb_fwd_input;
516-
forward_input_to_proto(inputs[i], &pb_fwd_input);
517-
batched_fwd_inputs_vec.push_back(std::move(pb_fwd_input));
518-
}
519-
ADD_VECTOR_TO_PROTO(pb_batched_fwd_inputs.mutable_micro_inputs(),
520-
batched_fwd_inputs_vec);
510+
// convert to proto::ForwardInput
511+
proto::ForwardInput pb_forward_input;
512+
forward_input_to_proto(inputs[0], &pb_forward_input);
513+
521514
// call ExecuteModel with callback
522515
auto done = new ExecuteModelClosure();
523516
done->promise = std::move(promise);
524-
stub_->ExecuteModel(
525-
&done->cntl, &pb_batched_fwd_inputs, &done->pb_output, done);
517+
stub_->ExecuteModel(&done->cntl, &pb_forward_input, &done->pb_output, done);
526518
return true;
527519
}
528520

@@ -567,4 +559,4 @@ void TransferBlocksClosure::Run() {
567559
return;
568560
}
569561

570-
} // namespace xllm
562+
} // namespace xllm

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,4 @@ class TransferBlocksClosure : public google::protobuf::Closure {
145145
brpc::Controller cntl;
146146
folly::Promise<uint32_t> promise;
147147
};
148-
} // namespace xllm
148+
} // namespace xllm

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,14 @@ folly::SemiFuture<std::optional<ForwardOutput>> RemoteWorker::step_async(
167167
}
168168

169169
folly::SemiFuture<std::optional<RawForwardOutput>> RemoteWorker::step_async(
170-
const std::vector<RawForwardInput>& inputs) {
170+
const RawForwardInput& inputs) {
171171
folly::Promise<std::optional<RawForwardOutput>> promise;
172172
auto future = promise.getSemiFuture();
173-
threadpool_.schedule(
174-
[this, inputs = inputs, promise = std::move(promise)]() mutable {
175-
channel_->execute_model_async(inputs, promise);
176-
});
173+
threadpool_.schedule([this,
174+
inputs = std::move(inputs),
175+
promise = std::move(promise)]() mutable {
176+
channel_->execute_model_async({inputs}, promise);
177+
});
177178

178179
return future;
179180
}

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class RemoteWorker : public WorkerClient {
127127
const ForwardInput& inputs) override;
128128

129129
virtual folly::SemiFuture<std::optional<RawForwardOutput>> step_async(
130-
const std::vector<RawForwardInput>& inputs) override;
130+
const RawForwardInput& inputs) override;
131131

132132
virtual folly::SemiFuture<folly::Unit> process_group_test_async() override;
133133

xllm/core/distributed_runtime/worker_service.cpp

100755100644
Lines changed: 60 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void WorkerService::set_worker(std::unique_ptr<Worker> worker) {
6666
initialized_ = true;
6767
}
6868

69-
void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
69+
void WorkerService::step(ForwardInput& fwd_input,
7070
torch::Tensor& next_tokens,
7171
torch::Tensor& logprobs,
7272
torch::Tensor& top_tokens,
@@ -78,7 +78,7 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
7878
torch::Tensor& out_tokens,
7979
torch::Tensor& out_logprobs) {
8080
// execute model
81-
auto future = worker_->step_async(batched_fwd_inputs);
81+
auto future = worker_->step_async(fwd_input);
8282

8383
if (!options_.enable_schedule_overlap()) {
8484
auto forward_outputs = std::move(future).get();
@@ -142,10 +142,10 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
142142
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
143143
auto total_prefill_seq_len = 0;
144144
auto total_num_sequences = 0;
145-
for (auto& input : batched_fwd_inputs.micro_inputs) {
146-
total_num_sequences += input.input_params.num_sequences;
147-
total_prefill_seq_len += input.input_params.prefill_seq_len;
148-
}
145+
146+
total_num_sequences += fwd_input.input_params.num_sequences;
147+
total_prefill_seq_len += fwd_input.input_params.prefill_seq_len;
148+
149149
next_tokens =
150150
torch::arange(-1,
151151
-1 * (total_num_sequences - total_prefill_seq_len + 1),
@@ -166,7 +166,7 @@ void WorkerService::create_polling_shm_thread(
166166
output_shm_manager = std::move(output_shm_manager)]() mutable {
167167
Timer timer;
168168
while (true) {
169-
BatchedForwardInputs batched_fwd_inputs;
169+
ForwardInput fwd_input;
170170
std::vector<ForwardInput> inputs;
171171
input_shm_manager->raw_input_read(inputs);
172172
timer.reset();
@@ -184,31 +184,9 @@ void WorkerService::create_polling_shm_thread(
184184
torch::Tensor out_tokens;
185185
torch::Tensor out_logprobs;
186186

187-
auto micro_batches_num = inputs.size();
188-
batched_fwd_inputs.micro_inputs = std::move(inputs);
189-
batched_fwd_inputs.concated_sampling_params =
190-
batched_fwd_inputs.micro_inputs[0].sampling_params;
191-
for (auto i = 1; i < micro_batches_num; ++i) {
192-
batched_fwd_inputs.concated_sampling_params.concat(
193-
batched_fwd_inputs.micro_inputs[i].sampling_params);
194-
}
195-
196-
// concat acc_logprob here for beam search together
197-
if (micro_batches_num > 1) {
198-
std::vector<torch::Tensor> acc_logprob_vec;
199-
acc_logprob_vec.reserve(micro_batches_num);
200-
for (auto i = 0; i < micro_batches_num; ++i) {
201-
acc_logprob_vec.push_back(
202-
batched_fwd_inputs.micro_inputs[i].acc_logprob);
203-
}
204-
batched_fwd_inputs.acc_logprob =
205-
torch::cat(acc_logprob_vec, /*dim=*/-1);
206-
} else {
207-
batched_fwd_inputs.acc_logprob =
208-
batched_fwd_inputs.micro_inputs[0].acc_logprob;
209-
}
187+
fwd_input = std::move(inputs[0]);
210188

211-
step(batched_fwd_inputs,
189+
step(fwd_input,
212190
next_tokens,
213191
logprobs,
214192
top_tokens,
@@ -598,90 +576,58 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
598576
return;
599577
}
600578

601-
void WorkerService::ExecuteModel(
602-
::google::protobuf::RpcController* controller,
603-
const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
604-
proto::ForwardOutput* pb_forward_output,
605-
::google::protobuf::Closure* done) {
606-
threadpool_->schedule([this,
607-
controller,
608-
pb_batched_fwd_inputs,
609-
pb_forward_output,
610-
done]() mutable {
611-
brpc::ClosureGuard done_guard(done);
612-
Timer timer;
613-
// convert proto::BatchedForwardInputs to BatchedForwardInputs
614-
auto micro_batches_num = pb_batched_fwd_inputs->micro_inputs().size();
615-
BatchedForwardInputs batched_fwd_inputs;
616-
batched_fwd_inputs.micro_inputs.reserve(micro_batches_num);
617-
for (auto i = 0; i < micro_batches_num; ++i) {
618-
ForwardInput forward_input;
619-
proto_to_forward_input(&(pb_batched_fwd_inputs->micro_inputs()[i]),
620-
forward_input,
621-
options_.num_decoding_tokens());
622-
batched_fwd_inputs.micro_inputs.push_back(std::move(forward_input));
623-
}
624-
625-
// concat sampling parameters
626-
batched_fwd_inputs.concated_sampling_params =
627-
batched_fwd_inputs.micro_inputs[0].sampling_params;
628-
for (auto i = 1; i < micro_batches_num; ++i) {
629-
batched_fwd_inputs.concated_sampling_params.concat(
630-
batched_fwd_inputs.micro_inputs[i].sampling_params);
631-
}
632-
633-
// concat acc_logprob here for beam search together
634-
if (micro_batches_num > 1) {
635-
std::vector<torch::Tensor> acc_logprob_vec;
636-
acc_logprob_vec.reserve(micro_batches_num);
637-
for (auto i = 0; i < micro_batches_num; ++i) {
638-
acc_logprob_vec.push_back(
639-
batched_fwd_inputs.micro_inputs[i].acc_logprob);
640-
}
641-
batched_fwd_inputs.acc_logprob = torch::cat(acc_logprob_vec, /*dim=*/-1);
642-
} else {
643-
batched_fwd_inputs.acc_logprob =
644-
batched_fwd_inputs.micro_inputs[0].acc_logprob;
645-
}
579+
void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
580+
const proto::ForwardInput* pb_forward_input,
581+
proto::ForwardOutput* pb_forward_output,
582+
::google::protobuf::Closure* done) {
583+
threadpool_->schedule(
584+
[this, controller, pb_forward_input, pb_forward_output, done]() mutable {
585+
brpc::ClosureGuard done_guard(done);
586+
// convert proto::ForwardInput to ForwardInput
646587

647-
// model output
648-
torch::Tensor next_tokens;
649-
torch::Tensor logprobs;
650-
torch::Tensor top_tokens;
651-
torch::Tensor top_logprobs;
652-
torch::Tensor embeddings;
653-
torch::Tensor expert_load_data;
654-
int32_t prepared_layer_id = -1;
655-
// beam search kernel output
656-
torch::Tensor src_seq_idxes;
657-
torch::Tensor out_tokens;
658-
torch::Tensor out_logprobs;
659-
660-
step(batched_fwd_inputs,
661-
next_tokens,
662-
logprobs,
663-
top_tokens,
664-
top_logprobs,
665-
embeddings,
666-
expert_load_data,
667-
prepared_layer_id,
668-
src_seq_idxes,
669-
out_tokens,
670-
out_logprobs);
671-
// convert to proto output
672-
forward_output_to_proto(next_tokens,
673-
logprobs,
674-
top_tokens,
675-
top_logprobs,
676-
embeddings,
677-
expert_load_data,
678-
prepared_layer_id,
679-
src_seq_idxes,
680-
out_tokens,
681-
out_logprobs,
682-
pb_forward_output);
683-
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
684-
});
588+
Timer timer;
589+
ForwardInput forward_input;
590+
proto_to_forward_input(
591+
pb_forward_input, forward_input, options_.num_decoding_tokens());
592+
593+
// model output
594+
torch::Tensor next_tokens;
595+
torch::Tensor logprobs;
596+
torch::Tensor top_tokens;
597+
torch::Tensor top_logprobs;
598+
torch::Tensor embeddings;
599+
torch::Tensor expert_load_data;
600+
int32_t prepared_layer_id = -1;
601+
// beam search kernel output
602+
torch::Tensor src_seq_idxes;
603+
torch::Tensor out_tokens;
604+
torch::Tensor out_logprobs;
605+
606+
step(forward_input,
607+
next_tokens,
608+
logprobs,
609+
top_tokens,
610+
top_logprobs,
611+
embeddings,
612+
expert_load_data,
613+
prepared_layer_id,
614+
src_seq_idxes,
615+
out_tokens,
616+
out_logprobs);
617+
// convert to proto output
618+
forward_output_to_proto(next_tokens,
619+
logprobs,
620+
top_tokens,
621+
top_logprobs,
622+
embeddings,
623+
expert_load_data,
624+
prepared_layer_id,
625+
src_seq_idxes,
626+
out_tokens,
627+
out_logprobs,
628+
pb_forward_output);
629+
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
630+
});
685631
}
686632

687633
void WorkerService::GetLastStepResult(

xllm/core/distributed_runtime/worker_service.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class WorkerService : public proto::DistributeWorker {
111111
::google::protobuf::Closure* done) override;
112112

113113
void ExecuteModel(::google::protobuf::RpcController* controller,
114-
const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
114+
const proto::ForwardInput* pb_fwd_input,
115115
proto::ForwardOutput* pb_forward_output,
116116
::google::protobuf::Closure* done) override;
117117

@@ -126,7 +126,7 @@ class WorkerService : public proto::DistributeWorker {
126126
::google::protobuf::Closure* done) override;
127127

128128
private:
129-
void step(BatchedForwardInputs& batched_fwd_inputs,
129+
void step(ForwardInput& fwd_input,
130130
torch::Tensor& next_tokens,
131131
torch::Tensor& logprobs,
132132
torch::Tensor& top_tokens,

xllm/core/runtime/acl_graph_executor_impl.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,14 @@ ForwardInput AclGraphExecutorImpl::prepare_inputs(Batch& batch) {
187187
// tokens: [num_decode_tokens]
188188
// positions: [num_decode_tokens] token pos in the sequence
189189
// returns: [num_decode_tokens, hidden_size]
190-
torch::Tensor AclGraphExecutorImpl::run(
191-
const std::vector<torch::Tensor>& tokens,
192-
const std::vector<torch::Tensor>& positions,
193-
std::vector<KVCache>& kv_caches,
194-
const std::vector<ModelInputParams>& params) {
190+
torch::Tensor AclGraphExecutorImpl::run(const torch::Tensor& tokens,
191+
const torch::Tensor& positions,
192+
std::vector<KVCache>& kv_caches,
193+
const ModelInputParams& params) {
195194
// no mirco batch in decode phase
196-
const torch::Tensor& tokens_tensor = tokens[0];
197-
const torch::Tensor& positions_tensor = positions[0];
198-
const ModelInputParams& params_single = params[0];
195+
const torch::Tensor& tokens_tensor = tokens;
196+
const torch::Tensor& positions_tensor = positions;
197+
const ModelInputParams& params_single = params;
199198
// Identify decode phase using q_max_seq_len for precise detection
200199
// Decode phase: all sequences have q_seq_len == 1 (generating one token at a
201200
// time) Prefill phase: sequences have q_seq_len > 1 (processing multiple
@@ -207,7 +206,7 @@ torch::Tensor AclGraphExecutorImpl::run(
207206
// If not in decode phase, use eager mode directly without acl graph
208207
if (!in_decoding_phase) {
209208
COUNTER_INC(num_model_execution_total_eager);
210-
return model_->forward(tokens[0], positions[0], kv_caches, params[0]);
209+
return model_->forward(tokens, positions, kv_caches, params);
211210
}
212211

213212
// Only use acl graph in decode phase for performance optimization
@@ -229,15 +228,12 @@ torch::Tensor AclGraphExecutorImpl::run(
229228

230229
// Combined condition for graph capture support
231230
// ACL graph executor only supports single tensor inputs (no micro-batching)
232-
const bool single_input =
233-
(tokens.size() == 1) && (positions.size() == 1) && (params.size() == 1);
234-
const bool capture_supported =
235-
single_input && seq_len_supported && same_num_decoding_tokens;
231+
const bool capture_supported = seq_len_supported && same_num_decoding_tokens;
236232

237233
// Early return if conditions are not suitable for graph operations
238234
if (!capture_supported) {
239235
COUNTER_INC(num_model_execution_total_eager);
240-
return model_->forward(tokens[0], positions[0], kv_caches, params[0]);
236+
return model_->forward(tokens, positions, kv_caches, params);
241237
}
242238

243239
// Check if captured graph exists for this bucket size
@@ -273,7 +269,7 @@ torch::Tensor AclGraphExecutorImpl::run(
273269
// Fallback to eager mode if capture fails
274270
LOG(ERROR) << "Failed to capture ACL graph for bucket size: " << bucket_size;
275271
COUNTER_INC(num_model_execution_total_eager);
276-
return model_->forward(tokens[0], positions[0], kv_caches, params[0]);
272+
return model_->forward(tokens, positions, kv_caches, params);
277273
}
278274

279275
void AclGraph::copy_data_to_graph_buffer(const torch::Tensor& tokens,

xllm/core/runtime/acl_graph_executor_impl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ class AclGraphExecutorImpl : public ExecutorImpl {
101101
ForwardInput prepare_inputs(Batch& batch) override;
102102

103103
// Execute model with graph optimization for decode phase
104-
torch::Tensor run(const std::vector<torch::Tensor>& tokens,
105-
const std::vector<torch::Tensor>& positions,
104+
torch::Tensor run(const torch::Tensor& tokens,
105+
const torch::Tensor& positions,
106106
std::vector<KVCache>& kv_caches,
107-
const std::vector<ModelInputParams>& params) override;
107+
const ModelInputParams& params) override;
108108

109109
private:
110110
// not own
@@ -123,4 +123,4 @@ class AclGraphExecutorImpl : public ExecutorImpl {
123123
uint32_t get_bucket_size(uint32_t batch_size) const;
124124
};
125125

126-
} // namespace xllm
126+
} // namespace xllm

0 commit comments

Comments
 (0)