From 350a613d7014621acbbbada5dbcd2013c9cff31a Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:19:13 -0800 Subject: [PATCH 1/5] Perf improvement: Minor fixes Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../custom_ops/mamba/cuda_backend_causal_conv.py | 10 ++++++---- .../custom_ops/mamba/triton_backend_mamba.py | 9 +++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 014f8cc7e6b..e7ee030d24d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -185,9 +185,9 @@ def _cuda_cached_causal_conv1d( # DECODE: batch update for single-token sequences if num_decode > 0: - # Use true start offsets for decode tokens (tail after prefills) - decode_idx = seq_start[num_prefill:].to(torch.long) - x_decode = inp_flat.index_select(0, decode_idx) # [num_decode, C_in] + x_decode = inp_flat[ + total_prefill_tokens : total_prefill_tokens + num_decode + ] # [num_decode, C_in] y_dec = causal_conv1d_update( x_decode, # [batch, dim] @@ -202,7 +202,9 @@ def _cuda_cached_causal_conv1d( if y_dec.dim() == 3: y_dec = y_dec.squeeze(-1) - y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype)) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_( + y_dec.to(y_flat.dtype) + ) # Custom op must not return an alias of any input; return a fresh tensor return y.contiguous().clone() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 64b62419162..d5d1c6688d5 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -144,23 +144,20 @@ def _triton_cached_ssm( dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim) dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) - dt_pre = torch.nn.functional.softplus(dt_hp + dt_bias_hp.to(dtype=dt_hp.dtype)) - dt_pre = torch.clamp(dt_pre, time_step_limit[0], time_step_limit[1]) A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) D_full = D[..., None].expand(num_heads, head_dim) - dt_bias_zero = torch.zeros_like(dt_bias_hp) y_dec = selective_state_update( ssm_state_cache, x_decode, - dt_pre, + dt_hp, A_full, B_decode, C_decode, D=D_full, z=None, - dt_bias=dt_bias_zero, - dt_softplus=False, + dt_bias=dt_bias_hp, + dt_softplus=True, state_batch_indices=slot_idx_decode, ) # [nd, H, D] From 76530a4ee8be2f7566d3231b42913fabf7d53548 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Fri, 7 Nov 2025 17:09:57 -0800 Subject: [PATCH 2/5] Add conv act fusion Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 2 + .../mamba/cuda_backend_causal_conv.py | 13 +- .../mamba/torch_backend_causal_conv.py | 11 +- .../transform/library/fuse_causal_conv.py | 120 ++++++++++++++++++ 4 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index b9d001c3d72..fabde3af980 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -165,6 +165,8 @@ transforms: ############################################################################################ # COMPILE MODEL ############################################################################################ + fuse_causal_conv_activation: + stage: compile compile_model: stage: compile run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index e7ee030d24d..1ddd39126cd 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -112,6 +112,7 @@ def _cuda_cached_causal_conv1d( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ) -> torch.Tensor: """Flattened cached causal conv that respects slot-indexed state caches (CUDA backend). @@ -175,7 +176,7 @@ def _cuda_cached_causal_conv1d( cache_indices=cache_indices, has_initial_state=has_initial_state, conv_states=conv_state_cache, - activation=None, + activation=activation, pad_slot_id=PAD_SLOT_ID, ) # (dim, total_prefill_tokens) @@ -194,7 +195,7 @@ def _cuda_cached_causal_conv1d( conv_state_cache, w2d, bias, - activation=None, + activation=activation, cache_seqlens=None, conv_state_indices=slot_idx[num_prefill:].to(torch.int32), pad_slot_id=PAD_SLOT_ID, @@ -229,6 +230,7 @@ def _cuda_cached_causal_conv1d_fake( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ): return torch.empty( input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype @@ -295,4 +297,9 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - return [stride, padding, dilation, groups, padding_mode] + # activation parameter may not exist in the source node (added by fusion later) + try: + activation = extract_op_args(source_attn_node, "activation")[0] + except (RuntimeError, IndexError): + activation = None + return [stride, padding, dilation, groups, padding_mode, activation] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index a204c559f00..3ce623e066c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -197,6 +197,7 @@ def _torch_cached_causal_conv1d( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ) -> torch.Tensor: """Flattened cached causal conv that respects slot-indexed state caches. @@ -223,6 +224,7 @@ def _torch_cached_causal_conv1d( groups, padding_mode, cache_batch, + activation, ) conv_state_cache.index_copy_(0, slot_idx_long, updated_state.to(conv_state_cache.dtype)) @@ -257,6 +259,7 @@ def _torch_cached_causal_conv1d( dilation, groups, padding_mode, + activation, ) y_flat.index_copy_(0, idx_i, y_seq[0].to(y_flat.dtype)) @@ -286,6 +289,7 @@ def _torch_cached_causal_conv1d_fake( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ): return torch.empty( input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype @@ -355,4 +359,9 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - return [stride, padding, dilation, groups, padding_mode] + # activation parameter may not exist in the source node (added by fusion later) + try: + activation = extract_op_args(source_attn_node, "activation")[0] + except (RuntimeError, IndexError): + activation = None + return [stride, padding, dilation, groups, padding_mode, activation] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py new file mode 100644 index 00000000000..3acc8e1f80f --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py @@ -0,0 +1,120 @@ +"""Fusion transform for fusing activation functions into causal_conv1d operations.""" + +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _match_causal_conv_activation_pattern( + graph: GraphModule, + target_op, +) -> List[Tuple[Node, Node, str]]: + """ + Match the causal_conv + activation pattern in the graph. + + The pattern corresponds to: + conv_out = cuda_cached_causal_conv1d(...) + out = activation(conv_out) + + Args: + graph: The graph module to search + target_op: The target causal conv op to match + + Returns: + A list of tuples (conv_node, activation_node, activation_name) for each match + """ + matches = [] + + for node in graph.nodes: + if not is_op(node, target_op): + continue + + # Check if this node has exactly one user and it's an activation + if len(node.users) != 1: + continue + + activation_node = list(node.users.keys())[0] + if activation_node.op != "call_function": + continue + + # Detect activation type + activation_name: Optional[str] = None + if activation_node.target in (torch.ops.aten.silu.default, F.silu): + activation_name = "silu" + # Can extend to support more activations here: + # elif activation_node.target in (torch.ops.aten.gelu.default, F.gelu): + # activation_name = "gelu" + + if activation_name is not None: + matches.append((node, activation_node, activation_name)) + + return matches + + +@TransformRegistry.register("fuse_causal_conv_activation") +class FuseCausalConvActivation(BaseTransform): + """Fuses activation functions into cached CUDA causal_conv1d operations. + + This transform detects patterns like: + conv_out = cuda_cached_causal_conv1d(...) + out = silu(conv_out) + + And replaces them with: + out = cuda_cached_causal_conv1d(..., activation="silu") + + This optimization allows the backend CUDA kernels to fuse the activation, + reducing memory bandwidth and improving performance. + + Note: This runs AFTER insert_cached_causal_conv, so it operates on the + cached CUDA operations, not the uncached torch operations. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + + # Step 1: Identify causal_conv + activation pattern + matches = _match_causal_conv_activation_pattern( + graph, + target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d, + ) + + # Step 2: Replace matched patterns with fused version + for conv_node, activation_node, activation_name in matches: + with graph.inserting_after(conv_node): + # Create new call with fused activation + # Replace the last arg (activation=None) with activation_name + new_args = list(conv_node.args[:-1]) + [activation_name] + fused_node = graph.call_function( + torch.ops.auto_deploy.cuda_cached_causal_conv1d, + args=tuple(new_args), + ) + + # Replace all uses of activation_node with fused_node + activation_node.replace_all_uses_with(fused_node) + + # Remove the old nodes + graph.erase_node(activation_node) + graph.erase_node(conv_node) + + gm.recompile() + + info = TransformInfo( + skipped=False, + num_matches=len(matches), + is_clean=False, + has_valid_shapes=False, + ) + return gm, info From 8eb0c252e81caba3e046b363d7849914c70d0cb6 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sat, 8 Nov 2025 17:15:12 -0800 Subject: [PATCH 3/5] fix unit tests Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py | 1 + .../singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py | 1 + .../singlegpu/custom_ops/test_torch_causal_conv_cached_op.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py index 85e997c615b..c2efeb2c347 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py @@ -69,6 +69,7 @@ def _bamba_mixer_torch_forward( self.conv1d.dilation[0], self.conv1d.groups, self.conv1d.padding_mode, + None, ) ) else: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py index 2c9e4a70720..81f76e8a669 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -82,6 +82,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env): d, g, pm, + None, ) assert y.shape == (batch, seq, c) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py index 3988595346b..4d693135132 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py @@ -78,6 +78,7 @@ def test_generate_only_with_slot_mapping(conv_env): d, g, pm, + None, ) assert y.shape == (batch, seq, c_out) @@ -138,6 +139,7 @@ def test_context_flattened_and_state_writeback(conv_env): d, g, pm, + None, ) assert y.shape == (batch, seq, c_out) From 324181a0ddde1e8cbe84ffd29092e135241f4e05 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sat, 8 Nov 2025 19:31:49 -0800 Subject: [PATCH 4/5] fix tests Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py | 4 ---- .../auto_deploy/custom_ops/mamba/triton_backend_mamba.py | 4 +++- tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py | 1 - .../singlegpu/custom_ops/test_torch_causal_conv_cached_op.py | 2 -- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index 3ce623e066c..8b94dd967ef 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -197,7 +197,6 @@ def _torch_cached_causal_conv1d( dilation: int, groups: int, padding_mode: str, - activation: Optional[str], ) -> torch.Tensor: """Flattened cached causal conv that respects slot-indexed state caches. @@ -224,7 +223,6 @@ def _torch_cached_causal_conv1d( groups, padding_mode, cache_batch, - activation, ) conv_state_cache.index_copy_(0, slot_idx_long, updated_state.to(conv_state_cache.dtype)) @@ -259,7 +257,6 @@ def _torch_cached_causal_conv1d( dilation, groups, padding_mode, - activation, ) y_flat.index_copy_(0, idx_i, y_seq[0].to(y_flat.dtype)) @@ -289,7 +286,6 @@ def _torch_cached_causal_conv1d_fake( dilation: int, groups: int, padding_mode: str, - activation: Optional[str], ): return torch.empty( input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index d5d1c6688d5..e89530b1dd9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -161,7 +161,9 @@ def _triton_cached_ssm( state_batch_indices=slot_idx_decode, ) # [nd, H, D] - y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_( + y_dec.to(y_flat.dtype) + ) return y diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py index c2efeb2c347..85e997c615b 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py @@ -69,7 +69,6 @@ def _bamba_mixer_torch_forward( self.conv1d.dilation[0], self.conv1d.groups, self.conv1d.padding_mode, - None, ) ) else: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py index 4d693135132..3988595346b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py @@ -78,7 +78,6 @@ def test_generate_only_with_slot_mapping(conv_env): d, g, pm, - None, ) assert y.shape == (batch, seq, c_out) @@ -139,7 +138,6 @@ def test_context_flattened_and_state_writeback(conv_env): d, g, pm, - None, ) assert y.shape == (batch, seq, c_out) From e833d2fdfef31512e8624a1d1aa84f59a02f291f Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:06:00 -0800 Subject: [PATCH 5/5] Address reviewer's comments Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../custom_ops/mamba/cuda_backend_causal_conv.py | 8 ++------ .../custom_ops/mamba/torch_backend_causal_conv.py | 8 ++------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 1ddd39126cd..0803375847f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -297,9 +297,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - # activation parameter may not exist in the source node (added by fusion later) - try: - activation = extract_op_args(source_attn_node, "activation")[0] - except (RuntimeError, IndexError): - activation = None - return [stride, padding, dilation, groups, padding_mode, activation] + # None is for activation parameter, which may not exist in the source node (added by fusion later) + return [stride, padding, dilation, groups, padding_mode, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index 8b94dd967ef..6f0059d250d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -355,9 +355,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - # activation parameter may not exist in the source node (added by fusion later) - try: - activation = extract_op_args(source_attn_node, "activation")[0] - except (RuntimeError, IndexError): - activation = None - return [stride, padding, dilation, groups, padding_mode, activation] + # None is for activation parameter, which may not exist in the source node (added by fusion later) + return [stride, padding, dilation, groups, padding_mode, None]