From a78d38ae495d1311926a4d631437b742138f4f7a Mon Sep 17 00:00:00 2001 From: Peiying Hua Date: Wed, 19 Nov 2025 10:21:22 -0800 Subject: [PATCH] Allow specifiying the use of persistent kernel Summary: Added environment argument "use_persistent" (default is False) to explicitly turn off non-persistent kernel and use persistent kernel. Throws error when both "use_persistent" and "no_use_persistent" are specified in the arguments. Example usage: Persistent kernel-- buck2 run mode/{dev-nosan,amd-gpu} -c xlog.level=WARNING -m ovr_config//triton:trunk -m rocm7 -c fbcode.nvcc_arch=mi350 -c fbcode.enable_gpu_sections=true pytorch/tritonbench:run -- --op fp8_gemm_rowwise --no_use_tma --use_persistent Non-persistent kernel-- buck2 run mode/{dev-nosan,amd-gpu} -c xlog.level=WARNING -m ovr_config//triton:trunk -m rocm7 -c fbcode.nvcc_arch=mi350 -c fbcode.enable_gpu_sections=true pytorch/tritonbench:run -- --op fp8_gemm_rowwise --no_use_tma --no_use_persistent When both specified in the arguments: buck2 run mode/{dev-nosan,amd-gpu} -c xlog.level=WARNING -m ovr_config//triton:trunk -m rocm7 -c fbcode.nvcc_arch=mi350 -c fbcode.enable_gpu_sections=true pytorch/tritonbench:run -- --op fp8_gemm_rowwise --no_use_tma --use_persistent --no_use_persistent IT WILL THROW ERROR: Cannot specify both '--use_persistent' and '--no_use_persistent' at the same time. These options are mutually exclusive. Please use only one. Reviewed By: jwfromm Differential Revision: D86579911 --- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 7899854c3c..05b1dfd652 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -1212,6 +1212,8 @@ def matmul_fp8_row( imprecise_acc: bool = False, tma_persistent: bool = True, no_use_persistent: Optional[bool] = None, + # add an option to explicitly require the use of persistent process + use_persistent: Optional[bool] = None, use_warp_specialization: bool = False, ) -> torch.Tensor: """ @@ -1238,6 +1240,9 @@ def matmul_fp8_row( no_use_persistent = True else: no_use_persistent = False + # if use_persistent is explicitly requested, set o_use_persistent to False + if use_persistent: + no_use_persistent = False # Get datatypes and constants to use. pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape