Skip to content

Commit c092316

Browse files
committed
metal : mul/div opt
1 parent 47c7b3b commit c092316

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3950,7 +3950,7 @@ static int ggml_metal_encode_node(
39503950
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
39513951
[encoder setBuffer: h_tpe offset:0 atIndex:2];
39523952
[encoder setBuffer: h_ids offset:0 atIndex:3];
3953-
[encoder setThreadgroupMemoryLength:ne02*ne20*sizeof(uint16_t) atIndex:0];
3953+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
39543954

39553955
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
39563956
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -974,9 +974,16 @@ kernel void kernel_mul(
974974
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
975975
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
976976

977-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
978-
const int i10 = i0%args.ne10;
979-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
977+
if (args.ne10 == 1) {
978+
const float x = *((device float *)(src1_ptr));
979+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
980+
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
981+
}
982+
} else {
983+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
984+
const int i10 = i0%args.ne10;
985+
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
986+
}
980987
}
981988
}
982989

@@ -1000,9 +1007,16 @@ kernel void kernel_div(
10001007
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
10011008
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
10021009

1003-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1004-
const int i10 = i0%args.ne10;
1005-
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
1010+
if (args.ne10 == 1) {
1011+
const float x = 1.0f / *((device float *)(src1_ptr));
1012+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1013+
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
1014+
}
1015+
} else {
1016+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1017+
const int i10 = i0%args.ne10;
1018+
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
1019+
}
10061020
}
10071021
}
10081022

0 commit comments

Comments
 (0)