@@ -173,7 +173,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
173173 }
174174
175175 // TODO: support data parallel case
176- if (inputs.micro_inputs [0 ].input_params .q_seq_lens_vec [ 0 ] > 1 ) {
176+ if (check_is_prefill ( inputs.micro_inputs [0 ].input_params .q_seq_lens_vec ) ) {
177177 return step_prefill (inputs);
178178 } else {
179179 return step_decode (inputs);
@@ -182,7 +182,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
182182
183183std::optional<ForwardOutput> SpeculativeWorkerImpl::step_empty (
184184 const BatchedForwardInputs& inputs) {
185- if (inputs.micro_inputs [0 ].input_params .q_seq_lens_vec [ 0 ] > 1 ) {
185+ if (check_is_prefill ( inputs.micro_inputs [0 ].input_params .q_seq_lens_vec ) ) {
186186 auto output = impl_->step (inputs);
187187 auto draft_output = draft_impl_->step (inputs);
188188 return output;
@@ -230,7 +230,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
230230 if (token_offset > 0 ) {
231231 prefill_inputs.micro_inputs [i].input_params .mm_data = MMData (
232232 MMType::EMBEDDING,
233- {{" embedding" , embeddings.narrow (0 , token_start_idx, token_offset)}});
233+ {{" embedding" ,
234+ embeddings.narrow (0 , token_start_idx, token_offset).clone ()}});
234235 }
235236 if (next_tokens.defined ()) {
236237 auto & token_ids = prefill_inputs.micro_inputs [i].token_ids ;
@@ -293,6 +294,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
293294void SpeculativeWorkerImpl::prepare_prefill_inputs (
294295 const BatchedForwardInputs& inputs,
295296 BatchedForwardInputs& prefill_inputs) {
297+ prefill_inputs.micro_inputs .clear ();
296298 prefill_inputs.micro_inputs .reserve (inputs.micro_inputs .size ());
297299 for (auto i = 0 ; i < inputs.micro_inputs .size (); ++i) {
298300 auto & input = inputs.micro_inputs [i];
@@ -308,16 +310,16 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
308310 int32_t start_idx = 0 ;
309311 std::vector<int32_t > new_token_ids;
310312 new_token_ids.reserve (input.token_ids .numel ());
311- for (size_t i = 0 ; i < input_params.num_sequences ; ++i ) {
313+ for (size_t j = 0 ; j < input_params.num_sequences ; ++j ) {
312314 int32_t q_len = 0 ;
313- q_len = input_params.q_seq_lens_vec [i ];
315+ q_len = input_params.q_seq_lens_vec [j ];
314316 Slice<int32_t > tokens_ids_slice_i =
315317 tokens_ids_slice.slice (start_idx + 1 , start_idx + q_len);
316318 start_idx += q_len;
317319 new_token_ids.insert (new_token_ids.end (),
318320 tokens_ids_slice_i.begin (),
319321 tokens_ids_slice_i.end ());
320- new_token_ids.emplace_back (extra_token_ids[i ]);
322+ new_token_ids.emplace_back (extra_token_ids[j ]);
321323 }
322324 prefill_input.token_ids =
323325 torch::tensor (new_token_ids, prefill_input.positions .options ());
@@ -359,7 +361,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
359361 // final step
360362 prepare_validate_inputs (inputs, validate_inputs, true );
361363 } else {
362- prepare_draft_inputs (draft_inputs, next_step_input, 1 , device_);
364+ if (i == 0 ) {
365+ prepare_draft_inputs (inputs, next_step_input, 1 , device_);
366+ } else {
367+ prepare_draft_inputs (draft_inputs, next_step_input, 1 , device_);
368+ }
363369 }
364370 draft_outputs.push_back (std::move (future).get ().value ());
365371 // update input of next step
@@ -368,8 +374,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
368374 auto last_output = draft_outputs.back ().sample_output ;
369375 auto start_idx = 0 ;
370376 auto token_start_idx = 0 ;
371- for (auto i = 0 ; i < draft_inputs.micro_inputs .size (); ++i ) {
372- auto & draft_input = draft_inputs.micro_inputs [i ];
377+ for (auto j = 0 ; j < draft_inputs.micro_inputs .size (); ++j ) {
378+ auto & draft_input = draft_inputs.micro_inputs [j ];
373379 auto offset = draft_input.input_params .num_sequences ;
374380 auto token_offset = draft_input.token_ids .size (0 );
375381 draft_input.token_ids = safe_to (
@@ -379,6 +385,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
379385 MMType::EMBEDDING,
380386 {{" embedding" ,
381387 last_output.embeddings .narrow (0 , token_start_idx, token_offset)
388+ .clone ()
382389 .to (device_)}});
383390 }
384391 start_idx += offset;
@@ -394,9 +401,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
394401 auto next_tokens =
395402 safe_to (draft_output.sample_output .next_tokens , torch::kInt );
396403 int32_t start_idx = 0 ;
397- for (auto i = 0 ; i < validate_inputs.micro_inputs .size (); ++i) {
398- int32_t offset = draft_inputs.micro_inputs [i].input_params .num_sequences ;
399- auto & validate_input = validate_inputs.micro_inputs [i];
404+ for (auto j = 0 ; j < validate_inputs.micro_inputs .size (); ++j) {
405+ int32_t offset =
406+ validate_inputs.micro_inputs [j].input_params .num_sequences /
407+ (options_.num_speculative_tokens () + 1 );
408+ auto & validate_input = validate_inputs.micro_inputs [j];
400409 auto & token_ids = validate_input.token_ids ;
401410 auto mask = (token_ids == -1 * (i + 1 ));
402411 token_ids.masked_scatter_ (mask, next_tokens.narrow (0 , start_idx, offset));
@@ -447,9 +456,10 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(
447456 const int64_t offset,
448457 const torch::Device device) {
449458 // prepare input for MTP in decoding phase (Like Eagle).
459+ draft_inputs.micro_inputs .clear ();
450460 draft_inputs.micro_inputs .reserve (inputs.micro_inputs .size ());
451- for (auto i = 0 ; i < inputs.micro_inputs .size (); ++i ) {
452- auto & input = inputs.micro_inputs [i ];
461+ for (auto idx = 0 ; idx < inputs.micro_inputs .size (); ++idx ) {
462+ auto & input = inputs.micro_inputs [idx ];
453463 ForwardInput draft_input = input.to (device, dtype_);
454464
455465 auto & input_params = draft_input.input_params ;
@@ -504,8 +514,8 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
504514 BatchedForwardInputs& validate_inputs,
505515 bool enable_schedule_overlap) {
506516 validate_inputs.micro_inputs .reserve (inputs.micro_inputs .size ());
507- for (auto i = 0 ; i < inputs.micro_inputs .size (); ++i ) {
508- auto & input = inputs.micro_inputs [i ];
517+ for (auto idx = 0 ; idx < inputs.micro_inputs .size (); ++idx ) {
518+ auto & input = inputs.micro_inputs [idx ];
509519
510520 ForwardInput validate_input = input.to (device_, dtype_);
511521 auto & input_params = validate_input.input_params ;
@@ -823,7 +833,7 @@ void SpeculativeWorkerImpl::update_sampling_params(
823833void SpeculativeWorkerImpl::prepare_work_before_execute (
824834 const BatchedForwardInputs& inputs,
825835 BatchedForwardInputs& processed_inputs) {
826- if (inputs.micro_inputs [0 ].input_params .q_seq_lens_vec [ 0 ] > 1 ) {
836+ if (check_is_prefill ( inputs.micro_inputs [0 ].input_params .q_seq_lens_vec ) ) {
827837 WorkerImpl::prepare_work_before_execute (inputs, processed_inputs);
828838 } else {
829839 if (enable_schedule_overlap ()) {
0 commit comments