@@ -7882,16 +7882,14 @@ kernel void kernel_mul_mm(
78827882 // no need for dequantization
78837883 if (FC_mul_mm_bounds_check) {
78847884 // bounds checks are required
7885- #pragma unroll(16)
78867885 for (short i = 0 ; i < 16 ; i++) {
78877886 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
78887887 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
78897888 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00 ? ((device T0 *) x)[i] : 0 ;
78907889 }
78917890 } else {
78927891 // do not perform bounds checks
7893- #pragma unroll(16)
7894- for (short i = 0 ; i < 16 ; i++) {
7892+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
78957893 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
78967894 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
78977895 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = ((device T0 *) x)[i];
@@ -7903,8 +7901,7 @@ kernel void kernel_mul_mm(
79037901
79047902 threadgroup_barrier (mem_flags::mem_threadgroup);
79057903
7906- #pragma unroll(16)
7907- for (short i = 0 ; i < 16 ; i++) {
7904+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
79087905 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
79097906 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
79107907 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
@@ -8137,16 +8134,14 @@ kernel void kernel_mul_mm_id(
81378134 // no need for dequantization
81388135 if (FC_mul_mm_bounds_check) {
81398136 // bounds checks are required
8140- #pragma unroll(16)
81418137 for (short i = 0 ; i < 16 ; i++) {
81428138 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
81438139 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
81448140 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00 ? ((device T0 *) x)[i] : 0 ;
81458141 }
81468142 } else {
81478143 // do not perform bounds checks
8148- #pragma unroll(16)
8149- for (short i = 0 ; i < 16 ; i++) {
8144+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
81508145 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
81518146 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
81528147 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = ((device T0 *) x)[i];
@@ -8158,8 +8153,7 @@ kernel void kernel_mul_mm_id(
81588153
81598154 threadgroup_barrier (mem_flags::mem_threadgroup);
81608155
8161- #pragma unroll(16)
8162- for (short i = 0 ; i < 16 ; i++) {
8156+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
81638157 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
81648158 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
81658159 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
0 commit comments