@@ -7543,6 +7543,7 @@ kernel void kernel_mul_mm_id(
75437543 threadgroup char * shmem [[threadgroup(0 )]],
75447544 uint3 tgpig[[threadgroup_position_in_grid]],
75457545 ushort tiitg[[thread_index_in_threadgroup]],
7546+ ushort tiisg[[thread_index_in_simdgroup]],
75467547 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
75477548
75487549 threadgroup T * sa = (threadgroup T *)(shmem);
@@ -7648,36 +7649,36 @@ kernel void kernel_mul_mm_id(
76487649 }
76497650
76507651 threadgroup_barrier (mem_flags::mem_threadgroup);
7652+
76517653 threadgroup float * temp_str = ((threadgroup float *) shmem) \
76527654 + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
7655+
76537656 for (short i = 0 ; i < 8 ; i++) {
76547657 simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
76557658 }
76567659
76577660 threadgroup_barrier (mem_flags::mem_threadgroup);
76587661
7659- if (sgitg == 0 ) {
7660- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7661- const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
7662+ for (int j = sgitg; j < n_cols; j += 4 ) {
7663+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76627664
7663- const int ide = id % args.ne20 ;
7664- const int idt = id / args.ne20 ;
7665+ const int ide = id % args.ne20 ;
7666+ const int idt = id / args.ne20 ;
76657667
7666- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
7667- device float4 * D4 = (device float4 *) D;
7668+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
7669+ device float4 * D4 = (device float4 *) D;
76687670
7669- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
7670- threadgroup float4 * C4 = (threadgroup float4 *) C;
7671+ threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
7672+ threadgroup float4 * C4 = (threadgroup float4 *) C;
76717673
7672- int i = 0 ;
7673- for (; i < n_rows/4 ; i++ ) {
7674- *(D4 + i) = *(C4 + i);
7675- }
7674+ int i = tiisg ;
7675+ for (; i < n_rows/4 ; i += 32 ) {
7676+ *(D4 + i) = *(C4 + i);
7677+ }
76767678
7677- i *= 4 ;
7678- for (; i < n_rows; i++) {
7679- *(D + i) = *(C + i);
7680- }
7679+ i = (4 *(n_rows/4 )) + tiisg;
7680+ for (; i < n_rows; i += 32 ) {
7681+ *(D + i) = *(C + i);
76817682 }
76827683 }
76837684}
0 commit comments