@@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
31453145 GGML_ASSERT (nqptg % 8 == 0 );
31463146 GGML_ASSERT (ncpsg % 32 == 0 );
31473147
3148+ // 16*32*nsgmax
3149+ // the shared memory needed for the simdgroups to load the KV cache
3150+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3151+ //
3152+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *nsgmax*(ncpsg + nqptg)) + 16 *32 *nsgmax)*(sizeof (float )/2 ), 16 ))
3153+
31483154 int64_t nsgmax = 2 ;
31493155
31503156 while (true ) {
3151- // 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
3152- // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3153- const size_t smem = (nqptg*(ne00 + 2 *nsgmax*(ncpsg + nqptg)) + 16 *32 *nsgmax)*(sizeof (float )/2 );
3157+ const size_t smem = FATTN_SMEM (nsgmax);
31543158 if (smem > device.maxThreadgroupMemoryLength ) {
31553159 break ;
31563160 }
@@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
31613165 // simdgroups per threadgroup (a.k.a. warps)
31623166 const int64_t nsg = ne01 <= nqptg ? MAX (4 , MIN (nsgmax, MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ))) : 4 ;
31633167
3164- const size_t smem = (nqptg*(ne00 + 2 * nsg*(ncpsg + nqptg)) + 16 * 32 *nsg)*( sizeof ( float )/ 2 );
3168+ const size_t smem = FATTN_SMEM ( nsg);
31653169
31663170 // printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
31673171 GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
3168-
3169- [encoder setThreadgroupMemoryLength: GGML_PAD (smem, 16 ) atIndex: 0 ];
3170-
3172+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
3173+ #undef FATTN_SMEM
31713174 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
31723175 } else {
31733176 // half4x4 kernel
@@ -3178,21 +3181,41 @@ static void ggml_metal_encode_node(
31783181 GGML_ASSERT (nqptg % 1 == 0 );
31793182 GGML_ASSERT (ncpsg % 32 == 0 );
31803183
3184+ // ne00 + 2*ncpsg*(nsg)
3185+ // for each query, we load it as f16 in shared memory (ne00)
3186+ // and store the attention scores (nqptg x ncpsg) as f32
3187+ //
3188+ // 2*ne00*(nsg)
3189+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
3190+ //
3191+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *ncpsg*(nsg)) + 2 *ne00*(nsg))*(sizeof (float )/2 ), 16 ))
3192+
3193+ int64_t nsgmax = 2 ;
3194+
3195+ while (true ) {
3196+ const size_t smem = FATTN_SMEM (nsgmax);
3197+ if (smem > device.maxThreadgroupMemoryLength ) {
3198+ break ;
3199+ }
3200+ nsgmax *= 2 ;
3201+ }
3202+ nsgmax /= 2 ;
3203+
31813204 // simdgroups per threadgroup (a.k.a. warps)
3182- const int64_t nsgt = MAX (2 , MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ));
3205+ const int64_t nsgt = MAX (2 , MIN (nsgmax, MIN ( ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ) ));
31833206
31843207 int64_t nsg = 1 ;
31853208 while (nsg <= nsgt) {
31863209 nsg *= 2 ;
31873210 }
31883211 nsg /= 2 ;
31893212
3190- const size_t smem = (nqptg*(ne00 + 2 * nsg*(ncpsg + nqptg)) + 2 *nsg*ne00)*( sizeof ( float )/ 2 );
3213+ const size_t smem = FATTN_SMEM ( nsg);
31913214
3192- // printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
3215+ // printf("smem: %zu, max: %zu, nsg = %d \n", smem, device.maxThreadgroupMemoryLength, (int) nsg );
31933216 GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
3194- [encoder setThreadgroupMemoryLength: GGML_PAD ( smem, 16 ) atIndex: 0 ];
3195-
3217+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
3218+ # undef FATTN_SMEM
31963219 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
31973220 }
31983221 } break ;
0 commit comments