Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ transforms:
############################################################################################
# COMPILE MODEL
############################################################################################
fuse_causal_conv_activation:
stage: compile
compile_model:
stage: compile
run_per_gm: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _cuda_cached_causal_conv1d(
dilation: int,
groups: int,
padding_mode: str,
activation: Optional[str],
) -> torch.Tensor:
"""Flattened cached causal conv that respects slot-indexed state caches (CUDA backend).

Expand Down Expand Up @@ -175,7 +176,7 @@ def _cuda_cached_causal_conv1d(
cache_indices=cache_indices,
has_initial_state=has_initial_state,
conv_states=conv_state_cache,
activation=None,
activation=activation,
pad_slot_id=PAD_SLOT_ID,
) # (dim, total_prefill_tokens)

Expand All @@ -185,24 +186,26 @@ def _cuda_cached_causal_conv1d(

# DECODE: batch update for single-token sequences
if num_decode > 0:
# Use true start offsets for decode tokens (tail after prefills)
decode_idx = seq_start[num_prefill:].to(torch.long)
x_decode = inp_flat.index_select(0, decode_idx) # [num_decode, C_in]
x_decode = inp_flat[
total_prefill_tokens : total_prefill_tokens + num_decode
] # [num_decode, C_in]

y_dec = causal_conv1d_update(
x_decode, # [batch, dim]
conv_state_cache,
w2d,
bias,
activation=None,
activation=activation,
cache_seqlens=None,
conv_state_indices=slot_idx[num_prefill:].to(torch.int32),
pad_slot_id=PAD_SLOT_ID,
)

if y_dec.dim() == 3:
y_dec = y_dec.squeeze(-1)
y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype))
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
y_dec.to(y_flat.dtype)
)

# Custom op must not return an alias of any input; return a fresh tensor
return y.contiguous().clone()
Expand All @@ -227,6 +230,7 @@ def _cuda_cached_causal_conv1d_fake(
dilation: int,
groups: int,
padding_mode: str,
activation: Optional[str],
):
return torch.empty(
input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype
Expand Down Expand Up @@ -293,4 +297,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
stride, padding, dilation, groups, padding_mode = extract_op_args(
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
)
return [stride, padding, dilation, groups, padding_mode]
# None is for activation parameter, which may not exist in the source node (added by fusion later)
return [stride, padding, dilation, groups, padding_mode, None]
Original file line number Diff line number Diff line change
Expand Up @@ -355,4 +355,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
stride, padding, dilation, groups, padding_mode = extract_op_args(
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
)
return [stride, padding, dilation, groups, padding_mode]
# None is for activation parameter, which may not exist in the source node (added by fusion later)
return [stride, padding, dilation, groups, padding_mode, None]
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,26 @@ def _triton_cached_ssm(

dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
dt_pre = torch.nn.functional.softplus(dt_hp + dt_bias_hp.to(dtype=dt_hp.dtype))
dt_pre = torch.clamp(dt_pre, time_step_limit[0], time_step_limit[1])
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
D_full = D[..., None].expand(num_heads, head_dim)

dt_bias_zero = torch.zeros_like(dt_bias_hp)
y_dec = selective_state_update(
ssm_state_cache,
x_decode,
dt_pre,
dt_hp,
A_full,
B_decode,
C_decode,
D=D_full,
z=None,
dt_bias=dt_bias_zero,
dt_softplus=False,
dt_bias=dt_bias_hp,
dt_softplus=True,
state_batch_indices=slot_idx_decode,
) # [nd, H, D]

y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype)
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
y_dec.to(y_flat.dtype)
)

return y

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Fusion transform for fusing activation functions into causal_conv1d operations."""

from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch.fx import GraphModule, Node

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry


def _match_causal_conv_activation_pattern(
graph: GraphModule,
target_op,
) -> List[Tuple[Node, Node, str]]:
"""
Match the causal_conv + activation pattern in the graph.

The pattern corresponds to:
conv_out = cuda_cached_causal_conv1d(...)
out = activation(conv_out)

Args:
graph: The graph module to search
target_op: The target causal conv op to match

Returns:
A list of tuples (conv_node, activation_node, activation_name) for each match
"""
matches = []

for node in graph.nodes:
if not is_op(node, target_op):
continue

# Check if this node has exactly one user and it's an activation
if len(node.users) != 1:
continue

activation_node = list(node.users.keys())[0]
if activation_node.op != "call_function":
continue

# Detect activation type
activation_name: Optional[str] = None
if activation_node.target in (torch.ops.aten.silu.default, F.silu):
activation_name = "silu"
# Can extend to support more activations here:
# elif activation_node.target in (torch.ops.aten.gelu.default, F.gelu):
# activation_name = "gelu"

if activation_name is not None:
matches.append((node, activation_node, activation_name))

return matches


@TransformRegistry.register("fuse_causal_conv_activation")
class FuseCausalConvActivation(BaseTransform):
"""Fuses activation functions into cached CUDA causal_conv1d operations.

This transform detects patterns like:
conv_out = cuda_cached_causal_conv1d(...)
out = silu(conv_out)

And replaces them with:
out = cuda_cached_causal_conv1d(..., activation="silu")

This optimization allows the backend CUDA kernels to fuse the activation,
reducing memory bandwidth and improving performance.

Note: This runs AFTER insert_cached_causal_conv, so it operates on the
cached CUDA operations, not the uncached torch operations.
"""

def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph

# Step 1: Identify causal_conv + activation pattern
matches = _match_causal_conv_activation_pattern(
graph,
target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d,
)

# Step 2: Replace matched patterns with fused version
for conv_node, activation_node, activation_name in matches:
with graph.inserting_after(conv_node):
# Create new call with fused activation
# Replace the last arg (activation=None) with activation_name
new_args = list(conv_node.args[:-1]) + [activation_name]
fused_node = graph.call_function(
torch.ops.auto_deploy.cuda_cached_causal_conv1d,
args=tuple(new_args),
)

# Replace all uses of activation_node with fused_node
activation_node.replace_all_uses_with(fused_node)

# Remove the old nodes
graph.erase_node(activation_node)
graph.erase_node(conv_node)

gm.recompile()

info = TransformInfo(
skipped=False,
num_matches=len(matches),
is_clean=False,
has_valid_shapes=False,
)
return gm, info
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env):
d,
g,
pm,
None,
)

assert y.shape == (batch, seq, c)
Expand Down