Skip to content

Commit 66d873b

Browse files
consider tail effects for parallel_blocks
1 parent aa5aa01 commit 66d873b

File tree

8 files changed

+57
-31
lines changed

8 files changed

+57
-31
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,8 @@ static void on_no_fattn_vec_case(const int D) {
672672

673673
template <int D, int ncols1, int ncols2, int KQ_stride>
674674
void launch_fattn(
675-
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
676-
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V, const bool stream_k,
677-
const int warp_size = WARP_SIZE
675+
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
676+
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
678677
) {
679678
constexpr int ncols = ncols1 * ncols2;
680679

@@ -697,9 +696,6 @@ void launch_fattn(
697696

698697
GGML_ASSERT(Q->ne[3] == 1);
699698

700-
GGML_ASSERT(stream_k || ncols2 == 1);
701-
const bool use_parallel_blocks = !stream_k && (Q->ne[1] <= ncols1) ? true : false;
702-
703699
ggml_cuda_pool & pool = ctx.pool();
704700
cudaStream_t main_stream = ctx.stream();
705701
const int id = ggml_cuda_get_device();
@@ -772,19 +768,38 @@ void launch_fattn(
772768

773769
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
774770
} else {
775-
if (use_parallel_blocks) {
776-
const int num_blocks_base = ntiles_x*Q->ne[2]*Q->ne[3];
777-
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
778-
const int seqlen_tiles = (K->ne[1] + D - 1) / D;
779-
780-
// Determine the number of active blocks per SM
781-
int numActiveBlocks = 1;
782-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
783-
784-
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
785-
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
786-
// If there are not enough tiles to process, we can reduce the number of blocks
787-
parallel_blocks = std::max(std::min((nsm * numActiveBlocks) / num_blocks_base, seqlen_tiles), 1);
771+
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
772+
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
773+
774+
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
775+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
776+
777+
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
778+
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
779+
780+
// parallel_blocks must not be larger than what the tensor size allows:
781+
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
782+
783+
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
784+
// Test whether parallel_blocks can be set to a higher value for better efficiency.
785+
const int blocks_per_wave = nsm * max_blocks_per_sm;
786+
int nwaves_best = 0;
787+
int efficiency_percent_best = 0;
788+
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
789+
const int nblocks_total = ntiles_total * parallel_blocks_test;
790+
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
791+
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
792+
793+
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
794+
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
795+
break;
796+
}
797+
798+
if (efficiency_percent > efficiency_percent_best) {
799+
nwaves_best = nwaves;
800+
efficiency_percent_best = efficiency_percent;
801+
parallel_blocks = parallel_blocks_test;
802+
}
788803
}
789804

790805
blocks_num.x = ntiles_x;

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
970970
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
971971
}
972972

973-
launch_fattn<D, ncols1, ncols2, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true, true);
973+
launch_fattn<D, ncols1, ncols2, KQ_per_iter>
974+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
974975
}
975976

976977

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,16 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
295295
constexpr int nwarps = 8;
296296
constexpr size_t nbytes_shared = 0;
297297
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
298-
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
298+
launch_fattn<D, cols_per_block, 1, -1>
299+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
299300
} break;
300301
case 128: {
301302
constexpr int D = 128;
302303
constexpr int nwarps = 8;
303304
constexpr size_t nbytes_shared = 0;
304305
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
305-
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
306+
launch_fattn<D, cols_per_block, 1, -1>
307+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
306308
} break;
307309
default: {
308310
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,14 +294,16 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
294294
constexpr int nwarps = 8;
295295
constexpr size_t nbytes_shared = 0;
296296
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
297-
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
297+
launch_fattn<D, cols_per_block, 1, -1>
298+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
298299
} break;
299300
case 128: {
300301
constexpr int D = 128;
301302
constexpr int nwarps = 8;
302303
constexpr size_t nbytes_shared = 0;
303304
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
304-
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
305+
launch_fattn<D, cols_per_block, 1, -1>
306+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
305307
} break;
306308
default: {
307309
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
303303
constexpr bool need_f16_K = D != 128;
304304
constexpr bool need_f16_V = D != 128 && D != 64;
305305
constexpr size_t nbytes_shared = 0;
306-
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V, false);
306+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
307307
}
308308

309309
template <int D, ggml_type type_K, ggml_type type_V>

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
288288
constexpr bool need_f16_K = D != 128;
289289
constexpr bool need_f16_V = D != 128 && D != 64;
290290
constexpr size_t nbytes_shared = 0;
291-
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V, false);
291+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
292292
}
293293

294294
template <int D, ggml_type type_K, ggml_type type_V>

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
480480
fattn_kernel = flash_attn_ext_f16<
481481
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
482482
}
483-
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, false, warp_size);
483+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
484484
}
485485

486486
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-cuda/fattn.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
244244
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
245245
const ggml_tensor * KQV = dst;
246246
const ggml_tensor * Q = dst->src[0];
247+
const ggml_tensor * K = dst->src[1];
248+
const ggml_tensor * V = dst->src[2];
249+
const ggml_tensor * mask = dst->src[3];
247250

248251
ggml_cuda_set_device(ctx.device);
249252
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -293,14 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
293296
return;
294297
}
295298

296-
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0) {
299+
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
300+
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
301+
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
302+
const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
303+
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
297304
if (prec == GGML_PREC_DEFAULT) {
298305
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
299-
return;
300-
} else if(Q->ne[0] <= 128) {
306+
} else {
301307
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
302-
return;
303308
}
309+
return;
304310
}
305311

306312
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:

0 commit comments

Comments
 (0)