From b82ac615be2a33c224ae017d5d81f422427a2697 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Sun, 22 Dec 2024 22:45:29 -0800 Subject: [PATCH] Fix out-of-bound load in row scaling Summary: Masking is needed to avoid out-of-bound reference for the last row. This fixes an illegal access error. Differential Revision: D67588103 --- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 23b501d45c..1f386c84c4 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2232,7 +2232,9 @@ def _kernel_scale_fp8_row( # Iterate over chunks of the row and apply scales. for _k in range(0, tl.cdiv(N, BLOCK_SIZE)): - a = tl.load(A + pid * stride_am + n_offset * stride_an) + a = tl.load( + A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0 + ) col_scale = tl.load(w_scale + n_offset) scaled_a = a * row_scale * col_scale tl.store(