Skip to content

Commit 8a3c0b3

Browse files
committed
metal : opt mul_mm_id map0
1 parent 2bd90e2 commit 8a3c0b3

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3924,8 +3924,6 @@ static int ggml_metal_encode_node(
39243924
}
39253925

39263926
{
3927-
const int nth = MIN(1024, ne10/4);
3928-
39293927
ggml_metal_kargs_mul_mm_id_map0 args = {
39303928
ne10,
39313929
ne11, // n_expert_used (bcast)
@@ -3946,7 +3944,7 @@ static int ggml_metal_encode_node(
39463944
[encoder setBuffer: h_tpe offset:0 atIndex:2];
39473945
[encoder setBuffer: h_ids offset:0 atIndex:3];
39483946

3949-
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3947+
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
39503948
}
39513949

39523950
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7514,18 +7514,14 @@ kernel void kernel_mul_mm_id_map0(
75147514
continue;
75157515
}
75167516

7517-
if (tpitg.x == 0) {
7518-
ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
7519-
}
7517+
ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
75207518

75217519
++n_all;
75227520
}
75237521
}
75247522

7525-
if (tpitg.x == 0) {
7526-
device int32_t * tpe_i32 = (device int32_t *) (htpe);
7527-
tpe_i32[ide] = n_all;
7528-
}
7523+
device int32_t * tpe_i32 = (device int32_t *) (htpe);
7524+
tpe_i32[ide] = n_all;
75297525
}
75307526

75317527
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;

0 commit comments

Comments
 (0)