@@ -2819,22 +2819,25 @@ kernel void kernel_flash_attn_ext(
28192819 float S[Q] = { [0 ... Q-1 ] = 0 .0h };
28202820 float M[Q] = { [0 ... Q-1 ] = -FLT_MAX/2 };
28212821
2822+ // thread indices inside the simdgroup
2823+ const short tx = tiisg%4 ;
2824+ const short ty = tiisg/4 ;
2825+
28222826 // assume K and V are same shape
28232827 const short ne22 = ne12;
28242828 const short ne23 = ne13;
28252829
2826- // broadcast
2830+ // broadcast k
28272831 const short rk2 = ne02/ne12;
28282832 const short rk3 = ne03/ne13;
28292833
2830- const short rv2 = ne02/ne22;
2831- const short rv3 = ne03/ne23;
2832-
2833- // k indices
28342834 const short ik2 = iq2/rk2;
28352835 const short ik3 = iq3/rk3;
28362836
2837- // v indices
2837+ // broadcast v
2838+ const short rv2 = ne02/ne22;
2839+ const short rv3 = ne03/ne23;
2840+
28382841 const short iv2 = iq2/rv2;
28392842 const short iv3 = iq3/rv3;
28402843
@@ -2885,15 +2888,12 @@ kernel void kernel_flash_attn_ext(
28852888 }
28862889 } else {
28872890 for (short ii = 0 ; ii < D16; ii += 4 ) {
2888- const short i = tiisg%4 ;
2889- const short j = tiisg/4 ;
2890-
2891- device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8 *cc + j)*nb11 + ik2*nb12 + ik3*nb13));
2891+ device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8 *cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
28922892
28932893 if (D16%4 == 0 ) {
28942894 half4x4 tmp;
2895- dequantize_func (pk4 + (ii + i )/nl, (ii + i )%nl, tmp);
2896- skv4[4 *j + i ] = tmp;
2895+ dequantize_func (pk4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
2896+ skv4[4 *ty + tx ] = tmp;
28972897
28982898 simdgroup_barrier (mem_flags::mem_threadgroup);
28992899
@@ -2908,10 +2908,10 @@ kernel void kernel_flash_attn_ext(
29082908 simdgroup_multiply_accumulate (mqk, mq[2 *(ii + k) + 1 ], mk, mqk);
29092909 }
29102910 } else {
2911- if (ii + i < D16) {
2911+ if (ii + tx < D16) {
29122912 half4x4 tmp;
2913- dequantize_func (pk4 + (ii + i )/nl, (ii + i )%nl, tmp);
2914- skv4[4 *j + i ] = tmp;
2913+ dequantize_func (pk4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
2914+ skv4[4 *ty + tx ] = tmp;
29152915 }
29162916
29172917 simdgroup_barrier (mem_flags::mem_threadgroup);
@@ -3006,15 +3006,12 @@ kernel void kernel_flash_attn_ext(
30063006 }
30073007 } else {
30083008 for (short ii = 0 ; ii < D16; ii += 4 ) {
3009- const short i = tiisg%4 ;
3010- const short j = tiisg/4 ;
3011-
3012- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8 *cc + j)*nb21 + iv2*nb22 + iv3*nb23));
3009+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8 *cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
30133010
30143011 if (D16%4 == 0 ) {
30153012 half4x4 tmp;
3016- dequantize_func (pv4 + (ii + i )/nl, (ii + i )%nl, tmp);
3017- skv4[4 *j + i ] = tmp;
3013+ dequantize_func (pv4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
3014+ skv4[4 *ty + tx ] = tmp;
30183015
30193016 simdgroup_barrier (mem_flags::mem_threadgroup);
30203017
@@ -3029,10 +3026,10 @@ kernel void kernel_flash_attn_ext(
30293026 simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms, mv, lo[2 *(ii + k) + 1 ]);
30303027 }
30313028 } else {
3032- if (ii + i < D16) {
3029+ if (ii + tx < D16) {
30333030 half4x4 tmp;
3034- dequantize_func (pv4 + (ii + i )/nl, (ii + i )%nl, tmp);
3035- skv4[4 *j + i ] = tmp;
3031+ dequantize_func (pv4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
3032+ skv4[4 *ty + tx ] = tmp;
30363033 }
30373034
30383035 simdgroup_barrier (mem_flags::mem_threadgroup);
@@ -3187,6 +3184,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_
31873184template [[host_name(" kernel_flash_attn_ext_q8_0_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2 , dequantize_q8_0, 128 >;
31883185template [[host_name(" kernel_flash_attn_ext_q8_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2 , dequantize_q8_0, 256 >;
31893186
3187+ // NOTE: can use half instead of float precision for some extra perf
31903188// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
31913189template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &), short D, short Q = 1 , short C = 32 >
31923190kernel void kernel_flash_attn_ext_vec (
@@ -3239,26 +3237,15 @@ kernel void kernel_flash_attn_ext_vec(
32393237
32403238 const short T = D + 2 *nsg*SH; // shared memory size per query in (half)
32413239
3242- float slope = 1 .0f ;
3243-
3244- // ALiBi
3245- if (max_bias > 0 .0f ) {
3246- const uint32_t h = iq2;
3247-
3248- const float base = h < n_head_log2 ? m0 : m1;
3249- const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
3250-
3251- slope = pow (base, exp);
3252- }
3253-
3254- // threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3255- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
3256- threadgroup float * ss = (threadgroup float *) (shared + 2 *sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
3257- threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2 *sgitg*SH + 1 *D); // same as above but in half4
3258- threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3240+ // threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3241+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
3242+ threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0 *D); // same as above but in half4x4
3243+ threadgroup float * ss = (threadgroup float *) (shared + 2 *sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
3244+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2 *sgitg*SH + 1 *D); // same as above but in half4
3245+ threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2 *sgitg*D + Q*T); // scratch buffer for the results
32593246
32603247 // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3261- half4x4 lo[D16/NW4];
3248+ float4x4 lo[D16/NW4];
32623249
32633250 // load heads from Q to shared memory
32643251 device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3273,7 +3260,7 @@ kernel void kernel_flash_attn_ext_vec(
32733260
32743261 // zero out lo
32753262 for (short i = 0 ; i < D16/NW4; i += NW4) {
3276- lo[i] = half4x4 (0 .0h);
3263+ lo[i] = float4x4 (0 .0h);
32773264 }
32783265
32793266 // zero out shared memory SH
@@ -3284,42 +3271,53 @@ kernel void kernel_flash_attn_ext_vec(
32843271 threadgroup_barrier (mem_flags::mem_threadgroup);
32853272
32863273 {
3287- float S = { 0 .0h };
3288- float M = { -FLT_MAX/2 };
3274+ float S = 0 .0h;
3275+ float M = -FLT_MAX/2 ;
3276+
3277+ // thread indices inside the simdgroup
3278+ const short tx = tiisg%8 ;
3279+ const short ty = tiisg/8 ;
32893280
32903281 // assume K and V are same shape
32913282 const short ne22 = ne12;
32923283 const short ne23 = ne13;
32933284
3294- // broadcast
3285+ // broadcast k
32953286 const short rk2 = ne02/ne12;
32963287 const short rk3 = ne03/ne13;
32973288
3289+ const short ik2 = iq2/rk2;
3290+ const short ik3 = iq3/rk3;
3291+
3292+ // broadcast v
32983293 const short rv2 = ne02/ne22;
32993294 const short rv3 = ne03/ne23;
33003295
3301- // k indices
3302- const short ik2 = iq2 / rk2;
3303- const short ik3 = iq3 / rk3;
3304-
3305- // v indices
3306- const short iv2 = iq2 / rv2;
3307- const short iv3 = iq3 / rv3;
3296+ const short iv2 = iq2/rv2;
3297+ const short iv3 = iq3/rv3;
33083298
33093299 // load the queries from shared memory into local memory
33103300 float4x4 mq[D16/NW4];
33113301
33123302 for (short ii = 0 ; ii < D16; ii += NW4) {
3313- short i = ii + tiisg%8 ;
3314- mq[ii/NW4][0 ] = (float4) sq4[4 *i + 0 ];
3315- mq[ii/NW4][1 ] = (float4) sq4[4 *i + 1 ];
3316- mq[ii/NW4][2 ] = (float4) sq4[4 *i + 2 ];
3317- mq[ii/NW4][3 ] = (float4) sq4[4 *i + 3 ];
3303+ mq[ii/NW4] = (float4x4) sq44[ii + tx];
33183304 }
33193305
33203306 // pointer to the mask
33213307 device const half * mp = (device const half *) (mask + iq1*nb31);
33223308
3309+ float slope = 1 .0f ;
3310+
3311+ // ALiBi
3312+ if (max_bias > 0 .0f ) {
3313+ const uint32_t h = iq2;
3314+
3315+ const float base = h < n_head_log2 ? m0 : m1;
3316+ const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
3317+
3318+ slope = pow (base, exp);
3319+ }
3320+
33233321 // loop over the KV cache
33243322 // each simdgroup handles blocks of Q rows and C columns
33253323 for (int ic0 = 0 ; ic0 < ne11; ic0 += C*nsg) {
@@ -3331,18 +3329,16 @@ kernel void kernel_flash_attn_ext_vec(
33313329 // Q*K^T
33323330 {
33333331 // each simdgroup processes 1 query and 4 keys
3334- const short j = tiisg/8 ;
3335- #pragma unroll
33363332 for (short cc = 0 ; cc < C/4 ; ++cc) {
33373333 float mqk = 0.0 ;
33383334
3339- device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4 *cc + j )*nb11 + ik2*nb12 + ik3*nb13));
3335+ device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4 *cc + ty )*nb11 + ik2*nb12 + ik3*nb13));
33403336
3341- float4x4 mk;
33423337#pragma unroll
33433338 for (short ii = 0 ; ii < D16; ii += NW4) {
3344- const short i = ii + tiisg% 8 ; // 0..7
3339+ const short i = ii + tx;
33453340
3341+ float4x4 mk;
33463342 dequantize_func (pk + i/nl, i%nl, mk);
33473343
33483344 mqk +=
@@ -3364,16 +3360,16 @@ kernel void kernel_flash_attn_ext_vec(
33643360 mqk += simd_shuffle_down (mqk, 1 );
33653361
33663362 // mqk = mqk*scale + mask*slope
3367- if (tiisg% 8 == 0 ) {
3363+ if (tx == 0 ) {
33683364 mqk *= scale;
33693365
33703366 if (logit_softcap != 0 .0f ) {
33713367 mqk = logit_softcap*precise::tanh (mqk);
33723368 }
33733369
3374- mqk += (mask != q) ? ((float ) mp[ic + 4 *cc + j ])*slope : (float ) 0 .0f ;
3370+ mqk += (mask != q) ? ((float ) mp[ic + 4 *cc + ty ])*slope : (float ) 0 .0f ;
33753371
3376- ss[4 *cc + j ] = mqk;
3372+ ss[4 *cc + ty ] = mqk;
33773373 }
33783374 }
33793375 }
@@ -3408,20 +3404,20 @@ kernel void kernel_flash_attn_ext_vec(
34083404
34093405 // O = O + (Q*K^T)*V
34103406 {
3411- const short j = tiisg/8 ;
34123407#pragma unroll
34133408 for (short cc = 0 ; cc < C/4 ; ++cc) {
3414- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4 *cc + j)*nb21 + iv2*nb22 + iv3*nb23));
3409+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4 *cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3410+
3411+ const float4x4 lss (ss[4 *cc + ty]);
34153412
3416- float4x4 mv;
3417- const float4x4 lss (ss[4 *cc + j]);
34183413#pragma unroll
34193414 for (short ii = 0 ; ii < D16; ii += NW4) {
3420- const short i = ii + tiisg% 8 ;
3415+ const short i = ii + tx ;
34213416
3417+ float4x4 mv;
34223418 dequantize_func (pv4 + i/nl, i%nl, mv);
34233419
3424- lo[ii/NW4] += (half4x4)( mv*lss) ;
3420+ lo[ii/NW4] += mv*lss;
34253421 }
34263422 }
34273423 }
@@ -3458,14 +3454,8 @@ kernel void kernel_flash_attn_ext_vec(
34583454 }
34593455
34603456 // store results to shared memory
3461- for (short ii = 0 ; ii < D16; ii += NW4) {
3462- short i = ii + tiisg;
3463- if (tiisg < 8 ) {
3464- sr4[4 *i + 0 ] = lo[ii/NW4][0 ];
3465- sr4[4 *i + 1 ] = lo[ii/NW4][1 ];
3466- sr4[4 *i + 2 ] = lo[ii/NW4][2 ];
3467- sr4[4 *i + 3 ] = lo[ii/NW4][3 ];
3468- }
3457+ for (short i = tiisg; i < D16; i += NW4) {
3458+ sr44[i] = lo[i/NW4];
34693459 }
34703460
34713461 threadgroup_barrier (mem_flags::mem_threadgroup);
@@ -3492,24 +3482,22 @@ kernel void kernel_flash_attn_ext_vec(
34923482 }
34933483
34943484 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3495- for (short ii = 0 ; ii < D4; ii += NW) {
3496- short i = ii + tiisg;
3497- sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
3485+ for (short i = tiisg; i < D16; i += NW) {
3486+ sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
34983487 }
34993488 }
35003489
35013490 threadgroup_barrier (mem_flags::mem_threadgroup);
35023491 }
35033492
3504- device float4 * dst4 = (device float4 *) dst;
3493+ device float4x4 * dst44 = (device float4x4 *) dst;
35053494
35063495 // final rescale with 1/S and store to global memory
35073496 if (sgitg == 0 ) {
35083497 const float S = ss[0 ];
35093498
3510- for (short ii = 0 ; ii < D4; ii += NW) {
3511- short i = ii + tiisg;
3512- dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
3499+ for (short i = tiisg; i < D16; i += NW) {
3500+ dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
35133501 }
35143502 }
35153503}
0 commit comments