Skip to content

Commit aa5aa01

Browse files
committed
CUDA: Improve flash decoding kernel occupancy for BS=1 case
Adds the following optimizations to the CUDA flash decoding code: - Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value. - Prefer vector flash attention kernels over MMA kernel for BS=1 This results in upto 15% perf improvement in gen phase throughput for large seq lengths. Issue: #12182
1 parent 2d011e6 commit aa5aa01

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ void launch_fattn(
698698
GGML_ASSERT(Q->ne[3] == 1);
699699

700700
GGML_ASSERT(stream_k || ncols2 == 1);
701-
const int parallel_blocks = Q->ne[1] <= ncols1 ? 4 : 1;
701+
const bool use_parallel_blocks = !stream_k && (Q->ne[1] <= ncols1) ? true : false;
702702

703703
ggml_cuda_pool & pool = ctx.pool();
704704
cudaStream_t main_stream = ctx.stream();
@@ -749,6 +749,8 @@ void launch_fattn(
749749
nb23 = nb23*bs*sizeof(half)/ts;
750750
}
751751

752+
int parallel_blocks = 1;
753+
752754
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
753755
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
754756

@@ -770,6 +772,21 @@ void launch_fattn(
770772

771773
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
772774
} else {
775+
if (use_parallel_blocks) {
776+
const int num_blocks_base = ntiles_x*Q->ne[2]*Q->ne[3];
777+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
778+
const int seqlen_tiles = (K->ne[1] + D - 1) / D;
779+
780+
// Determine the number of active blocks per SM
781+
int numActiveBlocks = 1;
782+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
783+
784+
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
785+
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
786+
// If there are not enough tiles to process, we can reduce the number of blocks
787+
parallel_blocks = std::max(std::min((nsm * numActiveBlocks) / num_blocks_base, seqlen_tiles), 1);
788+
}
789+
773790
blocks_num.x = ntiles_x;
774791
blocks_num.y = parallel_blocks;
775792
blocks_num.z = Q->ne[2]*Q->ne[3];

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,6 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
244244
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
245245
const ggml_tensor * KQV = dst;
246246
const ggml_tensor * Q = dst->src[0];
247-
const ggml_tensor * K = dst->src[1];
248-
const ggml_tensor * V = dst->src[2];
249-
const ggml_tensor * mask = dst->src[3];
250247

251248
ggml_cuda_set_device(ctx.device);
252249
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -296,10 +293,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
296293
return;
297294
}
298295

299-
const int gqa_ratio = Q->ne[2] / K->ne[2];
300-
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
301-
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
302-
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
296+
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0) {
303297
if (prec == GGML_PREC_DEFAULT) {
304298
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
305299
return;

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
#define cudaGraph_t hipGraph_t
130130
#define cudaStream_t hipStream_t
131131
#define cudaSuccess hipSuccess
132+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
132133
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
133134
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
134135
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,6 @@
134134
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135135
#define cudaStreamBeginCapture musaStreamBeginCapture
136136
#define cudaStreamEndCapture musaStreamEndCapture
137+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
137138

138139
typedef mt_bfloat16 nv_bfloat16;

0 commit comments

Comments
 (0)