From 25b5932ea39f3a53e2d59ae8749cf41db12db36b Mon Sep 17 00:00:00 2001 From: Shreyas Misra Date: Tue, 18 Nov 2025 13:25:57 -0800 Subject: [PATCH] feat: enable iter stats in autodeploy Signed-off-by: Shreyas Misra --- tensorrt_llm/_torch/auto_deploy/llm_args.py | 10 ++++++ .../_torch/auto_deploy/shim/ad_executor.py | 34 ++++++++++++++++--- tensorrt_llm/bench/benchmark/__init__.py | 6 ++-- .../unit/singlegpu/test_ad_trtllm_bench.py | 7 ++++ 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index efa8a4c367f..ce08d0fcc44 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -197,6 +197,16 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "backends, this should equal max_seq_len. Temporary field until tokens_per_block gets " "properly passed through.", ) + enable_iter_perf_stats: bool = Field( + default=False, description="Enable iteration performance statistics.", status="prototype" + ) + + enable_iter_req_stats: bool = Field( + default=False, + description="If true, enables per request stats per iteration. Must also set " + "enable_iter_perf_stats to true to get request stats.", + status="prototype", + ) ### VALIDATION ################################################################################# @model_validator(mode="after") diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index f818cb76bce..8ea1af33b3d 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -10,6 +10,7 @@ # limitations under the License. from collections import defaultdict +from dataclasses import dataclass from types import SimpleNamespace from typing import Dict, List, Optional, Tuple @@ -46,6 +47,13 @@ from .interface import CachedSequenceInterface, GetInferenceModel +@dataclass +class ReportingInfo: + print_log: bool = False + enable_iter_perf_stats: bool = False + enable_iter_req_stats: bool = False + + class _CacheManagerWithFakePool(KVCacheManager): """We use the default KVCacheManager but with a fake pool by setting head_dim=0. @@ -123,6 +131,11 @@ def build_from_config(cls, ad_config: LlmArgs): vocab_size_padded=factory.vocab_size_padded, chunk_size=factory.chunk_size, ) + reporting_info = ReportingInfo( + print_log=False, + enable_iter_perf_stats=ad_config.enable_iter_perf_stats, + enable_iter_req_stats=ad_config.enable_iter_req_stats, + ) # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. @@ -130,7 +143,7 @@ def build_from_config(cls, ad_config: LlmArgs): build_and_optimize = InferenceOptimizer(factory=factory, config=ad_config.transforms) # construct engine - return cls(build_and_optimize, seq_info, device, max_beam_width) + return cls(build_and_optimize, seq_info, device, max_beam_width, reporting_info) @torch.inference_mode() def __init__( @@ -139,20 +152,23 @@ def __init__( seq_info: SequenceInfo, device: DeviceLikeType, max_beam_width: int = 1, + reporting_info: ReportingInfo = ReportingInfo(), ) -> None: """Initialize the engine with model and sequence information.""" # NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements... # This is not correctly declared in the base ModelEngine class though... self.llm_args = SimpleNamespace() - self.llm_args.print_iter_log = False - self.llm_args.enable_iter_perf_stats = False - self.llm_args.enable_iter_req_stats = False + self.llm_args.print_iter_log = reporting_info.print_log + self.llm_args.enable_iter_perf_stats = reporting_info.enable_iter_perf_stats + self.llm_args.enable_iter_req_stats = reporting_info.enable_iter_req_stats self.llm_args.stream_interval = 1 self.llm_args.attention_dp_config = None self.llm_args.batch_wait_timeout_ms = 0 self.llm_args.batch_wait_timeout_iters = 0 self.llm_args.batch_wait_max_tokens_ratio = 0.0 self.llm_args.max_num_tokens = seq_info.max_num_tokens + self.iter_counter = 0 + self.iter_states = {} # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... self.max_beam_width = max_beam_width @@ -196,6 +212,9 @@ def _prepare_inputs( extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) dummy_token = -1 + num_ctx_requests = len(context_requests) + num_ctx_tokens = 0 + num_generation_tokens = 0 # look at context requests first for request in context_requests: @@ -206,6 +225,7 @@ def _prepare_inputs( begin_compute = request.context_current_position end_compute = begin_compute + request.context_chunk_size prompt_tokens = all_prompt_tokens[begin_compute:end_compute] + num_ctx_tokens += len(prompt_tokens) input_ids.append(prompt_tokens) input_pos.append(begin_compute) @@ -238,6 +258,7 @@ def _prepare_inputs( input_pos.append(request.max_beam_num_tokens) flat_gather_idx.append(request.py_batch_idx) + num_generation_tokens += 1 request.py_batch_idx = request.seq_slot # store seq slot idx @@ -267,6 +288,10 @@ def _prepare_inputs( scatter_ref=dummy_token, ) + self.iter_states["num_ctx_requests"] = num_ctx_requests + self.iter_states["num_ctx_tokens"] = num_ctx_tokens + # TODO: handle extend requests and draft requests for specdec + self.iter_states["num_generation_tokens"] = num_generation_tokens return last_logit_only @nvtx_range("ad_compute_logits") @@ -294,6 +319,7 @@ def forward( # convert requests and store in sequence info object new_tokens = getattr(new_tensors_device, "new_tokens", None) last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens) + self.iter_counter += 1 # compute all logits logits = self._compute_logits() diff --git a/tensorrt_llm/bench/benchmark/__init__.py b/tensorrt_llm/bench/benchmark/__init__.py index 63404f675ab..4c07816a405 100644 --- a/tensorrt_llm/bench/benchmark/__init__.py +++ b/tensorrt_llm/bench/benchmark/__init__.py @@ -107,12 +107,12 @@ def get_llm(runtime_config: RuntimeConfig, kwargs: dict): if runtime_config.backend != None: ignore_trt_only_args(kwargs, runtime_config.backend) + if runtime_config.iteration_log is not None: + kwargs["enable_iter_perf_stats"] = True + if runtime_config.backend == 'pytorch': llm_cls = PyTorchLLM - if runtime_config.iteration_log is not None: - kwargs["enable_iter_perf_stats"] = True - elif runtime_config.backend == "_autodeploy": from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py index e6515d3d802..a63eca22c99 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py @@ -31,6 +31,8 @@ def run_benchmark( "_autodeploy", "--dataset", dataset_path, + "--iteration_log", + "iteration_log.log", "--extra_llm_api_options", f"{extra_llm_api_options_path}", ] @@ -38,6 +40,11 @@ def run_benchmark( result = runner.invoke(main, args, catch_exceptions=False) assert result.exit_code == 0 + with open("iteration_log.log", "r") as f: + lines = f.readlines() + assert len(lines) > 0 + # TODO: add more checks + def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str): _DATASET_NAME = "synthetic_128_128.txt"