@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
3333 const int ne13,
3434 const int ne31,
3535 const int ne32,
36+ const int ne33,
3637 const int nb31,
3738 const int nb32,
39+ const int nb33,
3840 const int nb01,
3941 const int nb02,
4042 const int nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
521523template <int D, int ncols1, int ncols2> // D == head size
522524__launch_bounds__ (D, 1 )
523525static __global__ void flash_attn_stream_k_fixup(
524- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
526+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
525527 constexpr int ncols = ncols1*ncols2;
526528
527529 const int bidx0 = blockIdx .x ;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
535537 const int iter_k = ne11 / FATTN_KQ_STRIDE;
536538 const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
537539
538- const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
539- const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
540+ const int kbc0 = (bidx0 + 0 )*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
541+ const int kbc0_stop = (bidx0 + 1 )*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
540542
541543 const bool did_not_have_any_data = kbc0 == kbc0_stop;
542544 const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
545547 return ;
546548 }
547549
548- const int channel = kbc0 / (iter_k*iter_j);
549- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
550+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
551+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
552+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
550553
551554 if (jt*ncols1 + j >= ne01) {
552555 return ;
553556 }
554557
555- dst += jt*ne02*(ncols1*D) + channel *(ncols2*D) + (j*ne02 + c)*D + tid;
558+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head *(ncols2*D) + (j*ne02 + c)*D + tid;
556559
557560 // Load the partial result that needs a fixup:
558561 float dst_val = 0 .0f ;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
571574 int bidx = bidx0 - 1 ;
572575 int kbc_stop = kbc0;
573576 while (true ) {
574- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
577+ const int kbc = bidx*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
575578 if (kbc == kbc_stop) { // Did not have any data.
576579 bidx--;
577580 kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
617620 const float2 * __restrict__ VKQ_meta,
618621 float * __restrict__ dst,
619622 const int parallel_blocks) {
620- VKQ_parts += parallel_blocks*D * gridDim .z *blockIdx .x ;
621- VKQ_meta += parallel_blocks * gridDim .z *blockIdx .x ;
622- dst += D * gridDim .z *blockIdx .x ;
623+ // Dimension 0: threadIdx.x
624+ // Dimension 1: blockIdx.x
625+ // Dimension 2: blockIdx.y
626+ // Dimension 3: blockIdx.z
627+ // Memory layout is permuted with [0, 2, 1, 3]
628+
629+ const int ne01 = gridDim .x ;
630+ const int ne02 = gridDim .y ;
631+
632+ const int col = blockIdx .x ;
633+ const int head = blockIdx .y ;
634+ const int sequence = blockIdx .z ;
635+
636+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
637+
638+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
639+ VKQ_meta += j_dst_unrolled * parallel_blocks;
640+ dst += j_dst_unrolled * D;
623641
624642 const int tid = threadIdx .x ;
625643 __builtin_assume (tid < D);
626644
627645 extern __shared__ float2 meta[];
628646 for (int i = tid; i < 2 *parallel_blocks; i += D) {
629- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx . z *( 2 *parallel_blocks) + i];
647+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
630648 }
631649
632650 __syncthreads ();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
644662 const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
645663 *((uint32_t *) &KQ_max_scale) &= ftz_mask;
646664
647- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim . z *D + blockIdx . z * D + tid];
665+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
648666 VKQ_denominator += KQ_max_scale * meta[l].y ;
649667 }
650668
651- dst[blockIdx . z *D + tid] = VKQ_numerator / VKQ_denominator;
669+ dst[tid] = VKQ_numerator / VKQ_denominator;
652670}
653671
654672[[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
705723
706724 GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
707725
708- GGML_ASSERT (Q->ne [3 ] == 1 );
709-
710726 ggml_cuda_pool & pool = ctx.pool ();
711727 cudaStream_t main_stream = ctx.stream ();
712728 const int id = ggml_cuda_get_device ();
@@ -853,8 +869,8 @@ void launch_fattn(
853869 scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854870 Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
855871 K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
856- mask ? mask->ne [1 ] : 0 , mask ? mask->ne [2 ] : 0 ,
857- mask ? mask->nb [1 ] : 0 , mask ? mask->nb [2 ] : 0 ,
872+ mask ? mask->ne [1 ] : 0 , mask ? mask->ne [2 ] : 0 , mask ? mask-> ne [ 3 ] : 0 ,
873+ mask ? mask->nb [1 ] : 0 , mask ? mask->nb [2 ] : 0 , mask ? mask-> nb [ 3 ] : 0 ,
858874 Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
859875 nb11, nb12, nb13,
860876 nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
869885
870886 flash_attn_stream_k_fixup<DV, ncols1, ncols2>
871887 <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
872- ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
888+ ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], Q-> ne [ 3 ], K->ne [1 ]);
873889 }
874890 } else if (parallel_blocks > 1 ) {
875891 const dim3 block_dim_combine (DV, 1 , 1 );
876- const dim3 blocks_num_combine (Q->ne [1 ], 1 , blocks_num. z );
892+ const dim3 blocks_num_combine (Q->ne [1 ], Q-> ne [ 2 ], Q-> ne [ 3 ] );
877893 const size_t nbytes_shared_combine = parallel_blocks*sizeof (float2 );
878894
879895 flash_attn_combine_results<DV>
0 commit comments