Skip to content

Commit 2bd90e2

Browse files
committed
metal : mul_mm_id simplify + add test
1 parent f2d9acd commit 2bd90e2

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7543,6 +7543,7 @@ kernel void kernel_mul_mm_id(
75437543
threadgroup char * shmem [[threadgroup(0)]],
75447544
uint3 tgpig[[threadgroup_position_in_grid]],
75457545
ushort tiitg[[thread_index_in_threadgroup]],
7546+
ushort tiisg[[thread_index_in_simdgroup]],
75467547
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
75477548

75487549
threadgroup T * sa = (threadgroup T *)(shmem);
@@ -7648,36 +7649,36 @@ kernel void kernel_mul_mm_id(
76487649
}
76497650

76507651
threadgroup_barrier(mem_flags::mem_threadgroup);
7652+
76517653
threadgroup float * temp_str = ((threadgroup float *) shmem) \
76527654
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
7655+
76537656
for (short i = 0; i < 8; i++) {
76547657
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
76557658
}
76567659

76577660
threadgroup_barrier(mem_flags::mem_threadgroup);
76587661

7659-
if (sgitg == 0) {
7660-
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7661-
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
7662+
for (int j = sgitg; j < n_cols; j += 4) {
7663+
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76627664

7663-
const int ide = id % args.ne20;
7664-
const int idt = id / args.ne20;
7665+
const int ide = id % args.ne20;
7666+
const int idt = id / args.ne20;
76657667

7666-
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
7667-
device float4 * D4 = (device float4 *) D;
7668+
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
7669+
device float4 * D4 = (device float4 *) D;
76687670

7669-
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
7670-
threadgroup float4 * C4 = (threadgroup float4 *) C;
7671+
threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
7672+
threadgroup float4 * C4 = (threadgroup float4 *) C;
76717673

7672-
int i = 0;
7673-
for (; i < n_rows/4; i++) {
7674-
*(D4 + i) = *(C4 + i);
7675-
}
7674+
int i = tiisg;
7675+
for (; i < n_rows/4; i += 32) {
7676+
*(D4 + i) = *(C4 + i);
7677+
}
76767678

7677-
i *= 4;
7678-
for (; i < n_rows; i++) {
7679-
*(D + i) = *(C + i);
7680-
}
7679+
i = (4*(n_rows/4)) + tiisg;
7680+
for (; i < n_rows; i += 32) {
7681+
*(D + i) = *(C + i);
76817682
}
76827683
}
76837684
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6018,6 +6018,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
60186018
for (bool b : {false, true}) {
60196019
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
60206020
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
6021+
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));
60216022
}
60226023

60236024
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));

0 commit comments

Comments
 (0)