Skip to content

Commit 47c7b3b

Browse files
committed
metal : optimize mul_mm_id id gathering
1 parent 8a3c0b3 commit 47c7b3b

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ typedef struct {
320320
} ggml_metal_kargs_mul_mv_ext;
321321

322322
typedef struct {
323+
int32_t ne02;
323324
int32_t ne10;
324325
int32_t ne11; // n_expert_used (bcast)
325326
uint64_t nb11;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3925,6 +3925,7 @@ static int ggml_metal_encode_node(
39253925

39263926
{
39273927
ggml_metal_kargs_mul_mm_id_map0 args = {
3928+
ne02,
39283929
ne10,
39293930
ne11, // n_expert_used (bcast)
39303931
nb11,
@@ -3938,13 +3939,20 @@ static int ggml_metal_encode_node(
39383939

39393940
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
39403941

3942+
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
3943+
3944+
const size_t smem = ne02*ne20*sizeof(uint16_t);
3945+
3946+
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3947+
39413948
[encoder setComputePipelineState:pipeline];
39423949
[encoder setBytes:&args length:sizeof(args) atIndex:0];
39433950
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
39443951
[encoder setBuffer: h_tpe offset:0 atIndex:2];
39453952
[encoder setBuffer: h_ids offset:0 atIndex:3];
3953+
[encoder setThreadgroupMemoryLength:ne02*ne20*sizeof(uint16_t) atIndex:0];
39463954

3947-
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3955+
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
39483956
}
39493957

39503958
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7497,31 +7497,45 @@ kernel void kernel_mul_mm_id_map0(
74977497
device const char * src2,
74987498
device char * htpe,
74997499
device char * hids,
7500-
uint3 tgpig[[threadgroup_position_in_grid]],
7501-
ushort3 tpitg[[thread_position_in_threadgroup]],
7502-
ushort3 ntg[[threads_per_threadgroup]]) {
7503-
const int ide = tgpig[0]; // expert id
7500+
threadgroup char * shmem [[threadgroup(0)]],
7501+
ushort tpitg[[thread_position_in_threadgroup]],
7502+
ushort ntg[[threads_per_threadgroup]]) {
7503+
const short ide = tpitg; // expert id
75047504

7505-
int n_all = 0;
7505+
uint32_t n_all = 0;
75067506

75077507
device int32_t * ids_i32 = (device int32_t *) (hids);
75087508

7509-
for (int i21 = 0; i21 < args.ne21; i21++) { // n_tokens
7510-
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
7509+
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
7510+
{
7511+
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
75117512

7512-
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
7513-
if (src2_i32[i20] != ide) {
7514-
continue;
7513+
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*args.ne20;
7514+
7515+
for (int i20 = 0; i20 < args.ne20 && i21 + tpitg < args.ne21; i20++) {
7516+
sids[i20] = src2_i32[i20];
75157517
}
7518+
}
7519+
7520+
threadgroup_barrier(mem_flags::mem_threadgroup);
75167521

7517-
ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
7522+
for (int t = 0; t < ntg && i21 + t < args.ne21; t++) {
7523+
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + t*args.ne20;
75187524

7519-
++n_all;
7525+
for (int i20 = 0; i20 < args.ne20; i20++) {
7526+
if (sids[i20] == ide) {
7527+
ids_i32[ide*args.ne21 + n_all] = (i21 + t)*args.ne20 + i20;
7528+
++n_all;
7529+
break;
7530+
}
7531+
}
75207532
}
7533+
7534+
threadgroup_barrier(mem_flags::mem_threadgroup);
75217535
}
75227536

7523-
device int32_t * tpe_i32 = (device int32_t *) (htpe);
7524-
tpe_i32[ide] = n_all;
7537+
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
7538+
tpe_u32[ide] = n_all;
75257539
}
75267540

75277541
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
@@ -7549,10 +7563,10 @@ kernel void kernel_mul_mm_id(
75497563
const int r1 = tgpig.x;
75507564
const int im = tgpig.z; // expert
75517565

7552-
device const int32_t * tpe_i32 = (device const int32_t *) (htpe);
7553-
device const int32_t * ids_i32 = (device const int32_t *) (hids);
7566+
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
7567+
device const int32_t * ids_i32 = (device const int32_t *) (hids);
75547568

7555-
const int neh1 = tpe_i32[im];
7569+
const uint32_t neh1 = tpe_u32[im];
75567570

75577571
if (r1*BLOCK_SIZE_N >= neh1) {
75587572
return;
@@ -7578,9 +7592,9 @@ kernel void kernel_mul_mm_id(
75787592

75797593
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
75807594

7581-
const int i11 = (id % args.ne20) % args.ne11;
7582-
const int i12 = (id / args.ne20);
7583-
const int i13 = 0;
7595+
const short i11 = (id % args.ne20) % args.ne11;
7596+
const short i12 = (id / args.ne20);
7597+
const short i13 = 0;
75847598

75857599
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
75867600
const short offset1 = il/nl;
@@ -7649,17 +7663,18 @@ kernel void kernel_mul_mm_id(
76497663
threadgroup float * temp_str = ((threadgroup float *) shmem) \
76507664
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
76517665

7666+
#pragma unroll(8)
76527667
for (short i = 0; i < 8; i++) {
76537668
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
76547669
}
76557670

76567671
threadgroup_barrier(mem_flags::mem_threadgroup);
76577672

7658-
for (int j = sgitg; j < n_cols; j += 4) {
7673+
for (short j = sgitg; j < n_cols; j += 4) {
76597674
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76607675

7661-
const int ide = id % args.ne20;
7662-
const int idt = id / args.ne20;
7676+
const short ide = id % args.ne20;
7677+
const short idt = id / args.ne20;
76637678

76647679
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
76657680
device float4 * D4 = (device float4 *) D;

0 commit comments

Comments
 (0)