@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
184184 constant int64_t & ne00,
185185 constant int64_t & ne01,
186186 constant int64_t & ne02,
187- uint3 tgpig[[threadgroup_position_in_grid]],
188- uint3 tpitg[[thread_position_in_threadgroup]],
189- uint3 ntg[[threads_per_threadgroup]]) {
190- const int64_t i03 = tgpig[2 ];
191- const int64_t i02 = tgpig[1 ];
192- const int64_t i01 = tgpig[0 ];
187+ threadgroup float * buf [[threadgroup(0 )]],
188+ uint tgpig[[threadgroup_position_in_grid]],
189+ uint tpitg[[thread_position_in_threadgroup]],
190+ uint sgitg[[simdgroup_index_in_threadgroup]],
191+ uint tiisg[[thread_index_in_simdgroup]],
192+ uint ntg[[threads_per_threadgroup]]) {
193+ const int64_t i03 = (tgpig) / (ne02*ne01);
194+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
193196
194197 device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
195198 device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
196199
197200 // parallel max
198- float lmax = tpitg[0 ] < ne00 ? psrc0[tpitg[0 ]] : -INFINITY;
199- for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00; i00 += ntg[0 ]) {
201+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202+
203+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
200204 lmax = MAX (lmax, psrc0[i00]);
201205 }
202- const float max = simd_max (lmax);
206+
207+ float max = simd_max (lmax);
208+ if (tiisg == 0 ) {
209+ buf[sgitg] = max;
210+ }
211+
212+ threadgroup_barrier (mem_flags::mem_threadgroup);
213+
214+ // broadcast, simd group number is ntg / 32
215+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
216+ if (tpitg < i) {
217+ buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
218+ }
219+ }
220+
221+ threadgroup_barrier (mem_flags::mem_threadgroup);
222+
223+ max = buf[0 ];
203224
204225 // parallel sum
205226 float lsum = 0 .0f ;
206- for (int i00 = tpitg[ 0 ] ; i00 < ne00; i00 += ntg[ 0 ] ) {
227+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
207228 const float exp_psrc0 = exp (psrc0[i00] - max);
208229 lsum += exp_psrc0;
209230 // Remember the result of exp here. exp is expensive, so we really do not
210- // whish to compute it twice.
231+ // wish to compute it twice.
211232 pdst[i00] = exp_psrc0;
212233 }
213234
214- const float sum = simd_sum (lsum);
235+ float sum = simd_sum (lsum);
236+ if (tiisg == 0 ) {
237+ buf[sgitg] = sum;
238+ }
239+
240+ threadgroup_barrier (mem_flags::mem_threadgroup);
241+
242+ // broadcast, simd group number is ntg / 32
243+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
244+ if (tpitg < i) {
245+ buf[tpitg] += buf[tpitg + i];
246+ }
247+ }
248+
249+ threadgroup_barrier (mem_flags::mem_threadgroup);
250+
251+ sum = buf[0 ];
215252
216- for (int i00 = tpitg[ 0 ] ; i00 < ne00; i00 += ntg[ 0 ] ) {
253+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
217254 pdst[i00] /= sum;
218255 }
219256}
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
224261 constant int64_t & ne00,
225262 constant int64_t & ne01,
226263 constant int64_t & ne02,
227- uint3 tgpig[[threadgroup_position_in_grid]],
228- uint3 tpitg[[thread_position_in_threadgroup]],
229- uint3 ntg[[threads_per_threadgroup]]) {
230- const int64_t i03 = tgpig[2 ];
231- const int64_t i02 = tgpig[1 ];
232- const int64_t i01 = tgpig[0 ];
264+ threadgroup float * buf [[threadgroup(0 )]],
265+ uint tgpig[[threadgroup_position_in_grid]],
266+ uint tpitg[[thread_position_in_threadgroup]],
267+ uint sgitg[[simdgroup_index_in_threadgroup]],
268+ uint tiisg[[thread_index_in_simdgroup]],
269+ uint ntg[[threads_per_threadgroup]]) {
270+ const int64_t i03 = (tgpig) / (ne02*ne01);
271+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
233273
234274 device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
235275 device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
236276
237277 // parallel max
238- float4 lmax4 = tpitg[0 ] < ne00/4 ? psrc4[tpitg[0 ]] : -INFINITY;
239- for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00/4 ; i00 += ntg[0 ]) {
278+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279+
280+ for (int i00 = tpitg + ntg; i00 < ne00/4 ; i00 += ntg) {
240281 lmax4 = fmax (lmax4, psrc4[i00]);
241282 }
242- float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
243283
244- const float max = simd_max (lmax);
284+ const float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
285+ float max = simd_max (lmax);
286+ if (tiisg == 0 ) {
287+ buf[sgitg] = max;
288+ }
289+
290+ threadgroup_barrier (mem_flags::mem_threadgroup);
291+
292+ // broadcast, simd group number is ntg / 32
293+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
294+ if (tpitg < i) {
295+ buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
296+ }
297+ }
298+
299+ threadgroup_barrier (mem_flags::mem_threadgroup);
300+
301+ max = buf[0 ];
245302
246303 // parallel sum
247304 float4 lsum4 = 0 .0f ;
248- for (int i00 = tpitg[ 0 ] ; i00 < ne00/4 ; i00 += ntg[ 0 ] ) {
305+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
249306 const float4 exp_psrc4 = exp (psrc4[i00] - max);
250307 lsum4 += exp_psrc4;
251308 pdst4[i00] = exp_psrc4;
252309 }
253- float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
254310
255- const float sum = simd_sum (lsum);
311+ const float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
312+ float sum = simd_sum (lsum);
313+ if (tiisg == 0 ) {
314+ buf[sgitg] = sum;
315+ }
316+
317+ threadgroup_barrier (mem_flags::mem_threadgroup);
318+
319+ // broadcast, simd group number is ntg / 32
320+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
321+ if (tpitg < i) {
322+ buf[tpitg] += buf[tpitg + i];
323+ }
324+ }
325+
326+ threadgroup_barrier (mem_flags::mem_threadgroup);
327+
328+ sum = buf[0 ];
256329
257- for (int i00 = tpitg[ 0 ] ; i00 < ne00/4 ; i00 += ntg[ 0 ] ) {
330+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
258331 pdst4[i00] /= sum;
259332 }
260333}
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
274347 dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
275348 } else {
276349 dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
277- }
350+ }
278351}
279352
280353kernel void kernel_diag_mask_inf_8 (
0 commit comments