diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 7899854c3c..37f0e663bd 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -1212,6 +1212,8 @@ def matmul_fp8_row( imprecise_acc: bool = False, tma_persistent: bool = True, no_use_persistent: Optional[bool] = None, + # add an option to explicitly require the use of persistent process + use_persistent: Optional[bool] = None, use_warp_specialization: bool = False, ) -> torch.Tensor: """ @@ -1232,12 +1234,16 @@ def matmul_fp8_row( Returns: torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) """ - if no_use_persistent is None: + if use_persistent: + no_use_persistent = False + elif no_use_persistent is None: # Default True for AMD and False for Nvidia. if torch.version.hip is not None: no_use_persistent = True else: no_use_persistent = False + # if use_persistent is explicitly requested, set o_use_persistent to False + # Get datatypes and constants to use. pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape