88using namespace metal ;
99
1010inline void unpack_3bit (const uchar3 b, thread float * w) {
11- w[0 ] = float (((b[0 ] & 1 ) << 2 ) | (b[1 ] & 3 ));
12- w[1 ] = float (((b[0 ] & 2 ) << 1 ) | ((b[1 ] & 12 ) >> 2 ));
13- w[2 ] = float ((b[0 ] & 4 ) | ((b[1 ] & 48 ) >> 4 ));
14- w[3 ] = float (((b[0 ] & 8 ) >> 1 ) | ((b[1 ] & 192 ) >> 6 ));
15-
16- w[4 ] = float (((b[0 ] & 16 ) >> 2 ) | (b[2 ] & 3 ));
17- w[5 ] = float (((b[0 ] & 32 ) >> 3 ) | ((b[2 ] & 12 ) >> 2 ));
18- w[6 ] = float (((b[0 ] & 64 ) >> 4 ) | ((b[2 ] & 48 ) >> 4 ));
19- w[7 ] = float (((b[0 ] & 128 ) >> 5 ) | ((b[2 ] & 192 ) >> 6 ));
11+ w[0 ] = float (b[0 ] & 0x07 );
12+ w[1 ] = float ((b[0 ] & 0x38 ) >> 3 );
13+ w[2 ] = float (((b[0 ] & 0xc0 ) >> 6 ) | ((b[1 ] & 0x01 ) << 2 ));
14+ w[3 ] = float ((b[1 ] & 0x0e ) >> 1 );
15+ w[4 ] = float ((b[1 ] & 0x70 ) >> 4 );
16+ w[5 ] = float (((b[1 ] & 0x80 ) >> 7 ) | ((b[2 ] & 0x03 ) << 1 ));
17+ w[6 ] = float ((b[2 ] & 0x1c ) >> 2 );
18+ w[7 ] = float ((b[2 ] & 0xe0 ) >> 5 );
2019}
2120
2221/* *
2322 * 3-Bit Quantized Linear.
2423 *
2524 * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16)
2625 * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
27- * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
28- * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
26+ * @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups
27+ * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups
2928 * @param[outputData] M x N output tensor of floating point dtype (same as input)
3029 * @param[sizes] The sizes involved in the order: M, K, N
3130 *
@@ -45,6 +44,7 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]],
4544 constexpr uint k_pack_factor = 8 ;
4645 const uint K = sizes.y ;
4746 const uint N = sizes.z ;
47+ const uint num_groups = (K + group_size - 1 ) / group_size;
4848 uint n = thread_index.x ; // 0..N/4-1
4949 uint m = thread_index.z ; // 0..M
5050 n = n / threads_per_channel;
@@ -64,12 +64,18 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]],
6464 // Find specific group to which channels handled by this thread
6565 // belong.
6666 uint k_block_index = k / group_size;
67- uint scales_group_offset = (k_block_index * N + n );
67+ uint scales_group_offset = (n * num_groups + k_block_index );
6868
6969 vecT scales =
70- (reinterpret_cast <constant vecT *>(scales_ptr + scales_group_offset))[0 ];
70+ vecT (scales_ptr[scales_group_offset],
71+ scales_ptr[scales_group_offset + num_groups],
72+ scales_ptr[scales_group_offset + 2 * num_groups],
73+ scales_ptr[scales_group_offset + 3 * num_groups]);
7174 vecT zeros =
72- (reinterpret_cast <constant vecT *>(zeros_ptr + scales_group_offset))[0 ];
75+ vecT (zeros_ptr[scales_group_offset],
76+ zeros_ptr[scales_group_offset + num_groups],
77+ zeros_ptr[scales_group_offset + 2 * num_groups],
78+ zeros_ptr[scales_group_offset + 3 * num_groups]);
7379 float4 zeros_float = float4 (zeros);
7480
7581 float4 a_val[2 ];
0 commit comments