@@ -351,7 +351,7 @@ kernel void kernel_rms_norm(
351351
352352 threadgroup_barrier (mem_flags::mem_threadgroup);
353353 // broadcast, simd group number is ntg / 32
354- for (int i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
354+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
355355 if (tpitg < i) {
356356 sum[tpitg] += sum[tpitg + i];
357357 }
@@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32(
13391339 }
13401340}
13411341
1342+ #if QK_K == 256
13421343kernel void kernel_mul_mat_q3_K_f32 (
13431344 device const void * src0,
13441345 device const float * src1,
@@ -1347,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
13471348 constant int64_t & ne10,
13481349 constant int64_t & ne0,
13491350 constant int64_t & ne1,
1350- threadgroup float * sum [[threadgroup(0 )]],
13511351 uint2 tgpig[[threadgroup_position_in_grid]],
1352- uint2 tpitg[[thread_position_in_threadgroup ]],
1353- uint2 tptg[[threads_per_threadgroup ]]) {
1352+ uint tiisg[[thread_index_in_simdgroup ]],
1353+ uint sgitg[[simdgroup_index_in_threadgroup ]]) {
13541354
13551355 const int nb = ne00/QK_K;
13561356
13571357 const int64_t r0 = tgpig.x ;
13581358 const int64_t r1 = tgpig.y ;
13591359
1360- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1361- device const float * yy = (device const float *) src1 + r1*ne10;
1362-
1363- const int nth = tptg.x *tptg.y ;
1364- const int ith = tptg.y *tpitg.x + tpitg.y ;
1360+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
13651361
1366- #if QK_K == 256
1362+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1363+ device const float * yy = (device const float *) src1 + r1*ne10;
13671364
1368- const uint8_t m3 = 3 ;
1369- const int8_t m4 = 4 ;
1365+ float yl[16 ];
13701366
13711367 const uint16_t kmask1 = 0x0303 ;
13721368 const uint16_t kmask2 = 0x0f0f ;
13731369
1374- const int tid = tpitg.y ; // expecting 16
1370+ const int tid = tiisg/2 ;
1371+ const int ix = tiisg%2 ;
13751372 const int ip = tid/8 ; // 0 or 1
13761373 const int il = tid/2 - 4 *ip; // 0...3
13771374 const int ir = tid%2 ;
13781375 const int n = 8 ;
13791376 const int l0 = n*ir;
13801377
1381- const uint8_t m = 1 << (4 *ip + il);
1378+ const uint16_t m1 = 1 << (4 *ip + il);
1379+ const uint16_t m2 = m1 << 8 ;
13821380
13831381 const int shift = 2 *il;
1382+ const uint16_t qm1 = 0x0003 << shift;
1383+ const uint16_t qm2 = 0x0300 << shift;
1384+ const int32_t v1 = 4 << shift;
1385+ const int32_t v2 = 1024 << shift;
13841386
13851387 const uint16_t s_shift1 = 4 *ip;
13861388 const uint16_t s_shift2 = s_shift1 + 2 *(il/2 );
@@ -1389,93 +1391,132 @@ kernel void kernel_mul_mat_q3_K_f32(
13891391 const int q_offset = 32 *ip + l0;
13901392 const int y_offset = 128 *ip + 32 *il + l0;
13911393
1392- // float sumf = 0;
1393- float sumf1 = 0 , sumf2 = 0 ;
1394- for (int i = tpitg.x ; i < nb; i += tptg.x ) {
1394+ const int step = sizeof (block_q3_K) * nb / 2 ;
13951395
1396- const float d_all = (float )(x[i].d );
1397-
1398- device const uint8_t * q = x[i].qs + q_offset;
1399- device const uint8_t * h = x[i].hmask + l0;
1400- device const float * y = yy + i * QK_K + y_offset;
1396+ device const float * y1 = yy + ix*QK_K + y_offset;
14011397
1402- device const uint16_t * a = (device const uint16_t *)x[i]. scales ;
1403- const char2 scales = as_type<char2>(( uint16_t )(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4 )));
1398+ float sumf1[ 2 ] = { 0 . f }, sumf2[ 2 ] = { 0 . f } ;
1399+ for ( int i = ix; i < nb; i += 2 ) {
14041400
1405- float s = 0 ;
1406- for ( int l = 0 ; l < n; ++l) {
1407- s += y [l+ 0 ] * (( int8_t )((q [l+ 0 ] >> shift) & m3) - ((h[l+ 0 ] & m) ? 0 : m4)) ;
1401+ for ( int l = 0 ; l < 8 ; ++l) {
1402+ yl[l+ 0 ] = y1[l+ 0 ];
1403+ yl [l+8 ] = y1 [l+16 ] ;
14081404 }
1409- float d = d_all * s;
1410- sumf1 += d * scales[0 ];
1411- sumf2 += d;
1412- // sumf += d_all * s * (scales[0] - 32);
14131405
1414- s = 0 ;
1415- for (int l = 0 ; l < n; ++l) {
1416- s += y[l+16 ] * ((int8_t )((q[l+16 ] >> shift) & m3) - ((h[l+16 ] & m) ? 0 : m4));
1406+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1407+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1408+ device const uint16_t * a = (device const uint16_t *)(x[i].scales );
1409+ device const half * dh = &x[i].d ;
1410+
1411+ for (int row = 0 ; row < 2 ; ++row) {
1412+
1413+ const float d_all = (float )dh[0 ];
1414+ const char2 scales = as_type<char2>((uint16_t )(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4 )));
1415+
1416+ float s1 = 0 , s2 = 0 ;
1417+ for (int l = 0 ; l < n; l += 2 ) {
1418+ const uint16_t qs = q[l/2 ];
1419+ s1 += yl[l+0 ] * ((int32_t )(qs & qm1) - ((h[l/2 ] & m1) ? 0 : v1));
1420+ s2 += yl[l+1 ] * ((int32_t )(qs & qm2) - ((h[l/2 ] & m2) ? 0 : v2));
1421+ }
1422+ float d = d_all * (s1 + 1 .f /256 .f * s2);
1423+ sumf1[row] += d * scales[0 ];
1424+ sumf2[row] += d;
1425+
1426+ s1 = s2 = 0 ;
1427+ for (int l = 0 ; l < n; l += 2 ) {
1428+ const uint16_t qs = q[l/2 +8 ];
1429+ s1 += yl[l+8 ] * ((int32_t )(qs & qm1) - ((h[l/2 +8 ] & m1) ? 0 : v1));
1430+ s2 += yl[l+9 ] * ((int32_t )(qs & qm2) - ((h[l/2 +8 ] & m2) ? 0 : v2));
1431+ }
1432+ d = d_all * (s1 + 1 .f /256 .f * s2);
1433+ sumf1[row] += d * scales[1 ];
1434+ sumf2[row] += d;
1435+
1436+ q += step;
1437+ h += step;
1438+ a += step;
1439+ dh += step;
1440+
14171441 }
1418- d = d_all * s;
1419- sumf1 += d * scales[1 ];
1420- sumf2 += d;
1421- // sumf += d_all * s * (scales[1] - 32);
1442+
1443+ y1 += 2 * QK_K;
14221444
14231445 }
14241446
1425- // sum[ith] = sumf;
1426- sum[ith] = sumf1 - 32 .f *sumf2;
1447+ for (int row = 0 ; row < 2 ; ++row) {
1448+ const float sumf = (sumf1[row] - 32 .f *sumf2[row]) / (1 << shift);
1449+ const float tot = simd_sum (sumf);
1450+ if (tiisg == 0 ) {
1451+ dst[r1*ne0 + first_row + row] = tot;
1452+ }
1453+ }
1454+ }
14271455#else
1428- const int il = 4 * tpitg.x ; // 0, 4, 8, 12
1456+ kernel void kernel_mul_mat_q3_K_f32 (
1457+ device const void * src0,
1458+ device const float * src1,
1459+ device float * dst,
1460+ constant int64_t & ne00,
1461+ constant int64_t & ne10,
1462+ constant int64_t & ne0,
1463+ constant int64_t & ne1,
1464+ uint2 tgpig[[threadgroup_position_in_grid]],
1465+ uint tiisg[[thread_index_in_simdgroup]],
1466+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1467+
1468+ const int nb = ne00/QK_K;
1469+
1470+ const int64_t r0 = tgpig.x ;
1471+ const int64_t r1 = tgpig.y ;
1472+
1473+ const int row = 2 * r0 + sgitg;
1474+
1475+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1476+ device const float * yy = (device const float *) src1 + r1*ne10;
1477+ const int ix = tiisg/4 ;
1478+ const int il = 4 * (tiisg%4 );// 0, 4, 8, 12
14291479 const int im = il/8 ; // 0, 0, 1, 1
14301480 const int in = il%8 ; // 0, 4, 0, 4
14311481
1432- float sumf = 0 ;
1482+ float2 sum = { 0 . f , 0 . f } ;
14331483
1434- for (int i = tpitg. y ; i < nb; i += tptg. y ) {
1484+ for (int i = ix ; i < nb; i += 8 ) {
14351485
14361486 const float d_all = (float )(x[i].d );
14371487
1438- device const uint8_t * q = x[i].qs + il;
1439- device const uint8_t * h = x[i].hmask + in;
1440- device const float * y = yy + i * QK_K + il;
1441-
1442- const float d1 = d_all * ((x[i].scales [0 ] & 0xF ) - 8 );
1443- const float d2 = d_all * ((x[i].scales [0 ] >> 4 ) - 8 );
1444- const float d3 = d_all * ((x[i].scales [1 ] & 0xF ) - 8 );
1445- const float d4 = d_all * ((x[i].scales [1 ] >> 4 ) - 8 );
1446-
1447- for (int l = 0 ; l < 4 ; ++l) {
1448- const uint8_t hm = h[l] >> im;
1449- sumf += y[l+ 0 ] * d1 * ((int8_t )((q[l+0 ] >> 0 ) & 3 ) - ((hm & 0x01 ) ? 0 : 4 ))
1450- + y[l+16 ] * d2 * ((int8_t )((q[l+0 ] >> 2 ) & 3 ) - ((hm & 0x04 ) ? 0 : 4 ))
1451- + y[l+32 ] * d3 * ((int8_t )((q[l+0 ] >> 4 ) & 3 ) - ((hm & 0x10 ) ? 0 : 4 ))
1452- + y[l+48 ] * d4 * ((int8_t )((q[l+0 ] >> 6 ) & 3 ) - ((hm & 0x40 ) ? 0 : 4 ));
1488+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1489+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1490+ device const uint16_t * s = (device const uint16_t *)(x[i].scales );
1491+ device const float * y = yy + i * QK_K + il;
1492+
1493+ const float d1 = d_all * ((int32_t )(s[0 ] & 0x000F ) - 8 );
1494+ const float d2 = d_all * ((int32_t )(s[0 ] & 0x00F0 ) - 128 ) * 1 .f /64 .f ;
1495+ const float d3 = d_all * ((int32_t )(s[0 ] & 0x0F00 ) - 2048 ) * 1 .f /4096 .f ;
1496+ const float d4 = d_all * ((int32_t )(s[0 ] & 0xF000 ) - 32768 ) * 1 .f /262144 .f ;
1497+
1498+ for (int l = 0 ; l < 4 ; l += 2 ) {
1499+ const uint16_t hm = h[l/2 ] >> im;
1500+ sum[0 ] += y[l+ 0 ] * d1 * ((int32_t )(q[l/2 ] & 0x0003 ) - ((hm & 0x0001 ) ? 0 : 4 ))
1501+ + y[l+16 ] * d2 * ((int32_t )(q[l/2 ] & 0x000c ) - ((hm & 0x0004 ) ? 0 : 16 ))
1502+ + y[l+32 ] * d3 * ((int32_t )(q[l/2 ] & 0x0030 ) - ((hm & 0x0010 ) ? 0 : 64 ))
1503+ + y[l+48 ] * d4 * ((int32_t )(q[l/2 ] & 0x00c0 ) - ((hm & 0x0040 ) ? 0 : 256 ));
1504+ sum[1 ] += y[l+ 1 ] * d1 * ((int32_t )(q[l/2 ] & 0x0300 ) - ((hm & 0x0100 ) ? 0 : 1024 ))
1505+ + y[l+17 ] * d2 * ((int32_t )(q[l/2 ] & 0x0c00 ) - ((hm & 0x0400 ) ? 0 : 4096 ))
1506+ + y[l+33 ] * d3 * ((int32_t )(q[l/2 ] & 0x3000 ) - ((hm & 0x1000 ) ? 0 : 16384 ))
1507+ + y[l+49 ] * d4 * ((int32_t )(q[l/2 ] & 0xc000 ) - ((hm & 0x4000 ) ? 0 : 65536 ));
14531508 }
14541509
14551510 }
1511+ const float sumf = sum[0 ] + sum[1 ] * 1 .f /256 .f ;
14561512
1457- sum[ith] = sumf;
1458-
1459- #endif
1460-
1461- //
1462- // Accumulate the sum from all threads in the threadgroup
1463- //
1464- threadgroup_barrier (mem_flags::mem_threadgroup);
1465- if (ith%4 == 0 ) {
1466- for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
1467- }
1468- threadgroup_barrier (mem_flags::mem_threadgroup);
1469- if (ith%16 == 0 ) {
1470- for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
1471- }
1472- threadgroup_barrier (mem_flags::mem_threadgroup);
1473- if (ith == 0 ) {
1474- for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
1475- dst[r1*ne0 + r0] = sum[0 ];
1513+ const float tot = simd_sum (sumf);
1514+ if (tiisg == 0 ) {
1515+ dst[r1*ne0 + row] = tot;
14761516 }
14771517
14781518}
1519+ #endif
14791520
14801521#if QK_K == 256
14811522kernel void kernel_mul_mat_q4_K_f32 (
@@ -1773,7 +1814,6 @@ kernel void kernel_mul_mat_q5_K_f32(
17731814
17741815 for (int i = ix; i < nb; i += 8 ) {
17751816
1776- float4 sumy = {0 .f , 0 .f , 0 .f , 0 .f };
17771817 for (int l = 0 ; l < 4 ; ++l) {
17781818 yl[l+0 ] = y[l+ 0 ];
17791819 yl[l+4 ] = y[l+16 ];
0 commit comments