-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[#9102][feat] AutoDeploy: Support fp8 kv cache #9107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chenghao Zhang <[email protected]>
📝 WalkthroughWalkthroughThe conv-state cache generators in both CUDA and Torch backend implementations were updated to consistently use Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~5 minutes
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)
280-287: Approve dtype change; consider adding documentation.The change to use
inp_fake.dtypefor the conv cache is correct and aligns with the PR's objective to decouple the convolution cache dtype from the KV cache dtype. This enables FP8 KV cache support while keeping the conv cache at the input precision.Consider adding a brief comment explaining the dtype selection rationale:
def _get_conv_cache(si: SequenceInfo): + # Use input dtype for conv cache to decouple from KV cache dtype, + # enabling FP8 KV cache with different conv precision return torch.empty( si.max_batch_size, in_channels, max(1, kernel_size - 1), device=si.device, dtype=inp_fake.dtype, )tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py (1)
338-345: Approve dtype change; consistent with CUDA backend.The change to use
inp_fake.dtypefor the conv cache is correct and mirrors the identical change incuda_backend_causal_conv.py. This consistency across backends is good practice and ensures both implementations handle dtype separation uniformly.Consider adding the same documentation comment as suggested for the CUDA backend:
def _get_conv_cache(si: SequenceInfo): + # Use input dtype for conv cache to decouple from KV cache dtype, + # enabling FP8 KV cache with different conv precision return torch.empty( si.max_batch_size, in_channels, kernel_size, device=si.device, dtype=inp_fake.dtype, )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 8263
File: examples/models/contrib/sdxl/run_sdxl.py:0-0
Timestamp: 2025-10-13T13:55:04.170Z
Learning: The `diffusers` library (e.g., `DiffusionPipeline`, `StableDiffusionXLPipeline`, `StableDiffusion3Pipeline`) uses the `torch_dtype` parameter in `from_pretrained()` calls, not `dtype`. Only the `transformers` library has migrated to using `dtype`.
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
Applied to files:
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
|
does nano_v3.yaml need an update? |
Nope, just need the update from the quant config.
FP8 model: 1k/1k/256 case |
|
/bot run |
|
PR_Github #24527 [ run ] triggered by Bot. Commit: |
|
PR_Github #24527 [ run ] completed with state |
|
/bot run |
|
PR_Github #24556 [ run ] triggered by Bot. Commit: |
|
PR_Github #24556 [ run ] completed with state |
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
If the model has both linear attention and kv cache attention, the datatype for causal conv may have different datatype vs. the kv cache dtype.
We may consider changing the cache_config.dtype to cache_config.kv_dtype and then add a new field cache_config.causal_conv_dtype
fixes #9102
Summary by CodeRabbit