@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
33
33
const int ne13,
34
34
const int ne31,
35
35
const int ne32,
36
+ const int ne33,
36
37
const int nb31,
37
38
const int nb32,
39
+ const int nb33,
38
40
const int nb01,
39
41
const int nb02,
40
42
const int nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
521
523
template <int D, int ncols1, int ncols2> // D == head size
522
524
__launch_bounds__ (D, 1 )
523
525
static __global__ void flash_attn_stream_k_fixup(
524
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
526
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
525
527
constexpr int ncols = ncols1*ncols2;
526
528
527
529
const int bidx0 = blockIdx .x ;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
535
537
const int iter_k = ne11 / FATTN_KQ_STRIDE;
536
538
const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
537
539
538
- const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
539
- const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
540
+ const int kbc0 = (bidx0 + 0 )*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
541
+ const int kbc0_stop = (bidx0 + 1 )*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
540
542
541
543
const bool did_not_have_any_data = kbc0 == kbc0_stop;
542
544
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
545
547
return ;
546
548
}
547
549
548
- const int channel = kbc0 / (iter_k*iter_j);
549
- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
550
+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
551
+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
552
+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
550
553
551
554
if (jt*ncols1 + j >= ne01) {
552
555
return ;
553
556
}
554
557
555
- dst += jt*ne02*(ncols1*D) + channel *(ncols2*D) + (j*ne02 + c)*D + tid;
558
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head *(ncols2*D) + (j*ne02 + c)*D + tid;
556
559
557
560
// Load the partial result that needs a fixup:
558
561
float dst_val = 0 .0f ;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
571
574
int bidx = bidx0 - 1 ;
572
575
int kbc_stop = kbc0;
573
576
while (true ) {
574
- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
577
+ const int kbc = bidx*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
575
578
if (kbc == kbc_stop) { // Did not have any data.
576
579
bidx--;
577
580
kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
617
620
const float2 * __restrict__ VKQ_meta,
618
621
float * __restrict__ dst,
619
622
const int parallel_blocks) {
620
- VKQ_parts += parallel_blocks*D * gridDim .z *blockIdx .x ;
621
- VKQ_meta += parallel_blocks * gridDim .z *blockIdx .x ;
622
- dst += D * gridDim .z *blockIdx .x ;
623
+ // Dimension 0: threadIdx.x
624
+ // Dimension 1: blockIdx.x
625
+ // Dimension 2: blockIdx.y
626
+ // Dimension 3: blockIdx.z
627
+ // Memory layout is permuted with [0, 2, 1, 3]
628
+
629
+ const int ne01 = gridDim .x ;
630
+ const int ne02 = gridDim .y ;
631
+
632
+ const int col = blockIdx .x ;
633
+ const int head = blockIdx .y ;
634
+ const int sequence = blockIdx .z ;
635
+
636
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
637
+
638
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
639
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
640
+ dst += j_dst_unrolled * D;
623
641
624
642
const int tid = threadIdx .x ;
625
643
__builtin_assume (tid < D);
626
644
627
645
extern __shared__ float2 meta[];
628
646
for (int i = tid; i < 2 *parallel_blocks; i += D) {
629
- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx . z *( 2 *parallel_blocks) + i];
647
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
630
648
}
631
649
632
650
__syncthreads ();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
644
662
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
645
663
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
646
664
647
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim . z *D + blockIdx . z * D + tid];
665
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
648
666
VKQ_denominator += KQ_max_scale * meta[l].y ;
649
667
}
650
668
651
- dst[blockIdx . z *D + tid] = VKQ_numerator / VKQ_denominator;
669
+ dst[tid] = VKQ_numerator / VKQ_denominator;
652
670
}
653
671
654
672
[[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
705
723
706
724
GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
707
725
708
- GGML_ASSERT (Q->ne [3 ] == 1 );
709
-
710
726
ggml_cuda_pool & pool = ctx.pool ();
711
727
cudaStream_t main_stream = ctx.stream ();
712
728
const int id = ggml_cuda_get_device ();
@@ -853,8 +869,8 @@ void launch_fattn(
853
869
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854
870
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
855
871
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
856
- mask ? mask->ne [1 ] : 0 , mask ? mask->ne [2 ] : 0 ,
857
- mask ? mask->nb [1 ] : 0 , mask ? mask->nb [2 ] : 0 ,
872
+ mask ? mask->ne [1 ] : 0 , mask ? mask->ne [2 ] : 0 , mask ? mask-> ne [ 3 ] : 0 ,
873
+ mask ? mask->nb [1 ] : 0 , mask ? mask->nb [2 ] : 0 , mask ? mask-> nb [ 3 ] : 0 ,
858
874
Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
859
875
nb11, nb12, nb13,
860
876
nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
869
885
870
886
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
871
887
<<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
872
- ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
888
+ ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], Q-> ne [ 3 ], K->ne [1 ]);
873
889
}
874
890
} else if (parallel_blocks > 1 ) {
875
891
const dim3 block_dim_combine (DV, 1 , 1 );
876
- const dim3 blocks_num_combine (Q->ne [1 ], 1 , blocks_num. z );
892
+ const dim3 blocks_num_combine (Q->ne [1 ], Q-> ne [ 2 ], Q-> ne [ 3 ] );
877
893
const size_t nbytes_shared_combine = parallel_blocks*sizeof (float2 );
878
894
879
895
flash_attn_combine_results<DV>
0 commit comments