diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 36cf5a6dd49..6151dd92279 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -5,3 +5,4 @@ benchmark_results.json # ignore config files that users might put here for debugging *.yaml !nano_v3.yaml +!nano_v3_bench.yaml diff --git a/examples/auto_deploy/nano_v3_bench.yaml b/examples/auto_deploy/nano_v3_bench.yaml new file mode 100644 index 00000000000..fc9b04e2640 --- /dev/null +++ b/examples/auto_deploy/nano_v3_bench.yaml @@ -0,0 +1,23 @@ +runtime: trtllm +compile_backend: torch-cudagraph +max_batch_size: 384 # tunable +max_seq_len: 65536 # tunable +enable_chunked_prefill: true +attn_backend: flashinfer +model_factory: AutoModelForCausalLM +skip_loading_weights: false +free_mem_ratio: 0.9 +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384] +kv_cache_config: + # disable kv_cache reuse since not supported for hybrid/ssm models + enable_block_reuse: false +transforms: + detect_sharding: + sharding_source: ['factory', 'heuristic'] + sharding_dims: ['ep', 'bmm'] + # tunable mamba cache dtype + # --> use float32 for accuracy and default (null) for speed + insert_cached_ssm_attention: + cache_config: + # mamba_dtype: float32 + mamba_dtype: null diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index d219abd5951..90ea04db862 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -104,9 +104,15 @@ def trtllm_quant_fp8_linear( assert input_scale is not None input_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(input, input_scale) + enable_cuda_core = False + if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability(0) + enable_cuda_core = capability == (8, 9) or capability == (12, 0) # Use TensorRT-LLM FP8 scaled matrix multiply # Choose between CUDA core (for small M) and cuBLAS (for large M) implementations - if input_fp8.shape[0] <= 8: # NOTE: this kernel work with n % 2 == 0 as well?? + if ( + input_fp8.shape[0] <= 8 and enable_cuda_core + ): # NOTE: this kernel work with n % 2 == 0 as well?? # Use CUDA core for small M dimension (better for small batch sizes) output = torch.ops.trtllm.cuda_scaled_mm( input_fp8, diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index 8518681a7db..8b7c370fb23 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -88,6 +88,34 @@ def _nemotron_h_block_forward( return hidden_states +def _nemotron_h_topk_router_forward(self, hidden_states): + """ + Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel. + + This replaces the original forward method which used pure PyTorch operations + with a fused CUDA kernel that performs: + 1. Sigmoid activation of logits + 2. Group-based expert selection + 3. Top-k selection within selected groups + 4. Normalized weight computation + """ + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + + # Use the fused noaux_tc_op kernel which applies sigmoid internally + # and performs group-based top-k selection with normalization + topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op( + router_logits, + self.e_score_correction_bias, + self.n_group, + self.topk_group, + self.top_k, + self.routed_scaling_factor, + ) + + return topk_indices, topk_weights + + # Note: we assume experts have no bias for now def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): """ @@ -138,6 +166,7 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): ], "NemotronHBlock": [("forward", _nemotron_h_block_forward)], "NemotronHMOE": [("forward", _nemotron_h_moe_forward)], + "NemotronHTopkRouter": [("forward", _nemotron_h_topk_router_forward)], } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 84dce0d827c..b62fd9f2b9c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -19,7 +19,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: if not is_op(node, torch.ops.auto_deploy.torch_moe): continue - (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") + (mlp_style_val,) = extract_op_args(node, "mlp_style") hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = ( extract_op_args( @@ -50,7 +50,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" # Triton fused MoE op supports mlp only. - replacement_op = torch.ops.auto_deploy.trtllm_moe_fused + replacement_op = torch.ops.auto_deploy.triton_moe_fused else: raise ValueError(f"Unknown mlp_style: {mlp_style_val}") @@ -75,10 +75,6 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: graph.get_attr(new_key_w_up), graph.get_attr(new_key_w_down), ), - kwargs={ - "mlp_style": mlp_style_val, - "act_fn": act_fn_val, - }, ) node.replace_all_uses_with(new_node)