@@ -176,36 +176,73 @@ kernel void kernel_soft_max(
176176 constant int64_t & ne00,
177177 constant int64_t & ne01,
178178 constant int64_t & ne02,
179- uint3 tgpig[[threadgroup_position_in_grid]],
180- uint3 tpitg[[thread_position_in_threadgroup]],
181- uint3 ntg[[threads_per_threadgroup]]) {
182- const int64_t i03 = tgpig[2 ];
183- const int64_t i02 = tgpig[1 ];
184- const int64_t i01 = tgpig[0 ];
179+ threadgroup float * buf [[threadgroup(0 )]],
180+ uint tgpig[[threadgroup_position_in_grid]],
181+ uint tpitg[[thread_position_in_threadgroup]],
182+ uint sgitg[[simdgroup_index_in_threadgroup]],
183+ uint tiisg[[thread_index_in_simdgroup]],
184+ uint ntg[[threads_per_threadgroup]]) {
185+ const int64_t i03 = (tgpig) / (ne02*ne01);
186+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
187+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
185188
186189 device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
187190 device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
188191
189192 // parallel max
190- float lmax = tpitg[0 ] < ne00 ? psrc0[tpitg[0 ]] : -INFINITY;
191- for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00; i00 += ntg[0 ]) {
193+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
194+
195+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
192196 lmax = MAX (lmax, psrc0[i00]);
193197 }
194- const float max = simd_max (lmax);
198+
199+ float max = simd_max (lmax);
200+ if (tiisg == 0 ) {
201+ buf[sgitg] = max;
202+ }
203+
204+ threadgroup_barrier (mem_flags::mem_threadgroup);
205+
206+ // broadcast, simd group number is ntg / 32
207+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
208+ if (tpitg < i) {
209+ buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
210+ }
211+ }
212+
213+ threadgroup_barrier (mem_flags::mem_threadgroup);
214+
215+ max = buf[0 ];
195216
196217 // parallel sum
197218 float lsum = 0 .0f ;
198- for (int i00 = tpitg[ 0 ] ; i00 < ne00; i00 += ntg[ 0 ] ) {
219+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
199220 const float exp_psrc0 = exp (psrc0[i00] - max);
200221 lsum += exp_psrc0;
201222 // Remember the result of exp here. exp is expensive, so we really do not
202- // whish to compute it twice.
223+ // wish to compute it twice.
203224 pdst[i00] = exp_psrc0;
204225 }
205226
206- const float sum = simd_sum (lsum);
227+ float sum = simd_sum (lsum);
228+ if (tiisg == 0 ) {
229+ buf[sgitg] = sum;
230+ }
231+
232+ threadgroup_barrier (mem_flags::mem_threadgroup);
233+
234+ // broadcast, simd group number is ntg / 32
235+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
236+ if (tpitg < i) {
237+ buf[tpitg] += buf[tpitg + i];
238+ }
239+ }
240+
241+ threadgroup_barrier (mem_flags::mem_threadgroup);
242+
243+ sum = buf[0 ];
207244
208- for (int i00 = tpitg[ 0 ] ; i00 < ne00; i00 += ntg[ 0 ] ) {
245+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
209246 pdst[i00] /= sum;
210247 }
211248}
@@ -216,37 +253,73 @@ kernel void kernel_soft_max_4(
216253 constant int64_t & ne00,
217254 constant int64_t & ne01,
218255 constant int64_t & ne02,
219- uint3 tgpig[[threadgroup_position_in_grid]],
220- uint3 tpitg[[thread_position_in_threadgroup]],
221- uint3 ntg[[threads_per_threadgroup]]) {
222- const int64_t i03 = tgpig[2 ];
223- const int64_t i02 = tgpig[1 ];
224- const int64_t i01 = tgpig[0 ];
256+ threadgroup float * buf [[threadgroup(0 )]],
257+ uint tgpig[[threadgroup_position_in_grid]],
258+ uint tpitg[[thread_position_in_threadgroup]],
259+ uint sgitg[[simdgroup_index_in_threadgroup]],
260+ uint tiisg[[thread_index_in_simdgroup]],
261+ uint ntg[[threads_per_threadgroup]]) {
262+ const int64_t i03 = (tgpig) / (ne02*ne01);
263+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
264+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
225265
226266 device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
227267 device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
228268
229269 // parallel max
230- float4 lmax4 = tpitg[0 ] < ne00/4 ? psrc4[tpitg[0 ]] : -INFINITY;
231- for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00/4 ; i00 += ntg[0 ]) {
270+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
271+
272+ for (int i00 = tpitg + ntg; i00 < ne00/4 ; i00 += ntg) {
232273 lmax4 = fmax (lmax4, psrc4[i00]);
233274 }
234- float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
235275
236- const float max = simd_max (lmax);
276+ const float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
277+ float max = simd_max (lmax);
278+ if (tiisg == 0 ) {
279+ buf[sgitg] = max;
280+ }
281+
282+ threadgroup_barrier (mem_flags::mem_threadgroup);
283+
284+ // broadcast, simd group number is ntg / 32
285+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
286+ if (tpitg < i) {
287+ buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
288+ }
289+ }
290+
291+ threadgroup_barrier (mem_flags::mem_threadgroup);
292+
293+ max = buf[0 ];
237294
238295 // parallel sum
239296 float4 lsum4 = 0 .0f ;
240- for (int i00 = tpitg[ 0 ] ; i00 < ne00/4 ; i00 += ntg[ 0 ] ) {
297+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
241298 const float4 exp_psrc4 = exp (psrc4[i00] - max);
242299 lsum4 += exp_psrc4;
243300 pdst4[i00] = exp_psrc4;
244301 }
245- float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
246302
247- const float sum = simd_sum (lsum);
303+ const float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
304+ float sum = simd_sum (lsum);
305+ if (tiisg == 0 ) {
306+ buf[sgitg] = sum;
307+ }
308+
309+ threadgroup_barrier (mem_flags::mem_threadgroup);
310+
311+ // broadcast, simd group number is ntg / 32
312+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
313+ if (tpitg < i) {
314+ buf[tpitg] += buf[tpitg + i];
315+ }
316+ }
317+
318+ threadgroup_barrier (mem_flags::mem_threadgroup);
319+
320+ sum = buf[0 ];
248321
249- for (int i00 = tpitg[ 0 ] ; i00 < ne00/4 ; i00 += ntg[ 0 ] ) {
322+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
250323 pdst4[i00] /= sum;
251324 }
252325}
@@ -266,7 +339,7 @@ kernel void kernel_diag_mask_inf(
266339 dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
267340 } else {
268341 dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
269- }
342+ }
270343}
271344
272345kernel void kernel_diag_mask_inf_8 (
0 commit comments