Skip to content

Commit 80aa73a

Browse files
committed
metal : optimize mul_mm_id_map0
ggml-ci
1 parent 175d435 commit 80aa73a

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
396396
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
397397
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
398398
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
399-
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
399+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
400+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
401+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
402+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
403+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
404+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
400405
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
401406
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
402407
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1425,7 +1430,12 @@ @implementation GGMLMetalClass
14251430
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
14261431
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14271432
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1428-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1433+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1434+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1435+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1436+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1437+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1438+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14291439
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14301440
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14311441
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3935,7 +3945,17 @@ static int ggml_metal_encode_node(
39353945

39363946
id<MTLComputePipelineState> pipeline = nil;
39373947

3938-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
3948+
pipeline = nil;
3949+
3950+
switch (ne20) {
3951+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
3952+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
3953+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
3954+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
3955+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
3956+
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
3957+
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
3958+
}
39393959

39403960
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
39413961

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7505,7 +7505,7 @@ kernel void kernel_mul_mm(
75057505
}
75067506
}
75077507

7508-
template<typename T4>
7508+
template<short ne20> // n_expert_used
75097509
kernel void kernel_mul_mm_id_map0(
75107510
constant ggml_metal_kargs_mul_mm_id_map0 & args,
75117511
device const char * src2,
@@ -7518,31 +7518,38 @@ kernel void kernel_mul_mm_id_map0(
75187518

75197519
uint32_t n_all = 0;
75207520

7521-
device int32_t * ids_i32 = (device int32_t *) (hids);
7521+
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
75227522

75237523
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
7524-
{
7524+
if (i21 + tpitg < args.ne21) {
75257525
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
75267526

7527-
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*args.ne20;
7527+
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
75287528

7529-
for (int i20 = 0; i20 < args.ne20 && i21 + tpitg < args.ne21; i20++) {
7529+
#pragma unroll(ne20)
7530+
for (short i20 = 0; i20 < ne20; i20++) {
75307531
sids[i20] = src2_i32[i20];
75317532
}
75327533
}
75337534

75347535
threadgroup_barrier(mem_flags::mem_threadgroup);
75357536

7536-
for (int t = 0; t < ntg && i21 + t < args.ne21; t++) {
7537-
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + t*args.ne20;
7537+
for (short t = 0; t < ntg; t++) {
7538+
if (i21 + t >= args.ne21) {
7539+
break;
7540+
}
75387541

7539-
for (int i20 = 0; i20 < args.ne20; i20++) {
7540-
if (sids[i20] == ide) {
7541-
ids_i32[ide*args.ne21 + n_all] = (i21 + t)*args.ne20 + i20;
7542-
++n_all;
7543-
break;
7544-
}
7542+
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
7543+
7544+
short sel = 0;
7545+
#pragma unroll(ne20)
7546+
for (short i20 = 0; i20 < ne20; i20++) {
7547+
sel += (sids[i20] == ide)*(i20 + 1);
75457548
}
7549+
7550+
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
7551+
7552+
n_all += sel > 0;
75467553
}
75477554

75487555
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -7552,9 +7559,14 @@ kernel void kernel_mul_mm_id_map0(
75527559
tpe_u32[ide] = n_all;
75537560
}
75547561

7555-
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
7562+
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
75567563

7557-
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
7564+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
7565+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
7566+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
7567+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
7568+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
7569+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
75587570

75597571
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
75607572
kernel void kernel_mul_mm_id(
@@ -7580,7 +7592,7 @@ kernel void kernel_mul_mm_id(
75807592
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
75817593
device const int32_t * ids_i32 = (device const int32_t *) (hids);
75827594

7583-
const uint32_t neh1 = tpe_u32[im];
7595+
const int32_t neh1 = tpe_u32[im];
75847596

75857597
if (r1*BLOCK_SIZE_N >= neh1) {
75867598
return;

0 commit comments

Comments
 (0)