44
55#define FATTN_KQ_STRIDE_TILE_F16 64
66
7- template <int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
7+ template <int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
88#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
99__launch_bounds__ (nwarps*WARP_SIZE, 1 )
1010#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
5858
5959 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
6060
61- const int ic0 = (blockIdx .x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
62- const int ip = blockIdx .x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
61+ const int ic0 = blockIdx .x * ncols; // Index of the Q/QKV column to work on.
6362
6463 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
65- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx .y + nb01*ic0);
66- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx .y / gqa_ratio));
67- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx .y / gqa_ratio)); // K and V have same shape
64+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx .z + nb01*ic0);
65+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx .z / gqa_ratio));
66+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx .z / gqa_ratio)); // K and V have same shape
6867 const half * maskh = (const half *) mask + ne11*ic0;
6968
7069 const int stride_KV2 = nb11 / sizeof (half2);
7170
72- const float slopef = get_alibi_slope (max_bias, blockIdx .y , n_head_log2, m0, m1);
71+ const float slopef = get_alibi_slope (max_bias, blockIdx .z , n_head_log2, m0, m1);
7372 const half slopeh = __float2half (slopef);
7473
7574 static_assert (D % (2 *WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
105104
106105 __syncthreads ();
107106
108- const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
109- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
107+ for (int k_VKQ_0 = blockIdx .y *FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *FATTN_KQ_STRIDE_TILE_F16) {
110108 // Calculate KQ tile and keep track of new maximum KQ values:
111109
112110 half kqmax_new[ncols/nwarps];
@@ -271,40 +269,40 @@ static __global__ void flash_attn_tile_ext_f16(
271269 const int i0 = i00 + 2 *threadIdx .x ;
272270
273271 half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2 *WARP_SIZE)];
274- if (parallel_blocks == 1 ) {
272+ if (gridDim . y == 1 ) {
275273 dst_val /= __half2half2 (kqsum_j);
276274 }
277- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip ;
278- dst[j_dst*D*gridDim .y + D*blockIdx .y + i0 + 0 ] = __low2float (dst_val);
279- dst[j_dst*D*gridDim .y + D*blockIdx .y + i0 + 1 ] = __high2float (dst_val);
275+ const int j_dst = (ic0 + j_VKQ)*gridDim . y + blockIdx . y ;
276+ dst[j_dst*D*gridDim .z + D*blockIdx .z + i0 + 0 ] = __low2float (dst_val);
277+ dst[j_dst*D*gridDim .z + D*blockIdx .z + i0 + 1 ] = __high2float (dst_val);
280278 }
281279
282- if (parallel_blocks != 1 && threadIdx .x == 0 ) {
283- dst_meta[(ic0 + j_VKQ)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip ] = make_float2 (kqmax[j_VKQ_0/nwarps], kqsum_j);
280+ if (gridDim . y != 1 && threadIdx .x == 0 ) {
281+ dst_meta[(( ic0 + j_VKQ)*gridDim .z + blockIdx .z ) * gridDim . y + blockIdx . y ] = make_float2 (kqmax[j_VKQ_0/nwarps], kqsum_j);
284282 }
285283 }
286284#else
287285 NO_DEVICE_CODE;
288286#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
289287}
290288
291- template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
289+ template <int cols_per_block, bool use_logit_softcap>
292290void launch_fattn_tile_f16_64_128 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293291 const ggml_tensor * Q = dst->src [0 ];
294292 switch (Q->ne [0 ]) {
295293 case 64 : {
296294 constexpr int D = 64 ;
297295 constexpr int nwarps = 8 ;
298296 constexpr size_t nbytes_shared = 0 ;
299- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
300- launch_fattn<D, cols_per_block, 1 , parallel_blocks, -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true );
297+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
298+ launch_fattn<D, cols_per_block, 1 , -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true , false );
301299 } break ;
302300 case 128 : {
303301 constexpr int D = 128 ;
304302 constexpr int nwarps = 8 ;
305303 constexpr size_t nbytes_shared = 0 ;
306- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
307- launch_fattn<D, cols_per_block, 1 , parallel_blocks, -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true );
304+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
305+ launch_fattn<D, cols_per_block, 1 , -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true , false );
308306 } break ;
309307 default : {
310308 GGML_ABORT (" FlashAttention without tensor cores only supports head sizes 64 and 128." );
@@ -324,37 +322,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
324322
325323 if (Q->ne [1 ] <= 16 ) {
326324 constexpr int cols_per_block = 16 ;
327- constexpr int parallel_blocks = 4 ;
328325 if (logit_softcap == 0 .0f ) {
329326 constexpr bool use_logit_softcap = false ;
330- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
327+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
331328 } else {
332329 constexpr bool use_logit_softcap = true ;
333- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
334- }
335- return ;
336- }
337-
338- if (Q->ne [1 ] <= 32 ) {
339- constexpr int cols_per_block = 32 ;
340- constexpr int parallel_blocks = 4 ;
341- if (logit_softcap == 0 .0f ) {
342- constexpr bool use_logit_softcap = false ;
343- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
344- } else {
345- constexpr bool use_logit_softcap = true ;
346- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
330+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
347331 }
348332 return ;
349333 }
350334
351335 constexpr int cols_per_block = 32 ;
352- constexpr int parallel_blocks = 1 ;
353336 if (logit_softcap == 0 .0f ) {
354337 constexpr bool use_logit_softcap = false ;
355- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
338+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
356339 } else {
357340 constexpr bool use_logit_softcap = true ;
358- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
341+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
359342 }
360343}
0 commit comments