Skip to content

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Nov 12, 2025

If the model has both linear attention and kv cache attention, the datatype for causal conv may have different datatype vs. the kv cache dtype.

We may consider changing the cache_config.dtype to cache_config.kv_dtype and then add a new field cache_config.causal_conv_dtype

fixes #9102

Summary by CodeRabbit

  • Bug Fixes
    • Fixed cache tensor data type handling in Mamba convolution operations to consistently follow input tensor dtype across CUDA and PyTorch backends, ensuring improved type compatibility and removing fallback configurations.

Signed-off-by: Chenghao Zhang <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 12, 2025

📝 Walkthrough

Walkthrough

The conv-state cache generators in both CUDA and Torch backend implementations were updated to consistently use inp_fake.dtype for cache tensor dtype selection, removing the previous fallback to cache_config.dtype.

Changes

Cohort / File(s) Summary
Conv-state cache dtype selection
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
Modified _get_conv_cache method to always use inp_fake.dtype for the cache tensor dtype instead of falling back to cache_config.dtype.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~5 minutes

  • Repetitive pattern changes across two similar backend implementations
  • Straightforward dtype logic simplification with no control-flow modifications
  • No public API changes or external impact

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ⚠️ Warning PR description lacks required sections from the template: missing PR title format with ticket/type, incomplete Description section, and no Test Coverage details provided. Add proper PR title format (e.g., [TRTLLM-9102][feat] Summary), expand Description section with detailed explanation, and provide Test Coverage section listing relevant tests.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly indicates the main change: adding support for fp8 kv cache in AutoDeploy, which aligns with the raw summary showing dtype handling changes in causal convolution cache generation.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)

280-287: Approve dtype change; consider adding documentation.

The change to use inp_fake.dtype for the conv cache is correct and aligns with the PR's objective to decouple the convolution cache dtype from the KV cache dtype. This enables FP8 KV cache support while keeping the conv cache at the input precision.

Consider adding a brief comment explaining the dtype selection rationale:

     def _get_conv_cache(si: SequenceInfo):
+        # Use input dtype for conv cache to decouple from KV cache dtype,
+        # enabling FP8 KV cache with different conv precision
         return torch.empty(
             si.max_batch_size,
             in_channels,
             max(1, kernel_size - 1),
             device=si.device,
             dtype=inp_fake.dtype,
         )
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (1)

338-345: Approve dtype change; consistent with CUDA backend.

The change to use inp_fake.dtype for the conv cache is correct and mirrors the identical change in cuda_backend_causal_conv.py. This consistency across backends is good practice and ensures both implementations handle dtype separation uniformly.

Consider adding the same documentation comment as suggested for the CUDA backend:

     def _get_conv_cache(si: SequenceInfo):
+        # Use input dtype for conv cache to decouple from KV cache dtype,
+        # enabling FP8 KV cache with different conv precision
         return torch.empty(
             si.max_batch_size,
             in_channels,
             kernel_size,
             device=si.device,
             dtype=inp_fake.dtype,
         )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f1d637e and 7c8f2b0.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 8263
File: examples/models/contrib/sdxl/run_sdxl.py:0-0
Timestamp: 2025-10-13T13:55:04.170Z
Learning: The `diffusers` library (e.g., `DiffusionPipeline`, `StableDiffusionXLPipeline`, `StableDiffusion3Pipeline`) uses the `torch_dtype` parameter in `from_pretrained()` calls, not `dtype`. Only the `transformers` library has migrated to using `dtype`.
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@suyoggupta
Copy link
Collaborator

does nano_v3.yaml need an update?
also please post perf numbers.

@github-project-automation github-project-automation bot moved this from Backlog to In review in AutoDeploy Board Nov 13, 2025
@nvchenghaoz nvchenghaoz changed the title [TRTLLM-9102][feat] AutoDeploy: Support fp8 kv cache [#9102][feat] AutoDeploy: Support fp8 kv cache Nov 13, 2025
@nvchenghaoz
Copy link
Collaborator Author

does nano_v3.yaml need an update?

Nope, just need the update from the quant config.

also please post perf numbers.

FP8 model: 1k/1k/256 case
with fp8 kv cache: Request Throughput (req/sec): 6.9519
with BF16 kv cache: Request Throughput (req/sec): 6.6938

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24527 [ run ] triggered by Bot. Commit: 7c8f2b0

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24527 [ run ] completed with state SUCCESS. Commit: 7c8f2b0
/LLM/main/L0_MergeRequest_PR pipeline #18512 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24556 [ run ] triggered by Bot. Commit: 7c8f2b0

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24556 [ run ] completed with state SUCCESS. Commit: 7c8f2b0
/LLM/main/L0_MergeRequest_PR pipeline #18536 completed with status: 'SUCCESS'

@suyoggupta suyoggupta merged commit f6f6e1f into NVIDIA:main Nov 14, 2025
11 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in AutoDeploy Board Nov 14, 2025
zheyuf pushed a commit to zheyuf/TensorRT-LLM that referenced this pull request Nov 19, 2025
greg-kwasniewski1 pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Nov 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

[Feature]: AutoDeploy: Enable fp8 kv cache for Nano-v3 fp8 model

3 participants