Skip to content

Commit 228a568

Browse files
committed
metal : multi-simd softmax
ggml-ci
1 parent 465219b commit 228a568

File tree

2 files changed

+108
-30
lines changed

2 files changed

+108
-30
lines changed

ggml-metal.m

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,20 +986,25 @@ void ggml_metal_graph_compute(
986986
} break;
987987
case GGML_OP_SOFT_MAX:
988988
{
989-
const int nth = MIN(32, ne00);
989+
int nth = 32; // SIMD width
990990

991991
if (ne00%4 == 0) {
992992
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
993993
} else {
994+
do {
995+
nth *= 2;
996+
} while (nth <= ne00 && nth <= 1024);
997+
nth /= 2;
994998
[encoder setComputePipelineState:ctx->pipeline_soft_max];
995999
}
9961000
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
9971001
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
9981002
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
9991003
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
10001004
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1005+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
10011006

1002-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1007+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10031008
} break;
10041009
case GGML_OP_DIAG_MASK_INF:
10051010
{

ggml-metal.metal

Lines changed: 101 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -176,36 +176,73 @@ kernel void kernel_soft_max(
176176
constant int64_t & ne00,
177177
constant int64_t & ne01,
178178
constant int64_t & ne02,
179-
uint3 tgpig[[threadgroup_position_in_grid]],
180-
uint3 tpitg[[thread_position_in_threadgroup]],
181-
uint3 ntg[[threads_per_threadgroup]]) {
182-
const int64_t i03 = tgpig[2];
183-
const int64_t i02 = tgpig[1];
184-
const int64_t i01 = tgpig[0];
179+
threadgroup float * buf [[threadgroup(0)]],
180+
uint tgpig[[threadgroup_position_in_grid]],
181+
uint tpitg[[thread_position_in_threadgroup]],
182+
uint sgitg[[simdgroup_index_in_threadgroup]],
183+
uint tiisg[[thread_index_in_simdgroup]],
184+
uint ntg[[threads_per_threadgroup]]) {
185+
const int64_t i03 = (tgpig) / (ne02*ne01);
186+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
187+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
185188

186189
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
187190
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
188191

189192
// parallel max
190-
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
191-
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
193+
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
194+
195+
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
192196
lmax = MAX(lmax, psrc0[i00]);
193197
}
194-
const float max = simd_max(lmax);
198+
199+
float max = simd_max(lmax);
200+
if (tiisg == 0) {
201+
buf[sgitg] = max;
202+
}
203+
204+
threadgroup_barrier(mem_flags::mem_threadgroup);
205+
206+
// broadcast, simd group number is ntg / 32
207+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
208+
if (tpitg < i) {
209+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
210+
}
211+
}
212+
213+
threadgroup_barrier(mem_flags::mem_threadgroup);
214+
215+
max = buf[0];
195216

196217
// parallel sum
197218
float lsum = 0.0f;
198-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
219+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
199220
const float exp_psrc0 = exp(psrc0[i00] - max);
200221
lsum += exp_psrc0;
201222
// Remember the result of exp here. exp is expensive, so we really do not
202-
// whish to compute it twice.
223+
// wish to compute it twice.
203224
pdst[i00] = exp_psrc0;
204225
}
205226

206-
const float sum = simd_sum(lsum);
227+
float sum = simd_sum(lsum);
228+
if (tiisg == 0) {
229+
buf[sgitg] = sum;
230+
}
231+
232+
threadgroup_barrier(mem_flags::mem_threadgroup);
233+
234+
// broadcast, simd group number is ntg / 32
235+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
236+
if (tpitg < i) {
237+
buf[tpitg] += buf[tpitg + i];
238+
}
239+
}
240+
241+
threadgroup_barrier(mem_flags::mem_threadgroup);
242+
243+
sum = buf[0];
207244

208-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
245+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
209246
pdst[i00] /= sum;
210247
}
211248
}
@@ -216,37 +253,73 @@ kernel void kernel_soft_max_4(
216253
constant int64_t & ne00,
217254
constant int64_t & ne01,
218255
constant int64_t & ne02,
219-
uint3 tgpig[[threadgroup_position_in_grid]],
220-
uint3 tpitg[[thread_position_in_threadgroup]],
221-
uint3 ntg[[threads_per_threadgroup]]) {
222-
const int64_t i03 = tgpig[2];
223-
const int64_t i02 = tgpig[1];
224-
const int64_t i01 = tgpig[0];
256+
threadgroup float * buf [[threadgroup(0)]],
257+
uint tgpig[[threadgroup_position_in_grid]],
258+
uint tpitg[[thread_position_in_threadgroup]],
259+
uint sgitg[[simdgroup_index_in_threadgroup]],
260+
uint tiisg[[thread_index_in_simdgroup]],
261+
uint ntg[[threads_per_threadgroup]]) {
262+
const int64_t i03 = (tgpig) / (ne02*ne01);
263+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
264+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
225265

226266
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
227267
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
228268

229269
// parallel max
230-
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
231-
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
270+
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
271+
272+
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
232273
lmax4 = fmax(lmax4, psrc4[i00]);
233274
}
234-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
235275

236-
const float max = simd_max(lmax);
276+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
277+
float max = simd_max(lmax);
278+
if (tiisg == 0) {
279+
buf[sgitg] = max;
280+
}
281+
282+
threadgroup_barrier(mem_flags::mem_threadgroup);
283+
284+
// broadcast, simd group number is ntg / 32
285+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
286+
if (tpitg < i) {
287+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
288+
}
289+
}
290+
291+
threadgroup_barrier(mem_flags::mem_threadgroup);
292+
293+
max = buf[0];
237294

238295
// parallel sum
239296
float4 lsum4 = 0.0f;
240-
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
297+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
241298
const float4 exp_psrc4 = exp(psrc4[i00] - max);
242299
lsum4 += exp_psrc4;
243300
pdst4[i00] = exp_psrc4;
244301
}
245-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
246302

247-
const float sum = simd_sum(lsum);
303+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
304+
float sum = simd_sum(lsum);
305+
if (tiisg == 0) {
306+
buf[sgitg] = sum;
307+
}
308+
309+
threadgroup_barrier(mem_flags::mem_threadgroup);
310+
311+
// broadcast, simd group number is ntg / 32
312+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
313+
if (tpitg < i) {
314+
buf[tpitg] += buf[tpitg + i];
315+
}
316+
}
317+
318+
threadgroup_barrier(mem_flags::mem_threadgroup);
319+
320+
sum = buf[0];
248321

249-
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
322+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
250323
pdst4[i00] /= sum;
251324
}
252325
}
@@ -266,7 +339,7 @@ kernel void kernel_diag_mask_inf(
266339
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
267340
} else {
268341
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
269-
}
342+
}
270343
}
271344

272345
kernel void kernel_diag_mask_inf_8(

0 commit comments

Comments
 (0)