@@ -2774,13 +2774,13 @@ kernel void kernel_flash_attn_ext(
27742774 const short NW = N_SIMDWIDTH;
27752775 const short SH = (C + Q); // shared memory per simdgroup in (half)
27762776
2777- const short T = D + 2 * nsg*SH; // shared memory size per query in (half)
2778- const short TF = T/ 2 ; // shared memory size per query in (float)
2777+ const short T = D + nsg*SH; // shared memory size per query in (half)
2778+ const short TF = T; // shared memory size per query in (float)
27792779 const short T4 = T/4 ; // shared memory size per query in (half4)
27802780
2781- threadgroup half * sq = (threadgroup half *) (shared + 0 *D); // holds the query data
2782- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
2783- threadgroup float * ss = (threadgroup float *) (shared + 2 * sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
2781+ threadgroup half * sq = (threadgroup half *) (shared + 0 *D); // holds the query data
2782+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
2783+ threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
27842784
27852785 threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K and V in shared memory
27862786 threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4 *16 *KV) + Q*T); // same as above but in half4x4
@@ -2809,7 +2809,7 @@ kernel void kernel_flash_attn_ext(
28092809 // zero out shared memory SH
28102810 for (short j = 0 ; j < Q; ++j) {
28112811 for (short i = tiisg; i < SH; i += NW) {
2812- ss[j*TF + i] = 0 .0f ;
2812+ ss[j*TF + i] = 0 .0h ;
28132813 }
28142814 }
28152815
@@ -2874,7 +2874,7 @@ kernel void kernel_flash_attn_ext(
28742874 // Q*K^T
28752875 {
28762876 for (short cc = 0 ; cc < C/8 ; ++cc) {
2877- simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float , 8 >(0 .h );
2877+ simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half , 8 >(0 .h );
28782878
28792879 if (is_same<block_q, half4x4>::value) {
28802880 // we can read directly from global memory
@@ -2944,7 +2944,7 @@ kernel void kernel_flash_attn_ext(
29442944 const float m = M[j];
29452945
29462946 // scale and apply the logitcap / mask
2947- float s = ss[j*TF + tiisg]*scale;
2947+ float s = (( float )( ss[j*TF + tiisg])) *scale;
29482948
29492949 if (logit_softcap != 0 .0f ) {
29502950 s = logit_softcap*precise::tanh (s);
@@ -2980,7 +2980,7 @@ kernel void kernel_flash_attn_ext(
29802980
29812981 // O = diag(ms)*O
29822982 {
2983- simdgroup_float8x8 mm;
2983+ simdgroup_half8x8 mm;
29842984 simdgroup_load (mm, ss + C, TF, 0 , false );
29852985
29862986 for (short i = 0 ; i < D8; ++i) {
@@ -2991,7 +2991,7 @@ kernel void kernel_flash_attn_ext(
29912991 // O = O + (Q*K^T)*V
29922992 {
29932993 for (short cc = 0 ; cc < C/8 ; ++cc) {
2994- simdgroup_float8x8 ms;
2994+ simdgroup_half8x8 ms;
29952995 simdgroup_load (ms, ss + 8 *cc, TF, 0 , false );
29962996
29972997 if (is_same<block_q, half4x4>::value) {
@@ -3103,8 +3103,8 @@ kernel void kernel_flash_attn_ext(
31033103 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
31043104 {
31053105 simdgroup_half8x8 t;
3106- simdgroup_float8x8 ms0;
3107- simdgroup_float8x8 ms1;
3106+ simdgroup_half8x8 ms0;
3107+ simdgroup_half8x8 ms1;
31083108
31093109 simdgroup_load (ms0, ss + C, TF, 0 , false );
31103110 simdgroup_load (ms1, ss + C + sg*SH, TF, 0 , false );
0 commit comments