@@ -6313,7 +6313,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
63136313 simdgroup_T8x8 ma[4 ];
63146314 simdgroup_half8x8 mb[2 ];
63156315 simdgroup_half8x8 mc[8 ];
6316- for (int i = 0 ; i < 8 ; i++){
6316+ for (short i = 0 ; i < 8 ; i++){
63176317 mc[i] = make_filled_simdgroup_matrix<half, 8 >(0 .h );
63186318 }
63196319
@@ -6339,7 +6339,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
63396339 threadgroup_barrier (mem_flags::mem_threadgroup);
63406340
63416341 #pragma unroll(16)
6342- for (int i = 0 ; i < 16 ; i++) {
6342+ for (short i = 0 ; i < 16 ; i++) {
63436343 *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8 ) \
63446344 + (tiitg % THREAD_PER_ROW) * 16 + (i / 8 ) * 8 ) \
63456345 + (tiitg / THREAD_PER_ROW) % 8 + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
@@ -6358,22 +6358,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
63586358 threadgroup half * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
63596359
63606360 #pragma unroll(4)
6361- for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6361+ for (short ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
63626362 #pragma unroll(4)
6363- for (int i = 0 ; i < 4 ; i++) {
6363+ for (short i = 0 ; i < 4 ; i++) {
63646364 simdgroup_load (ma[i],lsma + SG_MAT_SIZE * i);
63656365 }
63666366 simdgroup_barrier (mem_flags::mem_none);
63676367 #pragma unroll(2)
6368- for (int i = 0 ; i < 2 ; i++) {
6368+ for (short i = 0 ; i < 2 ; i++) {
63696369 simdgroup_load (mb[i],lsmb + SG_MAT_SIZE * i);
63706370 }
63716371
63726372 lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
63736373 lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
63746374
63756375 #pragma unroll(8)
6376- for (int i = 0 ; i < 8 ; i++){
6376+ for (short i = 0 ; i < 8 ; i++){
63776377 simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
63786378 }
63796379 }
@@ -6382,7 +6382,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
63826382 if ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
63836383 device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1 )) \
63846384 + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6385- for (int i = 0 ; i < 8 ; i++) {
6385+ for (short i = 0 ; i < 8 ; i++) {
63866386 // cast to f32
63876387 simdgroup_float8x8 mc_f32 (1 .0f );
63886388 simdgroup_multiply (mc_f32, mc[i], mc_f32);
@@ -6394,7 +6394,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
63946394 threadgroup_barrier (mem_flags::mem_threadgroup);
63956395 threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
63966396 + 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
6397- for (int i = 0 ; i < 8 ; i++) {
6397+ for (short i = 0 ; i < 8 ; i++) {
63986398 simdgroup_float8x8 mc_f32 (1 .0f );
63996399 simdgroup_multiply (mc_f32, mc[i], mc_f32);
64006400 simdgroup_store (mc_f32, temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
0 commit comments