@@ -7497,31 +7497,45 @@ kernel void kernel_mul_mm_id_map0(
74977497 device const char * src2,
74987498 device char * htpe,
74997499 device char * hids,
7500- uint3 tgpig[[threadgroup_position_in_grid ]],
7501- ushort3 tpitg[[thread_position_in_threadgroup]],
7502- ushort3 ntg[[threads_per_threadgroup]]) {
7503- const int ide = tgpig[ 0 ] ; // expert id
7500+ threadgroup char * shmem [[threadgroup( 0 ) ]],
7501+ ushort tpitg[[thread_position_in_threadgroup]],
7502+ ushort ntg[[threads_per_threadgroup]]) {
7503+ const short ide = tpitg ; // expert id
75047504
7505- int n_all = 0 ;
7505+ uint32_t n_all = 0 ;
75067506
75077507 device int32_t * ids_i32 = (device int32_t *) (hids);
75087508
7509- for (int i21 = 0 ; i21 < args.ne21 ; i21++) { // n_tokens
7510- device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21 );
7509+ for (int i21 = 0 ; i21 < args.ne21 ; i21 += ntg) { // n_tokens
7510+ {
7511+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21 );
75117512
7512- for (int i20 = 0 ; i20 < args.ne20 ; i20++) { // n_expert_used
7513- if (src2_i32[i20] != ide) {
7514- continue ;
7513+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*args.ne20 ;
7514+
7515+ for (int i20 = 0 ; i20 < args.ne20 && i21 + tpitg < args.ne21 ; i20++) {
7516+ sids[i20] = src2_i32[i20];
75157517 }
7518+ }
7519+
7520+ threadgroup_barrier (mem_flags::mem_threadgroup);
75167521
7517- ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
7522+ for (int t = 0 ; t < ntg && i21 + t < args.ne21 ; t++) {
7523+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + t*args.ne20 ;
75187524
7519- ++n_all;
7525+ for (int i20 = 0 ; i20 < args.ne20 ; i20++) {
7526+ if (sids[i20] == ide) {
7527+ ids_i32[ide*args.ne21 + n_all] = (i21 + t)*args.ne20 + i20;
7528+ ++n_all;
7529+ break ;
7530+ }
7531+ }
75207532 }
7533+
7534+ threadgroup_barrier (mem_flags::mem_threadgroup);
75217535 }
75227536
7523- device int32_t * tpe_i32 = (device int32_t *) (htpe);
7524- tpe_i32 [ide] = n_all;
7537+ device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
7538+ tpe_u32 [ide] = n_all;
75257539}
75267540
75277541typedef decltype (kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
@@ -7549,10 +7563,10 @@ kernel void kernel_mul_mm_id(
75497563 const int r1 = tgpig.x ;
75507564 const int im = tgpig.z ; // expert
75517565
7552- device const int32_t * tpe_i32 = (device const int32_t *) (htpe);
7553- device const int32_t * ids_i32 = (device const int32_t *) (hids);
7566+ device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
7567+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
75547568
7555- const int neh1 = tpe_i32 [im];
7569+ const uint32_t neh1 = tpe_u32 [im];
75567570
75577571 if (r1*BLOCK_SIZE_N >= neh1) {
75587572 return ;
@@ -7578,9 +7592,9 @@ kernel void kernel_mul_mm_id(
75787592
75797593 const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
75807594
7581- const int i11 = (id % args.ne20 ) % args.ne11 ;
7582- const int i12 = (id / args.ne20 );
7583- const int i13 = 0 ;
7595+ const short i11 = (id % args.ne20 ) % args.ne11 ;
7596+ const short i12 = (id / args.ne20 );
7597+ const short i13 = 0 ;
75847598
75857599 const uint64_t offset0 = im*args.nb02 + i13*args.nb03 ;
75867600 const short offset1 = il/nl;
@@ -7649,17 +7663,18 @@ kernel void kernel_mul_mm_id(
76497663 threadgroup float * temp_str = ((threadgroup float *) shmem) \
76507664 + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
76517665
7666+ #pragma unroll(8)
76527667 for (short i = 0 ; i < 8 ; i++) {
76537668 simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
76547669 }
76557670
76567671 threadgroup_barrier (mem_flags::mem_threadgroup);
76577672
7658- for (int j = sgitg; j < n_cols; j += 4 ) {
7673+ for (short j = sgitg; j < n_cols; j += 4 ) {
76597674 const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76607675
7661- const int ide = id % args.ne20 ;
7662- const int idt = id / args.ne20 ;
7676+ const short ide = id % args.ne20 ;
7677+ const short idt = id / args.ne20 ;
76637678
76647679 device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
76657680 device float4 * D4 = (device float4 *) D;
0 commit comments