Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 266 additions & 39 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py

Large diffs are not rendered by default.

179 changes: 124 additions & 55 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ..models.modeling_utils import DecoderModelForCausalLM
from ..modules.decoder_layer import DecoderLayer
from ..speculative.drafter import Drafter
from ..speculative.mtp import SampleStateTensorsMTP
from ..speculative.speculation_gate import SpeculationGate
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
from .guided_decoder import GuidedDecoder
Expand Down Expand Up @@ -276,7 +277,7 @@ def __init__(self,
if self.dist.pp_size > 1:
self.event_loop = self._executor_loop_pp
else:
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
self.event_loop = trace_func(self.event_loop)

Expand Down Expand Up @@ -1060,14 +1061,11 @@ def _prepare_and_schedule_batch(self):
0
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []

# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
if not self.has_previous_draft_tokens:
# If speculation is off, this function sets py_draft_tokens to []
# for all active requests. If it's on, we initialize py_draft_tokens
# with dummy draft tokens to make the scheduler aware of the fact
# that speculation is about to happen.
self._prepare_draft_requests()
# If speculation is off, this function sets py_draft_tokens to []
# for all active requests. If it's on, we initialize py_draft_tokens
# with dummy draft tokens to make the scheduler aware of the fact
# that speculation is about to happen.
self._prepare_draft_requests()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
Expand Down Expand Up @@ -1316,6 +1314,8 @@ def _executor_loop_overlap(self):
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
target_inputs = None
previous_tensors_device = None
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
while True:
profile_step()
Expand Down Expand Up @@ -1396,31 +1396,29 @@ def _executor_loop_overlap(self):
self.guided_decoder.init_disagg_gen_requests()

previous_tensors = self.previous_batch and self.previous_batch.sample_state
target_inputs = None
draft_outputs = None
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
use_previous_draft_tokens = self.has_previous_draft_tokens
if self.drafter is not None and (self.use_spec_decode or
use_previous_draft_tokens):
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
scheduled_batch, previous_tensors)
target_inputs = self._handle_speculative_decoding(
scheduled_batch, previous_tensors,
previous_tensors_device)

# Use the draft_model's outputs if we've launched the draft model.
# Otherwise, use the previous batch's outputs.
if target_inputs is not None or use_previous_draft_tokens:
if (target_inputs is not None
and target_inputs.next_draft_tokens
is not None) or use_previous_draft_tokens:
previous_tensors_device = target_inputs
else:
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device

batch_outputs = self._forward_step(scheduled_batch,
previous_tensors_device)

if target_inputs is not None:
self._process_draft_results(scheduled_batch,
draft_outputs, draft_batch)
elif self.previous_batch is not None and not use_previous_draft_tokens:
if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)

if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
Expand All @@ -1435,6 +1433,10 @@ def _executor_loop_overlap(self):
(req, block_id,
self.ctx_in_transmission_counter))

if self.drafter is not None and self.use_spec_decode:
# Cleanup previous draft resources used in the draft model
self.drafter.cleanup_previous_draft_resources()

if self.guided_decoder is not None:
# add_batch must be called again to have updated new tokens.
self.guided_decoder.add_batch(scheduled_batch)
Expand Down Expand Up @@ -1469,6 +1471,94 @@ def _executor_loop_overlap(self):

self._kv_connector_terminate_requests()

def _accept_draft_tokens(
self, scheduled_batch: ScheduledRequests,
target_outputs: SampleStateTensors,
target_inputs: Optional[SampleStateTensors]
) -> Tuple[SampleStateTensorsMTP, Optional[torch.Tensor]]:
"""
Prepare target device inputs after computing draft token acceptance.

This function:
1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
2. If no draft tokens: directly uses the first sampled token
3. Creates new_tokens by extracting accepted tokens per request

Args:
scheduled_batch: The scheduled requests
target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
or [1, batch_size, beam_width] if no draft tokens
target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
Returns:
Tuple of:
- SampleStateTensorsMTP with new_tokens set to accepted tokens,
new_tokens_lens and next_draft_tokens set to None
- num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
or None if no draft tokens
"""
has_draft_tokens = target_inputs is not None and isinstance(
target_inputs, SampleStateTensorsMTP
) and target_inputs.next_draft_tokens is not None
target_tokens = target_outputs.new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
new_tokens = torch.zeros_like(target_tokens)

# Squeeze the beam dimension (beam_width=1 for greedy or single beam)
target_tokens = target_tokens.squeeze(
-1) # [max_draft_len + 1, batch_size] or [1, batch_size]

batch_size = target_tokens.shape[1]
device = target_tokens.device
# Compute number of accepted tokens per request
num_accepted_tokens = torch.zeros(batch_size,
dtype=torch.int32,
device=device)

if has_draft_tokens:
# Draft tokens exist, compute acceptance
draft_tokens = target_inputs.next_draft_tokens # [batch_size, max_draft_len]
max_draft_len = draft_tokens.shape[1]

# Compute number of accepted tokens per request
# Generation requests: compare with draft tokens to find acceptance
num_contexts = len(scheduled_batch.context_requests)
if batch_size > num_contexts:
# Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
gen_target_tokens = target_tokens[:,
num_contexts:].T # [num_gens, max_draft_len + 1]

# Compare draft tokens with target tokens to find acceptance
# Use cumprod to find the first rejection point
draft_tokens_gen = draft_tokens[
num_contexts:, :].int() # [num_gens, max_draft_len]
num_accepted_tokens[num_contexts:] += torch.cumprod(
(draft_tokens_gen == gen_target_tokens[:, :max_draft_len]
).int(),
dim=-1).sum(dim=1)

