@@ -7506,8 +7506,8 @@ static __global__ void flash_attn_f32(
7506
7506
}
7507
7507
}
7508
7508
7509
- template <int D, int ncols> // D head size
7510
- __launch_bounds__ (ncols == 8 ? (D + D % 32 ) : 2*D, 1)
7509
+ template <int D, int ncols> // D == head size
7510
+ __launch_bounds__ (ncols == 8 || D > 128 ? D : 2 *D, 1 )
7511
7511
static __global__ void flash_attn_ext_f16(
7512
7512
const char * __restrict__ Q,
7513
7513
const char * __restrict__ K,
@@ -7545,9 +7545,11 @@ static __global__ void flash_attn_ext_f16(
7545
7545
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16 , half, nvcuda::wmma::col_major> frag_b;
7546
7546
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16 , half> frag_c;
7547
7547
7548
- constexpr int nwarps = D / frag_m;
7548
+ constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/ 2 ) / frag_m;
7549
7549
constexpr int nthreads = nwarps*WARP_SIZE;
7550
7550
static_assert (nthreads % D == 0 , " nthreads not divisible by D." );
7551
+ constexpr int tc_vals_per_iter = nwarps*frag_m;
7552
+ static_assert (D % tc_vals_per_iter == 0 , " D not divisible by tensor core vals per iter." );
7551
7553
const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
7552
7554
__builtin_assume (tid < nthreads);
7553
7555
constexpr int D_padded = D + 8 ; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts.
@@ -7608,25 +7610,28 @@ static __global__ void flash_attn_ext_f16(
7608
7610
const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx .y < ne11;
7609
7611
7610
7612
// Calculate tile of KQ:
7611
- frag_c KQ_c[ncols/frag_n];
7612
7613
#pragma unroll
7613
- for (int j = 0 ; j < ncols/frag_n; ++j) {
7614
- nvcuda::wmma::fill_fragment (KQ_c[j], 0 .0f );
7615
- }
7616
- if (has_valid_data) {
7614
+ for (int i_KQ_0 = 0 ; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7615
+ frag_c KQ_c[ncols/frag_n];
7617
7616
#pragma unroll
7618
- for (int k_KQ_0 = 0 ; k_KQ_0 < D; k_KQ_0 += 16 ) {
7619
- frag_a_K K_a;
7620
- nvcuda::wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
7617
+ for (int j = 0 ; j < ncols/frag_n; ++j) {
7618
+ nvcuda::wmma::fill_fragment (KQ_c[j], 0 .0f );
7619
+ }
7620
+ if (has_valid_data) {
7621
7621
#pragma unroll
7622
- for (int j = 0 ; j < ncols/frag_n; ++j) {
7623
- nvcuda::wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
7622
+ for (int k_KQ_0 = 0 ; k_KQ_0 < D; k_KQ_0 += 16 ) {
7623
+ frag_a_K K_a;
7624
+ nvcuda::wmma::load_matrix_sync (K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
7625
+ #pragma unroll
7626
+ for (int j = 0 ; j < ncols/frag_n; ++j) {
7627
+ nvcuda::wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
7628
+ }
7624
7629
}
7625
7630
}
7626
- }
7627
7631
#pragma unroll
7628
- for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
7629
- nvcuda::wmma::store_matrix_sync (KQ + j0*D_padded + frag_m*threadIdx .y , KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7632
+ for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
7633
+ nvcuda::wmma::store_matrix_sync (KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx .y , KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7634
+ }
7630
7635
}
7631
7636
7632
7637
__syncthreads ();
@@ -7687,31 +7692,40 @@ static __global__ void flash_attn_ext_f16(
7687
7692
}
7688
7693
}
7689
7694
7690
- frag_c VKQ_c[ncols/frag_n];
7695
+ frag_c VKQ_c[D/tc_vals_per_iter][ ncols/frag_n];
7691
7696
#pragma unroll
7692
- for (int j = 0 ; j < ncols/frag_n; ++j) {
7693
- nvcuda::wmma::fill_fragment (VKQ_c[j], 0 .0f );
7694
- }
7695
-
7696
- #pragma unroll
7697
- for (int k0 = 0 ; k0 < D; k0 += 16 ) {
7698
- if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
7699
- break ;
7697
+ for (int i_KQ_0 = 0 ; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7698
+ #pragma unroll
7699
+ for (int j = 0 ; j < ncols/frag_n; ++j) {
7700
+ nvcuda::wmma::fill_fragment (VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0 .0f );
7700
7701
}
7701
7702
7702
- frag_a_V v_a;
7703
- nvcuda::wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k0)*stride_KV + frag_m*threadIdx .y , stride_KV);
7704
- #pragma unroll
7705
- for (int j = 0 ; j < ncols/frag_n; ++j) {
7706
- nvcuda::wmma::mma_sync (VKQ_c[j], v_a, KQ_b[k0/16 ][j], VKQ_c[j]);
7703
+ #pragma unroll
7704
+ for (int k0 = 0 ; k0 < D; k0 += 16 ) {
7705
+ if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
7706
+ break ;
7707
+ }
7708
+
7709
+ frag_a_V v_a;
7710
+ nvcuda::wmma::load_matrix_sync (v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx .y , stride_KV);
7711
+ #pragma unroll
7712
+ for (int j = 0 ; j < ncols/frag_n; ++j) {
7713
+ nvcuda::wmma::mma_sync (VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16 ][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]);
7714
+ }
7707
7715
}
7708
7716
}
7709
7717
7710
7718
__syncthreads ();
7711
7719
7712
7720
#pragma unroll
7713
- for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
7714
- nvcuda::wmma::store_matrix_sync (KQ + j0*D_padded + frag_m*threadIdx .y , VKQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7721
+ for (int i_KQ_0 = 0 ; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7722
+ #pragma unroll
7723
+ for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
7724
+ nvcuda::wmma::store_matrix_sync (
7725
+ KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx .y ,
7726
+ VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n],
7727
+ D_padded, nvcuda::wmma::mem_col_major);
7728
+ }
7715
7729
}
7716
7730
7717
7731
__syncthreads ();
@@ -11453,7 +11467,7 @@ inline void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml
11453
11467
cols_per_block = 8 ;
11454
11468
}
11455
11469
const int frag_m = cols_per_block == 8 ? 32 : 16 ;
11456
- const int nwarps = Q->ne [0 ] / frag_m;
11470
+ const int nwarps = ( Q->ne [0 ] <= 128 || cols_per_block == 8 ? Q-> ne [ 0 ] : Q-> ne [ 0 ]/ 2 ) / frag_m;
11457
11471
const dim3 blocks_num ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block, Q->ne [2 ], Q->ne [3 ]);
11458
11472
const dim3 block_dim (WARP_SIZE, nwarps, 1 );
11459
11473
const size_t shmem = 0 ;
0 commit comments