@@ -3355,15 +3355,15 @@ kernel void kernel_flash_attn_ext_vec(
33553355 const short NW4 = NW/4 ;
33563356 const short SH = 2 *C; // shared memory per simdgroup
33573357
3358- const short T = D + 2 * nsg*SH; // shared memory size per query in (half)
3358+ const short T = D + nsg*SH; // shared memory size per query in (half)
33593359
3360- // threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0 *D); // same as above but in q4_t
3362- threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0 *D); // same as above but in q4x4_t
3363- threadgroup s_t * ss = (threadgroup s_t *) (shared + 2 * sgitg*SH + Q*D); // scratch buffer for attention
3364- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2 * sgitg*SH + Q*D); // same as above but in s4_t
3365- threadgroup half * sm = (threadgroup half *) (shared + 2 * sgitg*SH + SH + Q*D); // scratch buffer for mask
3366- threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3360+ // threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0 *D); // same as above but in q4_t
3362+ threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0 *D); // same as above but in q4x4_t
3363+ threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
3364+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
3365+ threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
3366+ threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
33673367
33683368 // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
33693369 o4x4_t lo[D16/NW4];
@@ -3522,7 +3522,7 @@ kernel void kernel_flash_attn_ext_vec(
35223522 for (short cc = 0 ; cc < C/4 ; ++cc) {
35233523 device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
35243524
3525- const v4x4_t ms (ss[4 *cc + ty]);
3525+ const s4x4_t ms (ss[4 *cc + ty]);
35263526
35273527#pragma unroll
35283528 for (short ii = 0 ; ii < D16; ii += NW4) {
@@ -3531,7 +3531,7 @@ kernel void kernel_flash_attn_ext_vec(
35313531 v4x4_t mv;
35323532 deq_v (pv4 + i/nl_v, i%nl_v, mv);
35333533
3534- lo[ii/NW4] += ( o4x4_t )( mv*ms) ;
3534+ lo[ii/NW4] += mv*ms;
35353535 }
35363536 }
35373537 }
@@ -3616,12 +3616,15 @@ kernel void kernel_flash_attn_ext_vec(
36163616 }
36173617}
36183618
3619+ // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
3620+ // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3621+ //
36193622#define FA_TYPES \
3620- half4, half4x4, \
3621- half4x4, \
3622- half4x4, \
3623- float , \
3624- float , float4, float4x4 , \
3623+ half4, half4x4, \
3624+ half4x4, \
3625+ half4x4, \
3626+ float , \
3627+ half, half4, half4x4 , \
36253628 half4x4
36263629
36273630typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
0 commit comments