Skip to content

Commit 8ca8cfb

Browse files
CUDA: determine FA parallel blocks at runtime
1 parent 3d652bf commit 8ca8cfb

File tree

9 files changed

+116
-197
lines changed

9 files changed

+116
-197
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ typedef void (* fattn_kernel_t)(
4545
const int ne0,
4646
const int ne1,
4747
const int ne2,
48-
const int ne3);
48+
const int ne3,
49+
const int parallel_blocks);
4950

5051
typedef half (*vec_dot_KQ_f16_t)(
5152
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -612,37 +613,36 @@ static __global__ void flash_attn_stream_k_fixup(
612613
*dst = dst_val / rowsum;
613614
}
614615

615-
template<int D, int parallel_blocks> // D == head size
616+
template<int D> // D == head size
616617
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
617618
__launch_bounds__(D, 1)
618619
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
619620
static __global__ void flash_attn_combine_results(
620621
const float * __restrict__ VKQ_parts,
621622
const float2 * __restrict__ VKQ_meta,
622-
float * __restrict__ dst) {
623+
float * __restrict__ dst,
624+
const int parallel_blocks) {
623625
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
624626
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
625627
dst += D * gridDim.y*blockIdx.x;
626628

627629
const int tid = threadIdx.x;
628630
__builtin_assume(tid < D);
629631

630-
__shared__ float2 meta[parallel_blocks];
632+
extern __shared__ float2 meta[];
631633
if (tid < 2*parallel_blocks) {
632634
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
633635
}
634636

635637
__syncthreads();
636638

637639
float kqmax = meta[0].x;
638-
#pragma unroll
639640
for (int l = 1; l < parallel_blocks; ++l) {
640641
kqmax = max(kqmax, meta[l].x);
641642
}
642643

643644
float VKQ_numerator = 0.0f;
644645
float VKQ_denominator = 0.0f;
645-
#pragma unroll
646646
for (int l = 0; l < parallel_blocks; ++l) {
647647
const float diff = meta[l].x - kqmax;
648648
const float KQ_max_scale = expf(diff);
@@ -677,11 +677,10 @@ static void on_no_fattn_vec_case(const int D) {
677677
}
678678
}
679679

680-
// parallel_blocks == 0 is stream-k decomposition
681-
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
680+
template <int D, int ncols1, int ncols2, int KQ_stride>
682681
void launch_fattn(
683682
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
684-
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
683+
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V, const bool stream_k
685684
) {
686685
constexpr int ncols = ncols1 * ncols2;
687686

@@ -704,6 +703,9 @@ void launch_fattn(
704703

705704
GGML_ASSERT(Q->ne[3] == 1);
706705

706+
GGML_ASSERT(stream_k || ncols2 == 1);
707+
const int parallel_blocks = Q->ne[1] <= ncols1 ? 4 : 1;
708+
707709
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
708710

709711
ggml_cuda_pool & pool = ctx.pool();
@@ -760,7 +762,7 @@ void launch_fattn(
760762

761763
const dim3 block_dim(warp_size, nwarps, 1);
762764
dim3 blocks_num;
763-
if (parallel_blocks == 0) {
765+
if (stream_k) {
764766
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
765767
const int max_blocks = 2*nsm;
766768
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
@@ -811,19 +813,20 @@ void launch_fattn(
811813
K_data,
812814
V_data,
813815
mask ? ((const char *) mask->data) : nullptr,
814-
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
816+
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
815817
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
816818
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
817819
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
818820
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
819821
Q->nb[1], Q->nb[2], Q->nb[3],
820822
nb11, nb12, nb13,
821823
nb21, nb22, nb23,
822-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
824+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3],
825+
parallel_blocks
823826
);
824827
CUDA_CHECK(cudaGetLastError());
825828

826-
if constexpr (parallel_blocks == 0) {
829+
if (stream_k) {
827830
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
828831
const dim3 block_dim_combine(D, 1, 1);
829832
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
@@ -832,13 +835,14 @@ void launch_fattn(
832835
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
833836
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
834837
}
835-
} else if constexpr (parallel_blocks > 1) {
838+
} else if (parallel_blocks > 1) {
836839
const dim3 block_dim_combine(D, 1, 1);
837840
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
841+
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
838842

839-
flash_attn_combine_results<D, parallel_blocks>
840-
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
841-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
843+
flash_attn_combine_results<D>
844+
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
845+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
842846
}
843847
CUDA_CHECK(cudaGetLastError());
844848
}

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,8 @@ static __global__ void flash_attn_ext_f16(
838838
const int ne0,
839839
const int ne1,
840840
const int ne2,
841-
const int ne3) {
841+
const int ne3,
842+
const int parallel_blocks) {
842843
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
843844

844845
// Skip unused kernel variants for faster compilation:
@@ -970,7 +971,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
970971
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
971972
}
972973

973-
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
974+
launch_fattn<D, ncols1, ncols2, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true, true);
974975
}
975976

976977

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#define FATTN_KQ_STRIDE_TILE_F16 64
66

7-
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
7+
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
88
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
99
__launch_bounds__(nwarps*WARP_SIZE, 1)
1010
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -43,7 +43,8 @@ static __global__ void flash_attn_tile_ext_f16(
4343
const int ne0,
4444
const int ne1,
4545
const int ne2,
46-
const int ne3) {
46+
const int ne3,
47+
const int parallel_blocks) {
4748
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
4849

4950
// Skip unused kernel variants for faster compilation:
@@ -105,8 +106,7 @@ static __global__ void flash_attn_tile_ext_f16(
105106

106107
__syncthreads();
107108

108-
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
109-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
109+
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
110110
// Calculate KQ tile and keep track of new maximum KQ values:
111111

112112
half kqmax_new[ncols/nwarps];
@@ -288,23 +288,23 @@ static __global__ void flash_attn_tile_ext_f16(
288288
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
289289
}
290290

291-
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
291+
template <int cols_per_block, bool use_logit_softcap>
292292
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293293
const ggml_tensor * Q = dst->src[0];
294294
switch (Q->ne[0]) {
295295
case 64: {
296296
constexpr int D = 64;
297297
constexpr int nwarps = 8;
298298
constexpr size_t nbytes_shared = 0;
299-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
300-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
299+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
300+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
301301
} break;
302302
case 128: {
303303
constexpr int D = 128;
304304
constexpr int nwarps = 8;
305305
constexpr size_t nbytes_shared = 0;
306-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
307-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
306+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
307+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
308308
} break;
309309
default: {
310310
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
324324

325325
if (Q->ne[1] <= 16) {
326326
constexpr int cols_per_block = 16;
327-
constexpr int parallel_blocks = 4;
328327
if (logit_softcap == 0.0f) {
329328
constexpr bool use_logit_softcap = false;
330-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
329+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
331330
} else {
332331
constexpr bool use_logit_softcap = true;
333-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
334-
}
335-
return;
336-
}
337-
338-
if (Q->ne[1] <= 32) {
339-
constexpr int cols_per_block = 32;
340-
constexpr int parallel_blocks = 4;
341-
if (logit_softcap == 0.0f) {
342-
constexpr bool use_logit_softcap = false;
343-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
344-
} else {
345-
constexpr bool use_logit_softcap = true;
346-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
332+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
347333
}
348334
return;
349335
}
350336

351337
constexpr int cols_per_block = 32;
352-
constexpr int parallel_blocks = 1;
353338
if (logit_softcap == 0.0f) {
354339
constexpr bool use_logit_softcap = false;
355-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
340+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
356341
} else {
357342
constexpr bool use_logit_softcap = true;
358-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
343+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
359344
}
360345
}

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#define FATTN_KQ_STRIDE_TILE_F32 32
66

7-
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
7+
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
88
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
99
__launch_bounds__(nwarps*WARP_SIZE, 1)
1010
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -43,7 +43,8 @@ static __global__ void flash_attn_tile_ext_f32(
4343
const int ne0,
4444
const int ne1,
4545
const int ne2,
46-
const int ne3) {
46+
const int ne3,
47+
const int parallel_blocks) {
4748
#ifdef FLASH_ATTN_AVAILABLE
4849

4950
// Skip unused kernel variants for faster compilation:
@@ -103,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f32(
103104

104105
__syncthreads();
105106

106-
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
107-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
107+
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
108108
// Calculate KQ tile and keep track of new maximum KQ values:
109109

110110
float kqmax_new[ncols/nwarps];
@@ -287,23 +287,23 @@ static __global__ void flash_attn_tile_ext_f32(
287287
#endif // FLASH_ATTN_AVAILABLE
288288
}
289289

290-
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
290+
template <int cols_per_block, bool use_logit_softcap>
291291
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
292292
const ggml_tensor * Q = dst->src[0];
293293
switch (Q->ne[0]) {
294294
case 64: {
295295
constexpr int D = 64;
296296
constexpr int nwarps = 8;
297297
constexpr size_t nbytes_shared = 0;
298-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
299-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
298+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
299+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
300300
} break;
301301
case 128: {
302302
constexpr int D = 128;
303303
constexpr int nwarps = 8;
304304
constexpr size_t nbytes_shared = 0;
305-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
306-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
305+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
306+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true, false);
307307
} break;
308308
default: {
309309
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -320,37 +320,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
320320

321321
if (Q->ne[1] <= 16) {
322322
constexpr int cols_per_block = 16;
323-
constexpr int parallel_blocks = 4;
324323
if (logit_softcap == 0.0f) {
325324
constexpr bool use_logit_softcap = false;
326-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
325+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
327326
} else {
328327
constexpr bool use_logit_softcap = true;
329-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
330-
}
331-
return;
332-
}
333-
334-
if (Q->ne[1] <= 32) {
335-
constexpr int cols_per_block = 32;
336-
constexpr int parallel_blocks = 4;
337-
if (logit_softcap == 0.0f) {
338-
constexpr bool use_logit_softcap = false;
339-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
340-
} else {
341-
constexpr bool use_logit_softcap = true;
342-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
328+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
343329
}
344330
return;
345331
}
346332

347333
constexpr int cols_per_block = 32;
348-
constexpr int parallel_blocks = 1;
349334
if (logit_softcap == 0.0f) {
350335
constexpr bool use_logit_softcap = false;
351-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
336+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
352337
} else {
353338
constexpr bool use_logit_softcap = true;
354-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
339+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
355340
}
356341
}

0 commit comments

Comments
 (0)