Skip to content

Commit 7de5d18

Browse files
authored
bugfix: fix coredump issue when both prefixcache and mtp are enabled. (#377)
* bugfix: fix coredump issue when both prefixcache and mtp are enabled. * bugfix: fix coredump caused by incorrect token replacement.
1 parent 3e07317 commit 7de5d18

File tree

5 files changed

+43
-19
lines changed

5 files changed

+43
-19
lines changed

xllm/core/framework/request/sequence_kv_state.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
5858
if (blocks.empty()) {
5959
return;
6060
}
61-
6261
// The number of matched blocks may be fewer than the number of blocks held by
6362
// the sequence itself. In this case, try to replace the blocks computed by
6463
// the sequence with blocks from the prefix_cache and release the computed
@@ -86,6 +85,10 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
8685
CHECK_GT(block_size, 0);
8786
num_shared_tokens =
8887
((current_total_num_tokens - 1) / block_size) * block_size;
88+
if (num_owned_shared_blocks_ > 0) {
89+
num_owned_shared_blocks_--;
90+
blocks_.pop_back();
91+
}
8992
}
9093
CHECK_LT(num_shared_tokens, current_total_num_tokens);
9194
// update the kv cache position

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(
182182
// should be in same prefill stage, so, to judge empty_kv_cache,
183183
// just use micro batch 0 here
184184
if (options_.enable_speculative_decode() && !is_spec_draft_) {
185-
if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) {
185+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
186186
output.sample_output.embeddings = hidden_states;
187187
} else if (concated_sampling_params.sample_idxes.defined()) {
188188
// auto sample_idxes =

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
173173
}
174174

175175
// TODO: support data parallel case
176-
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
176+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
177177
return step_prefill(inputs);
178178
} else {
179179
return step_decode(inputs);
@@ -182,7 +182,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
182182

183183
std::optional<ForwardOutput> SpeculativeWorkerImpl::step_empty(
184184
const BatchedForwardInputs& inputs) {
185-
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
185+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
186186
auto output = impl_->step(inputs);
187187
auto draft_output = draft_impl_->step(inputs);
188188
return output;
@@ -230,7 +230,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
230230
if (token_offset > 0) {
231231
prefill_inputs.micro_inputs[i].input_params.mm_data = MMData(
232232
MMType::EMBEDDING,
233-
{{"embedding", embeddings.narrow(0, token_start_idx, token_offset)}});
233+
{{"embedding",
234+
embeddings.narrow(0, token_start_idx, token_offset).clone()}});
234235
}
235236
if (next_tokens.defined()) {
236237
auto& token_ids = prefill_inputs.micro_inputs[i].token_ids;
@@ -293,6 +294,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
293294
void SpeculativeWorkerImpl::prepare_prefill_inputs(
294295
const BatchedForwardInputs& inputs,
295296
BatchedForwardInputs& prefill_inputs) {
297+
prefill_inputs.micro_inputs.clear();
296298
prefill_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
297299
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
298300
auto& input = inputs.micro_inputs[i];
@@ -308,16 +310,16 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
308310
int32_t start_idx = 0;
309311
std::vector<int32_t> new_token_ids;
310312
new_token_ids.reserve(input.token_ids.numel());
311-
for (size_t i = 0; i < input_params.num_sequences; ++i) {
313+
for (size_t j = 0; j < input_params.num_sequences; ++j) {
312314
int32_t q_len = 0;
313-
q_len = input_params.q_seq_lens_vec[i];
315+
q_len = input_params.q_seq_lens_vec[j];
314316
Slice<int32_t> tokens_ids_slice_i =
315317
tokens_ids_slice.slice(start_idx + 1, start_idx + q_len);
316318
start_idx += q_len;
317319
new_token_ids.insert(new_token_ids.end(),
318320
tokens_ids_slice_i.begin(),
319321
tokens_ids_slice_i.end());
320-
new_token_ids.emplace_back(extra_token_ids[i]);
322+
new_token_ids.emplace_back(extra_token_ids[j]);
321323
}
322324
prefill_input.token_ids =
323325
torch::tensor(new_token_ids, prefill_input.positions.options());
@@ -359,7 +361,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
359361
// final step
360362
prepare_validate_inputs(inputs, validate_inputs, true);
361363
} else {
362-
prepare_draft_inputs(draft_inputs, next_step_input, 1, device_);
364+
if (i == 0) {
365+
prepare_draft_inputs(inputs, next_step_input, 1, device_);
366+
} else {
367+
prepare_draft_inputs(draft_inputs, next_step_input, 1, device_);
368+
}
363369
}
364370
draft_outputs.push_back(std::move(future).get().value());
365371
// update input of next step
@@ -368,8 +374,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
368374
auto last_output = draft_outputs.back().sample_output;
369375
auto start_idx = 0;
370376
auto token_start_idx = 0;
371-
for (auto i = 0; i < draft_inputs.micro_inputs.size(); ++i) {
372-
auto& draft_input = draft_inputs.micro_inputs[i];
377+
for (auto j = 0; j < draft_inputs.micro_inputs.size(); ++j) {
378+
auto& draft_input = draft_inputs.micro_inputs[j];
373379
auto offset = draft_input.input_params.num_sequences;
374380
auto token_offset = draft_input.token_ids.size(0);
375381
draft_input.token_ids = safe_to(
@@ -379,6 +385,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
379385
MMType::EMBEDDING,
380386
{{"embedding",
381387
last_output.embeddings.narrow(0, token_start_idx, token_offset)
388+
.clone()
382389
.to(device_)}});
383390
}
384391
start_idx += offset;
@@ -394,9 +401,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
394401
auto next_tokens =
395402
safe_to(draft_output.sample_output.next_tokens, torch::kInt);
396403
int32_t start_idx = 0;
397-
for (auto i = 0; i < validate_inputs.micro_inputs.size(); ++i) {
398-
int32_t offset = draft_inputs.micro_inputs[i].input_params.num_sequences;
399-
auto& validate_input = validate_inputs.micro_inputs[i];
404+
for (auto j = 0; j < validate_inputs.micro_inputs.size(); ++j) {
405+
int32_t offset =
406+
validate_inputs.micro_inputs[j].input_params.num_sequences /
407+
(options_.num_speculative_tokens() + 1);
408+
auto& validate_input = validate_inputs.micro_inputs[j];
400409
auto& token_ids = validate_input.token_ids;
401410
auto mask = (token_ids == -1 * (i + 1));
402411
token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset));
@@ -447,9 +456,10 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(
447456
const int64_t offset,
448457
const torch::Device device) {
449458
// prepare input for MTP in decoding phase (Like Eagle).
459+
draft_inputs.micro_inputs.clear();
450460
draft_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
451-
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
452-
auto& input = inputs.micro_inputs[i];
461+
for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) {
462+
auto& input = inputs.micro_inputs[idx];
453463
ForwardInput draft_input = input.to(device, dtype_);
454464

455465
auto& input_params = draft_input.input_params;
@@ -504,8 +514,8 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
504514
BatchedForwardInputs& validate_inputs,
505515
bool enable_schedule_overlap) {
506516
validate_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
507-
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
508-
auto& input = inputs.micro_inputs[i];
517+
for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) {
518+
auto& input = inputs.micro_inputs[idx];
509519

510520
ForwardInput validate_input = input.to(device_, dtype_);
511521
auto& input_params = validate_input.input_params;
@@ -823,7 +833,7 @@ void SpeculativeWorkerImpl::update_sampling_params(
823833
void SpeculativeWorkerImpl::prepare_work_before_execute(
824834
const BatchedForwardInputs& inputs,
825835
BatchedForwardInputs& processed_inputs) {
826-
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
836+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
827837
WorkerImpl::prepare_work_before_execute(inputs, processed_inputs);
828838
} else {
829839
if (enable_schedule_overlap()) {

xllm/core/runtime/worker_impl.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,4 +759,13 @@ int64_t WorkerImpl::get_active_activation_memory() {
759759
.active_activation_memory;
760760
}
761761

762+
bool WorkerImpl::check_is_prefill(const std::vector<int>& q_seq_lens_vec) {
763+
for (auto q_len : q_seq_lens_vec) {
764+
if (q_len > 1) {
765+
return true;
766+
}
767+
}
768+
return false;
769+
}
770+
762771
} // namespace xllm

xllm/core/runtime/worker_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class WorkerImpl {
166166

167167
torch::ScalarType dtype() const { return dtype_; }
168168

169+
bool check_is_prefill(const std::vector<int>& q_seq_lens_vec);
170+
169171
int32_t hidden_size() const {
170172
return context_.get_model_args().hidden_size();
171173
}

0 commit comments

Comments
 (0)