# Vectorized extraction using advanced indexing (no GPU-CPU sync)
# Use num_accepted_tokens as indices to gather the right tokens
batch_indices = torch.arange(batch_size, device=device)
new_tokens[0, :, 0] = target_tokens[num_accepted_tokens,
batch_indices]
else:
# No draft tokens to accept, just use the first (and only) sampled token
batch_indices = torch.arange(batch_size, device=device)
new_tokens[0, :, 0] = target_tokens[0, batch_indices]

# Create the updated SampleStateTensorsMTP
# new_tokens_lens and next_draft_tokens are left as None
result_tensors = SampleStateTensorsMTP(
new_tokens=new_tokens,
log_probs=target_outputs.log_probs,
new_tokens_lens=None,
next_draft_tokens=None)

# Copy logits if available
if hasattr(target_outputs, 'logits'):
result_tensors.logits = target_outputs.logits

return result_tensors, num_accepted_tokens

def _process_previous_batch(self):
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
for req in self.previous_batch.ctx_transmission_reqs:
Expand Down Expand Up @@ -2365,7 +2455,8 @@ def _remove_inflight_ids(self, scheduled_requests):
for req in scheduled_requests.all_requests():
self.inflight_req_ids.erase(req.request_id)

def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
target_inputs):
with request_context(is_draft=self.draft_model_engine is not None,
scheduled_requests=scheduled_batch):
# Do an early checking to see if we need to forward the draft model.
Expand All @@ -2375,20 +2466,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
self.previous_batch is not None and self.use_spec_decode
and self.drafter.should_forward_draft_model(scheduled_batch))

if has_draft_batch or self.has_previous_draft_tokens:
self._update_requests(self.previous_batch.sample_state)
if self.has_previous_draft_tokens:
self._prepare_draft_requests()
new_target_inputs = None
if has_draft_batch:
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
assert target_outputs is not None, "target_outputs should not be None"
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
scheduled_batch=scheduled_batch,
target_inputs=target_inputs,
target_outputs=target_outputs)

if has_draft_batch:
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
self.drafter.generate_draft_tokens_with_overlap(
scheduled_batch, self.resource_manager,
previous_tensors.device if previous_tensors else None)
previous_tensors.device if previous_tensors else None,
new_target_inputs, num_accepted_tokens_device)

self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None
# Pad draft tokens to the max draft length for CUDA graph compatibility
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
else:
self.has_previous_draft_tokens = False
target_inputs, draft_outputs, draft_batch = None, None, None
# We are not running the draft model. Remove the draft tokens and turn off spec
# decode so that the requests get handled correctly.
# One corner case: when we have at least one context request, we have to keep spec
Expand All @@ -2401,34 +2497,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
for request in scheduled_batch.all_requests():
request.py_draft_tokens = []

return target_inputs, draft_outputs, draft_batch

def _process_draft_results(self, scheduled_batch, draft_outputs,
draft_batch):
"""
Append the draft tokens to the target requests, and clean up the draft resources.
"""
with request_context(is_draft=self.draft_model_engine is not None,
scheduled_requests=scheduled_batch):
req_id_to_old_request = {
req.py_request_id: req
for req in scheduled_batch.all_requests()
}

if self.drafter.use_static_draft_loop:
self.drafter.process_static_draft_outputs(
draft_outputs, draft_batch, req_id_to_old_request)
elif draft_outputs is not None:
self.drafter.process_dynamic_draft_outputs(
draft_outputs, req_id_to_old_request)

# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
# add_batch must be called again to restore to target requests with updated draft tokens.
if self.guided_decoder is not None:
self.guided_decoder.add_batch(scheduled_batch)
if hasattr(self.drafter, "guided_decoder"):
self.guided_decoder.rollback_draft_tokens()
return new_target_inputs

def reset_prefix_cache(self):
self.kv_cache_manager.reset_reuse_state()
Expand Down
11 changes: 11 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorrt_llm.quantization import QuantAlgo

from ..attention_backend.interface import AttentionRuntimeFeatures
from ..attention_backend.trtllm import TrtllmAttention
from ..distributed import MPIDist, TorchDist
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
get_spec_resource_manager)
Expand Down Expand Up @@ -390,6 +391,16 @@ def drafting_loop_wrapper(model):
else:
draft_model_engine = None

# TODO: Overlap scheduler is not supported for below cases:
# 1. non-CDL is used
# 2. non-TrtllmAttention attention backend is used
if has_draft_model_engine and (not use_chain_drafter or not issubclass(
draft_model_engine.attn_backend, TrtllmAttention)):
logger.warning(
"Overlap scheduler is not supported for non-CDL or non-TrtllmAttention backend."
)
llm_args.disable_overlap_scheduler = True

# PyTorchModelEngine modifies these fields, update them
model_engine_max_seq_len = model_engine.max_seq_len
net_max_seq_len = model_engine_max_seq_len
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def _group_requests_by_strategy_key(
)
for req_index, req in enumerate(requests):
strategy = _request_strategy(req, vocab_size=vocab_size)
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
# In the overlap path, py_draft_logits is not updated yet,
# so we use get_draft_token_length() for the checking.
speculation_needs_probs = get_draft_token_length(req) > 0 and strategy is not GREEDY
strategy_key = strategy_to_key(strategy, speculation_needs_probs)
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
group_dict_entry[0].append(req_index)
Expand Down
Loading
Loading