diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index b9d001c3d72..fabde3af980 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -165,6 +165,8 @@ transforms: ############################################################################################ # COMPILE MODEL ############################################################################################ + fuse_causal_conv_activation: + stage: compile compile_model: stage: compile run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 014f8cc7e6b..0803375847f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -112,6 +112,7 @@ def _cuda_cached_causal_conv1d( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ) -> torch.Tensor: """Flattened cached causal conv that respects slot-indexed state caches (CUDA backend). @@ -175,7 +176,7 @@ def _cuda_cached_causal_conv1d( cache_indices=cache_indices, has_initial_state=has_initial_state, conv_states=conv_state_cache, - activation=None, + activation=activation, pad_slot_id=PAD_SLOT_ID, ) # (dim, total_prefill_tokens) @@ -185,16 +186,16 @@ def _cuda_cached_causal_conv1d( # DECODE: batch update for single-token sequences if num_decode > 0: - # Use true start offsets for decode tokens (tail after prefills) - decode_idx = seq_start[num_prefill:].to(torch.long) - x_decode = inp_flat.index_select(0, decode_idx) # [num_decode, C_in] + x_decode = inp_flat[ + total_prefill_tokens : total_prefill_tokens + num_decode + ] # [num_decode, C_in] y_dec = causal_conv1d_update( x_decode, # [batch, dim] conv_state_cache, w2d, bias, - activation=None, + activation=activation, cache_seqlens=None, conv_state_indices=slot_idx[num_prefill:].to(torch.int32), pad_slot_id=PAD_SLOT_ID, @@ -202,7 +203,9 @@ def _cuda_cached_causal_conv1d( if y_dec.dim() == 3: y_dec = y_dec.squeeze(-1) - y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype)) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_( + y_dec.to(y_flat.dtype) + ) # Custom op must not return an alias of any input; return a fresh tensor return y.contiguous().clone() @@ -227,6 +230,7 @@ def _cuda_cached_causal_conv1d_fake( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ): return torch.empty( input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype @@ -293,4 +297,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - return [stride, padding, dilation, groups, padding_mode] + # None is for activation parameter, which may not exist in the source node (added by fusion later) + return [stride, padding, dilation, groups, padding_mode, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index a204c559f00..6f0059d250d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -355,4 +355,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - return [stride, padding, dilation, groups, padding_mode] + # None is for activation parameter, which may not exist in the source node (added by fusion later) + return [stride, padding, dilation, groups, padding_mode, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 64b62419162..e89530b1dd9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -144,27 +144,26 @@ def _triton_cached_ssm( dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim) dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) - dt_pre = torch.nn.functional.softplus(dt_hp + dt_bias_hp.to(dtype=dt_hp.dtype)) - dt_pre = torch.clamp(dt_pre, time_step_limit[0], time_step_limit[1]) A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) D_full = D[..., None].expand(num_heads, head_dim) - dt_bias_zero = torch.zeros_like(dt_bias_hp) y_dec = selective_state_update( ssm_state_cache, x_decode, - dt_pre, + dt_hp, A_full, B_decode, C_decode, D=D_full, z=None, - dt_bias=dt_bias_zero, - dt_softplus=False, + dt_bias=dt_bias_hp, + dt_softplus=True, state_batch_indices=slot_idx_decode, ) # [nd, H, D] - y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_( + y_dec.to(y_flat.dtype) + ) return y diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py new file mode 100644 index 00000000000..3acc8e1f80f --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py @@ -0,0 +1,120 @@ +"""Fusion transform for fusing activation functions into causal_conv1d operations.""" + +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _match_causal_conv_activation_pattern( + graph: GraphModule, + target_op, +) -> List[Tuple[Node, Node, str]]: + """ + Match the causal_conv + activation pattern in the graph. + + The pattern corresponds to: + conv_out = cuda_cached_causal_conv1d(...) + out = activation(conv_out) + + Args: + graph: The graph module to search + target_op: The target causal conv op to match + + Returns: + A list of tuples (conv_node, activation_node, activation_name) for each match + """ + matches = [] + + for node in graph.nodes: + if not is_op(node, target_op): + continue + + # Check if this node has exactly one user and it's an activation + if len(node.users) != 1: + continue + + activation_node = list(node.users.keys())[0] + if activation_node.op != "call_function": + continue + + # Detect activation type + activation_name: Optional[str] = None + if activation_node.target in (torch.ops.aten.silu.default, F.silu): + activation_name = "silu" + # Can extend to support more activations here: + # elif activation_node.target in (torch.ops.aten.gelu.default, F.gelu): + # activation_name = "gelu" + + if activation_name is not None: + matches.append((node, activation_node, activation_name)) + + return matches + + +@TransformRegistry.register("fuse_causal_conv_activation") +class FuseCausalConvActivation(BaseTransform): + """Fuses activation functions into cached CUDA causal_conv1d operations. + + This transform detects patterns like: + conv_out = cuda_cached_causal_conv1d(...) + out = silu(conv_out) + + And replaces them with: + out = cuda_cached_causal_conv1d(..., activation="silu") + + This optimization allows the backend CUDA kernels to fuse the activation, + reducing memory bandwidth and improving performance. + + Note: This runs AFTER insert_cached_causal_conv, so it operates on the + cached CUDA operations, not the uncached torch operations. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + + # Step 1: Identify causal_conv + activation pattern + matches = _match_causal_conv_activation_pattern( + graph, + target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d, + ) + + # Step 2: Replace matched patterns with fused version + for conv_node, activation_node, activation_name in matches: + with graph.inserting_after(conv_node): + # Create new call with fused activation + # Replace the last arg (activation=None) with activation_name + new_args = list(conv_node.args[:-1]) + [activation_name] + fused_node = graph.call_function( + torch.ops.auto_deploy.cuda_cached_causal_conv1d, + args=tuple(new_args), + ) + + # Replace all uses of activation_node with fused_node + activation_node.replace_all_uses_with(fused_node) + + # Remove the old nodes + graph.erase_node(activation_node) + graph.erase_node(conv_node) + + gm.recompile() + + info = TransformInfo( + skipped=False, + num_matches=len(matches), + is_clean=False, + has_valid_shapes=False, + ) + return gm, info diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py index 2c9e4a70720..81f76e8a669 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -82,6 +82,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env): d, g, pm, + None, ) assert y.shape == (batch, seq, c)