Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"backends, this should equal max_seq_len. Temporary field until tokens_per_block gets "
"properly passed through.",
)
enable_iter_perf_stats: bool = Field(
default=False, description="Enable iteration performance statistics.", status="prototype"
)

enable_iter_req_stats: bool = Field(
default=False,
description="If true, enables per request stats per iteration. Must also set "
"enable_iter_perf_stats to true to get request stats.",
status="prototype",
)

### VALIDATION #################################################################################
@model_validator(mode="after")
Expand Down
34 changes: 30 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from collections import defaultdict
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -46,6 +47,13 @@
from .interface import CachedSequenceInterface, GetInferenceModel


@dataclass
class ReportingInfo:
print_log: bool = False
enable_iter_perf_stats: bool = False
enable_iter_req_stats: bool = False


class _CacheManagerWithFakePool(KVCacheManager):
"""We use the default KVCacheManager but with a fake pool by setting head_dim=0.

Expand Down Expand Up @@ -123,14 +131,19 @@ def build_from_config(cls, ad_config: LlmArgs):
vocab_size_padded=factory.vocab_size_padded,
chunk_size=factory.chunk_size,
)
reporting_info = ReportingInfo(
print_log=False,
enable_iter_perf_stats=ad_config.enable_iter_perf_stats,
enable_iter_req_stats=ad_config.enable_iter_req_stats,
)
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.

# construct inference optimizer
build_and_optimize = InferenceOptimizer(factory=factory, config=ad_config.transforms)

# construct engine
return cls(build_and_optimize, seq_info, device, max_beam_width)
return cls(build_and_optimize, seq_info, device, max_beam_width, reporting_info)

@torch.inference_mode()
def __init__(
Expand All @@ -139,20 +152,23 @@ def __init__(
seq_info: SequenceInfo,
device: DeviceLikeType,
max_beam_width: int = 1,
reporting_info: ReportingInfo = ReportingInfo(),
) -> None:
"""Initialize the engine with model and sequence information."""
# NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
# This is not correctly declared in the base ModelEngine class though...
self.llm_args = SimpleNamespace()
self.llm_args.print_iter_log = False
self.llm_args.enable_iter_perf_stats = False
self.llm_args.enable_iter_req_stats = False
self.llm_args.print_iter_log = reporting_info.print_log
self.llm_args.enable_iter_perf_stats = reporting_info.enable_iter_perf_stats
self.llm_args.enable_iter_req_stats = reporting_info.enable_iter_req_stats
self.llm_args.stream_interval = 1
self.llm_args.attention_dp_config = None
self.llm_args.batch_wait_timeout_ms = 0
self.llm_args.batch_wait_timeout_iters = 0
self.llm_args.batch_wait_max_tokens_ratio = 0.0
self.llm_args.max_num_tokens = seq_info.max_num_tokens
self.iter_counter = 0
self.iter_states = {}

# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
self.max_beam_width = max_beam_width
Expand Down Expand Up @@ -196,6 +212,9 @@ def _prepare_inputs(
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)

dummy_token = -1
num_ctx_requests = len(context_requests)
num_ctx_tokens = 0
num_generation_tokens = 0

# look at context requests first
for request in context_requests:
Expand All @@ -206,6 +225,7 @@ def _prepare_inputs(
begin_compute = request.context_current_position
end_compute = begin_compute + request.context_chunk_size
prompt_tokens = all_prompt_tokens[begin_compute:end_compute]
num_ctx_tokens += len(prompt_tokens)

input_ids.append(prompt_tokens)
input_pos.append(begin_compute)
Expand Down Expand Up @@ -238,6 +258,7 @@ def _prepare_inputs(
input_pos.append(request.max_beam_num_tokens)
flat_gather_idx.append(request.py_batch_idx)

num_generation_tokens += 1
request.py_batch_idx = request.seq_slot

# store seq slot idx
Expand Down Expand Up @@ -267,6 +288,10 @@ def _prepare_inputs(
scatter_ref=dummy_token,
)

self.iter_states["num_ctx_requests"] = num_ctx_requests
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
# TODO: handle extend requests and draft requests for specdec
self.iter_states["num_generation_tokens"] = num_generation_tokens
return last_logit_only

@nvtx_range("ad_compute_logits")
Expand Down Expand Up @@ -294,6 +319,7 @@ def forward(
# convert requests and store in sequence info object
new_tokens = getattr(new_tensors_device, "new_tokens", None)
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens)
self.iter_counter += 1

# compute all logits
logits = self._compute_logits()
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/bench/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def get_llm(runtime_config: RuntimeConfig, kwargs: dict):
if runtime_config.backend != None:
ignore_trt_only_args(kwargs, runtime_config.backend)

if runtime_config.iteration_log is not None:
kwargs["enable_iter_perf_stats"] = True

if runtime_config.backend == 'pytorch':
llm_cls = PyTorchLLM

if runtime_config.iteration_log is not None:
kwargs["enable_iter_perf_stats"] = True

elif runtime_config.backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@ def run_benchmark(
"_autodeploy",
"--dataset",
dataset_path,
"--iteration_log",
"iteration_log.log",
"--extra_llm_api_options",
f"{extra_llm_api_options_path}",
]
)
result = runner.invoke(main, args, catch_exceptions=False)
assert result.exit_code == 0

with open("iteration_log.log", "r") as f:
lines = f.readlines()
assert len(lines) > 0
# TODO: add more checks


def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str):
_DATASET_NAME = "synthetic_128_128.txt"
Expand Down
Loading