diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 23b501d45c..1f386c84c4 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2232,7 +2232,9 @@ def _kernel_scale_fp8_row( # Iterate over chunks of the row and apply scales. for _k in range(0, tl.cdiv(N, BLOCK_SIZE)): - a = tl.load(A + pid * stride_am + n_offset * stride_an) + a = tl.load( + A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0 + ) col_scale = tl.load(w_scale + n_offset) scaled_a = a * row_scale * col_scale tl.store(