Skip to content

Commit acac821

Browse files
committed
cont : remove unnecessary unrolls
1 parent 268ae6c commit acac821

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7882,16 +7882,14 @@ kernel void kernel_mul_mm(
78827882
// no need for dequantization
78837883
if (FC_mul_mm_bounds_check) {
78847884
// bounds checks are required
7885-
#pragma unroll(16)
78867885
for (short i = 0; i < 16; i++) {
78877886
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
78887887
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
78897888
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
78907889
}
78917890
} else {
78927891
// do not perform bounds checks
7893-
#pragma unroll(16)
7894-
for (short i = 0; i < 16; i++) {
7892+
FOR_UNROLL (short i = 0; i < 16; i++) {
78957893
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
78967894
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
78977895
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = ((device T0 *) x)[i];
@@ -7903,8 +7901,7 @@ kernel void kernel_mul_mm(
79037901

79047902
threadgroup_barrier(mem_flags::mem_threadgroup);
79057903

7906-
#pragma unroll(16)
7907-
for (short i = 0; i < 16; i++) {
7904+
FOR_UNROLL (short i = 0; i < 16; i++) {
79087905
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
79097906
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
79107907
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
@@ -8137,16 +8134,14 @@ kernel void kernel_mul_mm_id(
81378134
// no need for dequantization
81388135
if (FC_mul_mm_bounds_check) {
81398136
// bounds checks are required
8140-
#pragma unroll(16)
81418137
for (short i = 0; i < 16; i++) {
81428138
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
81438139
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
81448140
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
81458141
}
81468142
} else {
81478143
// do not perform bounds checks
8148-
#pragma unroll(16)
8149-
for (short i = 0; i < 16; i++) {
8144+
FOR_UNROLL (short i = 0; i < 16; i++) {
81508145
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
81518146
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
81528147
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = ((device T0 *) x)[i];
@@ -8158,8 +8153,7 @@ kernel void kernel_mul_mm_id(
81588153

81598154
threadgroup_barrier(mem_flags::mem_threadgroup);
81608155

8161-
#pragma unroll(16)
8162-
for (short i = 0; i < 16; i++) {
8156+
FOR_UNROLL (short i = 0; i < 16; i++) {
81638157
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
81648158
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
81658159
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];

0 commit comments

Comments
 (0)