From c200b29eeff4d8143a67cc8bce18e335586a4c53 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 2 May 2025 22:34:40 -0400 Subject: [PATCH] metal lowbit kernels: qmv_fast optimization --- torchao/experimental/kernels/mps/metal.yaml | 3 + .../kernels/mps/metal/int1mm.metal | 9 +- .../kernels/mps/metal/int2mm_opt.metal | 19 +- .../kernels/mps/metal/int3mm_opt.metal | 34 +- .../kernels/mps/metal/int4mm_opt.metal | 18 +- .../kernels/mps/metal/int5mm.metal | 26 +- .../kernels/mps/metal/int6mm.metal | 25 +- .../kernels/mps/metal/int7mm.metal | 26 +- .../kernels/mps/metal/qmv_fast.metal | 364 ++++++++++++++++++ .../experimental/kernels/mps/src/dispatch.h | 14 + torchao/experimental/kernels/mps/src/lowbit.h | 28 +- .../experimental/kernels/mps/src/packing.h | 141 +++---- .../kernels/mps/test/test_lowbit.mm | 6 +- .../ops/mps/linear_fp_act_xbit_weight_aten.mm | 12 +- .../linear_fp_act_xbit_weight_executorch.mm | 8 +- torchao/experimental/ops/mps/mps_op_lib.py | 4 +- .../experimental/ops/mps/test/test_lowbit.py | 8 +- .../ops/mps/test/test_quantizer.py | 8 +- torchao/experimental/quant_api.py | 4 +- 19 files changed, 569 insertions(+), 188 deletions(-) create mode 100644 torchao/experimental/kernels/mps/metal/qmv_fast.metal diff --git a/torchao/experimental/kernels/mps/metal.yaml b/torchao/experimental/kernels/mps/metal.yaml index eb837432c7..dfad7ad715 100644 --- a/torchao/experimental/kernels/mps/metal.yaml +++ b/torchao/experimental/kernels/mps/metal.yaml @@ -21,3 +21,6 @@ - func: int7mm file: int7mm.metal + +- func: qmv_fast + file: qmv_fast.metal diff --git a/torchao/experimental/kernels/mps/metal/int1mm.metal b/torchao/experimental/kernels/mps/metal/int1mm.metal index a76d66041b..51e8558e9c 100644 --- a/torchao/experimental/kernels/mps/metal/int1mm.metal +++ b/torchao/experimental/kernels/mps/metal/int1mm.metal @@ -11,8 +11,8 @@ using namespace metal; * * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8) - * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N - * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N + * @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -29,6 +29,7 @@ kernel void int1pack_mm( uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; + const uint num_groups = (K + groupSize - 1) / groupSize; const uint m = thread_index.y; // 0..M-1 const uint n = thread_index.x; // 0..N-1 const uint32_t k_block = (K + groupSize - 1) / groupSize; @@ -38,8 +39,8 @@ kernel void int1pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scales[kb * N + n]); - const float zero = float(zeros[kb * N + n]); + const float scale = float(scales[n * num_groups + kb]); + const float zero = float(zeros[n * num_groups + kb]); for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); diff --git a/torchao/experimental/kernels/mps/metal/int2mm_opt.metal b/torchao/experimental/kernels/mps/metal/int2mm_opt.metal index 6008de6730..f42a6e44e9 100644 --- a/torchao/experimental/kernels/mps/metal/int2mm_opt.metal +++ b/torchao/experimental/kernels/mps/metal/int2mm_opt.metal @@ -26,12 +26,11 @@ using namespace metal; @param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit values, along K dim, packed together. @param [in] scales_ptr is scales ptr corresponding each - output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output channels. @param [in] zeros_ptr is zero points corresponding each - output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output channels. - output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output @param [out] output_data is output matrix of size M x N. @param [in] sizes array contains values of M, K and N. @param [in] thread_index is global thread id. @@ -51,6 +50,7 @@ kernel void int2pack_mm(constant T *A [[buffer(0)]], constexpr uint k_pack_factor = 4; const uint K = sizes.y; const uint N = sizes.z; + const uint num_groups = (K + group_size - 1) / group_size; uint n = thread_index.x; // 0..N/4-1 uint m = thread_index.z; // 0..M n = n / threads_per_channel; @@ -75,13 +75,18 @@ kernel void int2pack_mm(constant T *A [[buffer(0)]], // Find specific group to which channels handled by this thread // belong. uint k_block_index = k / group_size; - uint scales_group_offset = (k_block_index * N + n); + uint scales_group_offset = (n * num_groups + k_block_index); vecT scales = - (reinterpret_cast(scales_ptr + scales_group_offset))[0]; - // Adding zero point results in 10% perf penalty. + vecT(scales_ptr[scales_group_offset], + scales_ptr[scales_group_offset + num_groups], + scales_ptr[scales_group_offset + 2 * num_groups], + scales_ptr[scales_group_offset + 3 * num_groups]); vecT zeros = - (reinterpret_cast(zeros_ptr + scales_group_offset))[0]; + vecT(zeros_ptr[scales_group_offset], + zeros_ptr[scales_group_offset + num_groups], + zeros_ptr[scales_group_offset + 2 * num_groups], + zeros_ptr[scales_group_offset + 3 * num_groups]); float4 zeros_float = float4(zeros); float4 a_val = float4(A_ptr[k / 4]); diff --git a/torchao/experimental/kernels/mps/metal/int3mm_opt.metal b/torchao/experimental/kernels/mps/metal/int3mm_opt.metal index 8ab9862d03..69bd142cea 100644 --- a/torchao/experimental/kernels/mps/metal/int3mm_opt.metal +++ b/torchao/experimental/kernels/mps/metal/int3mm_opt.metal @@ -8,15 +8,14 @@ using namespace metal; inline void unpack_3bit(const uchar3 b, thread float* w) { - w[0] = float(((b[0] & 1) << 2) | (b[1] & 3)); - w[1] = float(((b[0] & 2) << 1) | ((b[1] & 12) >> 2)); - w[2] = float((b[0] & 4) | ((b[1] & 48) >> 4)); - w[3] = float(((b[0] & 8) >> 1) | ((b[1] & 192) >> 6)); - - w[4] = float(((b[0] & 16) >> 2) | (b[2] & 3)); - w[5] = float(((b[0] & 32) >> 3) | ((b[2] & 12) >> 2)); - w[6] = float(((b[0] & 64) >> 4) | ((b[2] & 48) >> 4)); - w[7] = float(((b[0] & 128) >> 5) | ((b[2] & 192) >> 6)); + w[0] = float(b[0] & 0x07); + w[1] = float((b[0] & 0x38) >> 3); + w[2] = float(((b[0] & 0xc0) >> 6) | ((b[1] & 0x01) << 2)); + w[3] = float((b[1] & 0x0e) >> 1); + w[4] = float((b[1] & 0x70) >> 4); + w[5] = float(((b[1] & 0x80) >> 7) | ((b[2] & 0x03) << 1)); + w[6] = float((b[2] & 0x1c) >> 2); + w[7] = float((b[2] & 0xe0) >> 5); } /** @@ -24,8 +23,8 @@ inline void unpack_3bit(const uchar3 b, thread float* w) { * * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8) - * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N - * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N + * @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -45,6 +44,7 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]], constexpr uint k_pack_factor = 8; const uint K = sizes.y; const uint N = sizes.z; + const uint num_groups = (K + group_size - 1) / group_size; uint n = thread_index.x; // 0..N/4-1 uint m = thread_index.z; // 0..M n = n / threads_per_channel; @@ -64,12 +64,18 @@ kernel void int3pack_mm(constant T *A [[buffer(0)]], // Find specific group to which channels handled by this thread // belong. uint k_block_index = k / group_size; - uint scales_group_offset = (k_block_index * N + n); + uint scales_group_offset = (n * num_groups + k_block_index); vecT scales = - (reinterpret_cast(scales_ptr + scales_group_offset))[0]; + vecT(scales_ptr[scales_group_offset], + scales_ptr[scales_group_offset + num_groups], + scales_ptr[scales_group_offset + 2 * num_groups], + scales_ptr[scales_group_offset + 3 * num_groups]); vecT zeros = - (reinterpret_cast(zeros_ptr + scales_group_offset))[0]; + vecT(zeros_ptr[scales_group_offset], + zeros_ptr[scales_group_offset + num_groups], + zeros_ptr[scales_group_offset + 2 * num_groups], + zeros_ptr[scales_group_offset + 3 * num_groups]); float4 zeros_float = float4(zeros); float4 a_val[2]; diff --git a/torchao/experimental/kernels/mps/metal/int4mm_opt.metal b/torchao/experimental/kernels/mps/metal/int4mm_opt.metal index edee43ec14..f6d0b4935b 100644 --- a/torchao/experimental/kernels/mps/metal/int4mm_opt.metal +++ b/torchao/experimental/kernels/mps/metal/int4mm_opt.metal @@ -64,12 +64,11 @@ using namespace metal; @param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit values, along K dim, packed together. @param [in] scales_ptr is scales ptr corresponding each - output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output channels. @param [in] zeros_ptr is zero points corresponding each - output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output + output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output channels. - output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output @param [out] output_data is output matrix of size M x N. @param [in] sizes array contains values of M, K and N. @param [in] thread_index is global thread id. @@ -89,6 +88,7 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]], constexpr uint k_pack_factor = 2; const uint K = sizes.y; const uint N = sizes.z; + const uint num_groups = (K + group_size - 1) / group_size; uint n = thread_index.x; // 0..N/4-1 uint m = thread_index.z; // 0..M n = n / threads_per_channel; @@ -113,13 +113,19 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]], // Find specific group to which channels handled by this thread // belong. uint k_block_index = k / group_size; - uint scales_group_offset = (k_block_index * N + n); + uint scales_group_offset = (n * num_groups + k_block_index); vecT scales = - (reinterpret_cast(scales_ptr + scales_group_offset))[0]; + vecT(scales_ptr[scales_group_offset], + scales_ptr[scales_group_offset + num_groups], + scales_ptr[scales_group_offset + 2 * num_groups], + scales_ptr[scales_group_offset + 3 * num_groups]); // Adding zero point results in 10% perf penalty. vecT zeros = - (reinterpret_cast(zeros_ptr + scales_group_offset))[0]; + vecT(zeros_ptr[scales_group_offset], + zeros_ptr[scales_group_offset + num_groups], + zeros_ptr[scales_group_offset + 2 * num_groups], + zeros_ptr[scales_group_offset + 3 * num_groups]); float4 zeros_float = float4(zeros); float4 a_val = float4(A_ptr[k / 4]); diff --git a/torchao/experimental/kernels/mps/metal/int5mm.metal b/torchao/experimental/kernels/mps/metal/int5mm.metal index 206786b038..c8be33911a 100644 --- a/torchao/experimental/kernels/mps/metal/int5mm.metal +++ b/torchao/experimental/kernels/mps/metal/int5mm.metal @@ -11,8 +11,8 @@ using namespace metal; * * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8) - * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N - * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N + * @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -29,6 +29,7 @@ kernel void int5pack_mm( uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; + const uint num_groups = (K + groupSize - 1) / groupSize; const uint m = thread_index.y; // 0..M-1 const uint n = thread_index.x; // 0..N-1 const uint32_t k_block = (K + groupSize - 1) / groupSize; @@ -38,8 +39,8 @@ kernel void int5pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scales[kb * N + n]); - const float zero = float(zeros[kb * N + n]); + const float scale = float(scales[n * num_groups + kb]); + const float zero = float(zeros[n * num_groups + kb]); for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); @@ -56,15 +57,14 @@ kernel void int5pack_mm( uchar b3 = B_ptr[5 * (k / 8) + 3]; uchar b4 = B_ptr[5 * (k / 8) + 4]; - uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15); - uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4); - uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15); - uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4); - - uchar w_val4 = ((b0 & 16)) | (b3 & 15); - uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4); - uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15); - uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4); + uchar w_val0 = (b0 & 0x1f); + uchar w_val1 = ((b0 & 0xe0) >> 5) | ((b1 & 0x03) << 3); + uchar w_val2 = ((b1 & 0x7c) >> 2); + uchar w_val3 = ((b1 & 0x80) >> 7) | ((b2 & 0x0f) << 1); + uchar w_val4 = ((b2 & 0xf0) >> 4) | ((b3 & 0x01) << 4); + uchar w_val5 = ((b3 & 0x3e) >> 1); + uchar w_val6 = ((b3 & 0xc0) >> 6) | ((b4 & 0x07) << 2); + uchar w_val7 = ((b4 & 0xf8) >> 3); rc += a_val0 * (scale * float(w_val0) + zero); rc += a_val1 * (scale * float(w_val1) + zero); diff --git a/torchao/experimental/kernels/mps/metal/int6mm.metal b/torchao/experimental/kernels/mps/metal/int6mm.metal index 55d359a6ba..45f03d9cef 100644 --- a/torchao/experimental/kernels/mps/metal/int6mm.metal +++ b/torchao/experimental/kernels/mps/metal/int6mm.metal @@ -11,8 +11,8 @@ using namespace metal; * * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8) - * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N - * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N + * @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -29,6 +29,7 @@ kernel void int6pack_mm( uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; + const uint num_groups = (K + groupSize - 1) / groupSize; const uint m = thread_index.y; // 0..M-1 const uint n = thread_index.x; // 0..N-1 const uint32_t k_block = (K + groupSize - 1) / groupSize; @@ -38,8 +39,8 @@ kernel void int6pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scales[kb * N + n]); - const float zero = float(zeros[kb * N + n]); + const float scale = float(scales[n * num_groups + kb]); + const float zero = float(zeros[n * num_groups + kb]); for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); @@ -59,15 +60,15 @@ kernel void int6pack_mm( uchar b4 = B_ptr[3 * (k / 4) + 4]; uchar b5 = B_ptr[3 * (k / 4) + 5]; - uchar w_val0 = ((b0 & 3) << 4) | (b1 & 15); - uchar w_val1 = ((b0 & 12) << 2) | ((b1 & 240) >> 4); - uchar w_val2 = ((b0 & 48)) | (b2 & 15); - uchar w_val3 = ((b0 & 192) >> 2) | ((b2 & 240) >> 4); + uchar w_val0 = (b0 & 0x3f); + uchar w_val1 = ((b0 & 0xc0) >> 6) | ((b1 & 0x0f) << 2); + uchar w_val2 = ((b1 & 0xf0) >> 4) | ((b2 & 0x03) << 4); + uchar w_val3 = (b2 & 0xfc) >> 2; - uchar w_val4 = ((b3 & 3) << 4) | (b4 & 15); - uchar w_val5 = ((b3 & 12) << 2) | ((b4 & 240) >> 4); - uchar w_val6 = ((b3 & 48)) | (b5 & 15); - uchar w_val7 = ((b3 & 192) >> 2) | ((b5 & 240) >> 4); + uchar w_val4 = (b3 & 0x3f); + uchar w_val5 = ((b3 & 0xc0) >> 6) | ((b4 & 0x0f) << 2); + uchar w_val6 = ((b4 & 0xf0) >> 4) | ((b5 & 0x03) << 4); + uchar w_val7 = (b5 & 0xfc) >> 2; rc += a_val0 * (scale * float(w_val0) + zero); rc += a_val1 * (scale * float(w_val1) + zero); diff --git a/torchao/experimental/kernels/mps/metal/int7mm.metal b/torchao/experimental/kernels/mps/metal/int7mm.metal index b97800b448..ce4e5a51d0 100644 --- a/torchao/experimental/kernels/mps/metal/int7mm.metal +++ b/torchao/experimental/kernels/mps/metal/int7mm.metal @@ -11,8 +11,8 @@ using namespace metal; * * @param[A] M x K input tensor of floating point dtype (Float, Half, BFloat16) * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8) - * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N - * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N + * @param[scales] 2D tensor containg the scales for each group. Expected shape is N x #groups + * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is N x #groups * @param[outputData] M x N output tensor of floating point dtype (same as input) * @param[sizes] The sizes involved in the order: M, K, N * @@ -29,6 +29,7 @@ kernel void int7pack_mm( uint2 thread_index [[thread_position_in_grid]]) { const uint K = sizes.y; const uint N = sizes.z; + const uint num_groups = (K + groupSize - 1) / groupSize; const uint m = thread_index.y; // 0..M-1 const uint n = thread_index.x; // 0..N-1 const uint32_t k_block = (K + groupSize - 1) / groupSize; @@ -38,8 +39,8 @@ kernel void int7pack_mm( float rc = 0.0; uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const float scale = float(scales[kb * N + n]); - const float zero = float(zeros[kb * N + n]); + const float scale = float(scales[n * num_groups + kb]); + const float zero = float(zeros[n * num_groups + kb]); for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { const auto a_val0 = float(A_ptr[k + 0]); const auto a_val1 = float(A_ptr[k + 1]); @@ -58,15 +59,14 @@ kernel void int7pack_mm( uchar b5 = B_ptr[7 * (k / 8) + 5]; uchar b6 = B_ptr[7 * (k / 8) + 6]; - uchar w_val0 = b0 & 127; - uchar w_val1 = b1 & 127; - uchar w_val2 = b2 & 127; - uchar w_val3 = b3 & 127; - uchar w_val4 = b4 & 127; - uchar w_val5 = b5 & 127; - uchar w_val6 = b6 & 127; - uchar w_val7 = ((b0 & 128) >> 7) | ((b1 & 128) >> 6) | ((b2 & 128) >> 5) | ((b3 & 128) >> 4) - | ((b4 & 128) >> 3) | ((b5 & 128) >> 2) | ((b6 & 128) >> 1); + uchar w_val0 = (b0 & 0x7f); + uchar w_val1 = (b0 >> 7) | ((b1 & 0x3f) << 1); + uchar w_val2 = (b1 >> 6) | ((b2 & 0x1f) << 2); + uchar w_val3 = (b2 >> 5) | ((b3 & 0x0f) << 3); + uchar w_val4 = (b3 >> 4) | ((b4 & 0x07) << 4); + uchar w_val5 = (b4 >> 3) | ((b5 & 0x03) << 5); + uchar w_val6 = (b5 >> 2) | ((b6 & 0x01) << 6); + uchar w_val7 = (b6 >> 1); rc += a_val0 * (scale * float(w_val0) + zero); rc += a_val1 * (scale * float(w_val1) + zero); diff --git a/torchao/experimental/kernels/mps/metal/qmv_fast.metal b/torchao/experimental/kernels/mps/metal/qmv_fast.metal new file mode 100644 index 0000000000..190b122d15 --- /dev/null +++ b/torchao/experimental/kernels/mps/metal/qmv_fast.metal @@ -0,0 +1,364 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +/* + This code was taken from MLX, and modified to add support for 1, 5 & 7 bit packing. + The original code is Copyright © 2023-2024 Apple Inc. + https://github.com/ml-explore/mlx/blob/481349495b8c3d094eb699e678077bbe1406392d/mlx/backend/metal/kernels/quantized.h#L1 + MLX MIT License: https://github.com/ml-explore/mlx/blob/main/LICENSE +*/ + +#include +#include + +static constant constexpr const int SIMD_SIZE = 32; + +template +inline U load_vector(constant T* x, thread U* x_thread) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U sum = 0; + + if (bits == 1) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 7) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 128.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 32.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 8.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 2.0f; + } + } + + return sum; +} + +template +inline U qdot( + constant uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U accum = 0; + + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + + accum += + (x_thread[0] * (w[i] & 0x01) + + x_thread[1] * (w[i] & 0x02) + + x_thread[2] * (w[i] & 0x04) + + x_thread[3] * (w[i] & 0x08) + + x_thread[4] * (w[i] & 0x10) + + x_thread[5] * (w[i] & 0x20) + + x_thread[6] * (w[i] & 0x40) + + x_thread[7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + constant uint16_t* ws = (constant uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + + accum += (w[1] & 0x03) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + + accum += (w[2] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + + accum += (w[3] & 0x01) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + + accum += (w[4] & 0x07) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 7) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 7 * i; + + accum += (w[0] & 0x7f) * x_thread[0]; + accum += (w[0] & 0x80) * x_thread[1]; + + accum += (w[1] & 0x3f) * (x_thread[1] * 256.0f); + accum += (w[1] & 0xc0) * x_thread[2]; + + accum += (w[2] & 0x1f) * (x_thread[2] * 256.0f); + accum += (w[2] & 0xe0) * x_thread[3]; + + accum += (w[3] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[3] & 0xf0) * x_thread[4]; + + accum += (w[4] & 0x07) * (x_thread[4] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[5]; + + accum += (w[5] & 0x03) * (x_thread[5] * 256.0f); + accum += (w[5] & 0xfc) * x_thread[6]; + + accum += (w[6] & 0x01) * (x_thread[6] * 256.0f); + accum += (w[6] & 0xfe) * x_thread[7]; + } + } + + return scale * accum + sum * bias; +} + +template +[[kernel]] void qmv_fast( + constant T* x [[buffer(0)]], + constant uchar* w [[buffer(1)]], + constant T* scales [[buffer(2)]], + constant T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int in_vec_size = static_cast(sizes.y); // K + const int out_vec_size = static_cast(sizes.z); // N + + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = (bits == 1 || bits == 2) ? 1 : 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = bits == 1 ? 16 : power_of_2_bits ? 32 / bits : bits == 6 ? 4 : 8; + constexpr int bytes_per_pack = bits == 1 ? 2 : power_of_2_bits ? 4 : bits == 6 ? 3 : bits; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + constant uint8_t* ws = (constant uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +#define INSTANTIATE_QMV_FAST(DTYPE, GSIZE, NBIT) \ + template [[host_name("qmv_fast_" #NBIT "bit_" #GSIZE "_" #DTYPE)]] kernel void \ + qmv_fast( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + +#define INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, GSIZE) \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 1); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 2); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 3); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 4); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 5); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 6); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 7); + +#define INSTANTIATE_QMV_FAST_DTYPE(DTYPE) \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 32); \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 64); \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 128); \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 256); + +INSTANTIATE_QMV_FAST_DTYPE(float); +INSTANTIATE_QMV_FAST_DTYPE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_QMV_FAST_DTYPE(bfloat); +#endif diff --git a/torchao/experimental/kernels/mps/src/dispatch.h b/torchao/experimental/kernels/mps/src/dispatch.h index 39acd8d1f0..a04452cece 100644 --- a/torchao/experimental/kernels/mps/src/dispatch.h +++ b/torchao/experimental/kernels/mps/src/dispatch.h @@ -34,4 +34,18 @@ inline void dispatch_mm_Mr1xNr4_per_TG( threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } +inline void dispatch_qmv_fast( + id encoder, + int32_t maxThreadsPerGroup, + int32_t M, + int32_t N, + int32_t K) { + (void)K; + if (maxThreadsPerGroup < 64) { + throw std::runtime_error("Can't dispatch!"); + } + [encoder dispatchThreadgroups:MTLSizeMake(M, (N + 7) / 8, 1) + threadsPerThreadgroup:MTLSizeMake(32, 2, 1)]; +} + } // namespace torchao::kernels::mps::lowbit::dispatch diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index 9b2d539761..370c6d400c 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -111,6 +111,25 @@ inline void linear_lowbit_quant_weights_mps_impl( }); } +template +std::tuple get_shader_func_and_dispatch( + int64_t qGroupSize, + const std::string_view type_str, + int32_t M, + int32_t N, + int32_t K) { + if (M == 1 && N % 8 == 0 && K % 512 == 0) { + return std::make_tuple( + std::string("qmv_fast_") + std::to_string(nbit) + "bit_" + + std::to_string(qGroupSize) + "_" + std::string(type_str), + dispatch::dispatch_qmv_fast); + } + return std::make_tuple( + std::string(LowBitConfig::func_prefix) + std::to_string(qGroupSize) + + "_" + std::string(type_str), + LowBitConfig::dispatch_fn); +} + // LowBit Quantized Weights Linear on Metal template void linear_lowbit_quant_weights_mps( @@ -129,8 +148,11 @@ void linear_lowbit_quant_weights_mps( assert( qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || qGroupSize == 256); - const std::string shader_func = std::string(LowBitConfig::func_prefix) + - std::to_string(qGroupSize) + "_" + std::string(type_str); + std::tuple shader_func_and_dispatch = + get_shader_func_and_dispatch(qGroupSize, type_str, M, N, K); + const std::string shader_func = std::get<0>(shader_func_and_dispatch); + const DispatchFn dispatch_fn = std::get<1>(shader_func_and_dispatch); + return linear_lowbit_quant_weights_mps_impl( a_buf, b_buf, @@ -141,7 +163,7 @@ void linear_lowbit_quant_weights_mps( K, N, shader_func, - LowBitConfig::dispatch_fn); + dispatch_fn); } } // namespace diff --git a/torchao/experimental/kernels/mps/src/packing.h b/torchao/experimental/kernels/mps/src/packing.h index 09a248da5e..5412c04a12 100644 --- a/torchao/experimental/kernels/mps/src/packing.h +++ b/torchao/experimental/kernels/mps/src/packing.h @@ -70,9 +70,7 @@ pack<2>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { /** * 3-bit packing. Each weight is 3 bits. We can't pack them into a byte, so we - * pack 8 weights into 3 bytes. But we can't nicely pack the 8 weights - * continuously. Instead, we pack the upper bits of all weights into the first - * byte, then the 2 lower bits of all weights into the other 2 bytes. + * pack 8 weights into 3 bytes. */ template <> inline void @@ -80,28 +78,18 @@ pack<3>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { for (int32_t n = 0; n < N; n++) { int32_t row_base = (n * (K / 8)) * 3; for (int32_t k8 = 0; k8 < K / 8; k8++) { - uint8_t src_0ab = w_ptr[n * K + k8 * 8 + 0]; - uint8_t src_1cd = w_ptr[n * K + k8 * 8 + 1]; - uint8_t src_2ef = w_ptr[n * K + k8 * 8 + 2]; - uint8_t src_3gh = w_ptr[n * K + k8 * 8 + 3]; - uint8_t src_4ij = w_ptr[n * K + k8 * 8 + 4]; - uint8_t src_5kl = w_ptr[n * K + k8 * 8 + 5]; - uint8_t src_6mn = w_ptr[n * K + k8 * 8 + 6]; - uint8_t src_7op = w_ptr[n * K + k8 * 8 + 7]; - - // b0: 7|6|5|4|3|2|1|0 (upper bits for all values) - b_ptr[row_base + 3 * k8 + 0] = ((src_0ab & 4) >> 2) | - ((src_1cd & 4) >> 1) | ((src_2ef & 4)) | ((src_3gh & 4) << 1) | - ((src_4ij & 4) << 2) | ((src_5kl & 4) << 3) | ((src_6mn & 4) << 4) | - ((src_7op & 4) << 5); - - // b1: gh|ef|cd|ab (lower 2 bits for first 4 values) - b_ptr[row_base + 3 * k8 + 1] = (src_0ab & 3) | ((src_1cd & 3) << 2) | - ((src_2ef & 3) << 4) | ((src_3gh & 3) << 6); + uint8_t src_val0 = w_ptr[n * K + k8 * 8]; + uint8_t src_val1 = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_val2 = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_val3 = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_val4 = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_val5 = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_val6 = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_val7 = w_ptr[n * K + k8 * 8 + 7]; - // b2: op|mn|kl|ij (lower 2 bits for last 4 values) - b_ptr[row_base + 3 * k8 + 2] = (src_4ij & 3) | ((src_5kl & 3) << 2) | - ((src_6mn & 3) << 4) | ((src_7op & 3) << 6); + b_ptr[row_base + 3 * k8 + 0] = src_val0 | (src_val1 << 3) | (src_val2 << 6); + b_ptr[row_base + 3 * k8 + 1] = (src_val2 >> 2) | (src_val3 << 1) | (src_val4 << 4) | (src_val5 << 7); + b_ptr[row_base + 3 * k8 + 2] = (src_val5 >> 1) | (src_val6 << 2) | (src_val7 << 5); } } } @@ -123,9 +111,7 @@ pack<4>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { } /** - * 5-bit packing. Each weight is 5 bits. So we pack 8 weights into 5 bytes. We - * pack the upper bits of all weights into the first byte, then the 4 lower - * bits of all weights into the other 4 bytes. + * 5-bit packing. Each weight is 5 bits. We pack 8 weights into 5 bytes. */ template <> inline void @@ -133,41 +119,26 @@ pack<5>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { for (int32_t n = 0; n < N; n++) { int32_t row_base = (n * (K / 8)) * 5; for (int32_t k8 = 0; k8 < K / 8; k8++) { - uint8_t src_0abAB = w_ptr[n * K + k8 * 8 + 0]; - uint8_t src_1cdCD = w_ptr[n * K + k8 * 8 + 1]; - uint8_t src_2efEF = w_ptr[n * K + k8 * 8 + 2]; - uint8_t src_3ghGH = w_ptr[n * K + k8 * 8 + 3]; - uint8_t src_4ijIJ = w_ptr[n * K + k8 * 8 + 4]; - uint8_t src_5klKL = w_ptr[n * K + k8 * 8 + 5]; - uint8_t src_6mnMN = w_ptr[n * K + k8 * 8 + 6]; - uint8_t src_7opOP = w_ptr[n * K + k8 * 8 + 7]; - - // b0: 7|6|5|4|3|2|1|0 (upper bits for all values) - b_ptr[row_base + 5 * k8 + 0] = ((src_0abAB & 16) >> 4) | - ((src_1cdCD & 16) >> 3) | ((src_2efEF & 16) >> 2) | - ((src_3ghGH & 16) >> 1) | ((src_4ijIJ & 16)) | - ((src_5klKL & 16) << 1) | ((src_6mnMN & 16) << 2) | - ((src_7opOP & 16) << 3); - - // b1: cdCD|abAB (lower 4 bits for first 2 values) - b_ptr[row_base + 5 * k8 + 1] = (src_0abAB & 15) | ((src_1cdCD & 15) << 4); - - // b2: ghGH|efEF (lower 4 bits for second 2 values) - b_ptr[row_base + 5 * k8 + 2] = (src_2efEF & 15) | ((src_3ghGH & 15) << 4); - - // b3: klKL|ijIJ (lower 4 bits for third 2 values) - b_ptr[row_base + 5 * k8 + 3] = (src_4ijIJ & 15) | ((src_5klKL & 15) << 4); + uint8_t src_val0 = w_ptr[n * K + k8 * 8]; + uint8_t src_val1 = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_val2 = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_val3 = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_val4 = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_val5 = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_val6 = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_val7 = w_ptr[n * K + k8 * 8 + 7]; - // b4: opOP|mnMN (lower 4 bits for last 2 values) - b_ptr[row_base + 5 * k8 + 4] = (src_6mnMN & 15) | ((src_7opOP & 15) << 4); + b_ptr[row_base + 5 * k8 + 0] = src_val0 | (src_val1 << 5); + b_ptr[row_base + 5 * k8 + 1] = (src_val1 >> 3) | (src_val2 << 2) | (src_val3 << 7); + b_ptr[row_base + 5 * k8 + 2] = (src_val3 >> 1) | (src_val4 << 4); + b_ptr[row_base + 5 * k8 + 3] = (src_val4 >> 4) | (src_val5 << 1) | (src_val6 << 6); + b_ptr[row_base + 5 * k8 + 4] = (src_val6 >> 2) | (src_val7 << 3); } } } /** - * 6-bit packing. Each weight is 6 bits. So we pack 4 weights into 3 bytes. We - * pack the upper 2 bits of all 4 weights into the first 2 bytes, then the 4 - * lower bits of all weights into the other 4 bytes. + * 6-bit packing. Each weight is 6 bits. We pack 4 weights into 3 bytes. */ template <> inline void @@ -175,32 +146,20 @@ pack<6>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { for (int32_t n = 0; n < N; n++) { int32_t row_base = (n * (K / 4)) * 3; for (int32_t k4 = 0; k4 < K / 4; k4++) { - uint8_t src_10abcd = w_ptr[n * K + k4 * 4 + 0]; - uint8_t src_32efgh = w_ptr[n * K + k4 * 4 + 1]; - uint8_t src_54ijkl = w_ptr[n * K + k4 * 4 + 2]; - uint8_t src_76mnop = w_ptr[n * K + k4 * 4 + 3]; - - // b0: 76|54|32|10 (upper 2 bits for all values) - b_ptr[row_base + 3 * k4 + 0] = ((src_10abcd & 48) >> 4) | - ((src_32efgh & 48) >> 2) | ((src_54ijkl & 48)) | - ((src_76mnop & 48) << 2); - - // b1: efgh|abcd (lower 4 bits for first 2 values) - b_ptr[row_base + 3 * k4 + 1] = - (src_10abcd & 15) | ((src_32efgh & 15) << 4); + uint8_t src_val0 = w_ptr[n * K + k4 * 4]; + uint8_t src_val1 = w_ptr[n * K + k4 * 4 + 1]; + uint8_t src_val2 = w_ptr[n * K + k4 * 4 + 2]; + uint8_t src_val3 = w_ptr[n * K + k4 * 4 + 3]; - // b2: mnop|ijkl (lower 4 bits for last 2 values) - b_ptr[row_base + 3 * k4 + 2] = - (src_54ijkl & 15) | ((src_76mnop & 15) << 4); + b_ptr[row_base + 3 * k4 + 0] = src_val0 | (src_val1 << 6); + b_ptr[row_base + 3 * k4 + 1] = (src_val1 >> 2) | (src_val2 << 4); + b_ptr[row_base + 3 * k4 + 2] = (src_val2 >> 4) | (src_val3 << 2); } } } /** - * 7-bit packing. Each weight is 7 bits. So we pack 8 weights into 7 bytes. - * Each of the 7 bytes contains 1 weight, plus 1 bit from the 8th weight. So, - * this packing spreads the 8th weight across all 7 bytes. The upper bit of - * each byte is the bit from the 8th weight. + * 7-bit packing. Each weight is 7 bits. We pack 8 weights into 7 bytes. */ template <> inline void @@ -208,22 +167,22 @@ pack<7>(const uint8_t* w_ptr, uint8_t* b_ptr, int32_t N, int32_t K) { for (int32_t n = 0; n < N; n++) { int32_t row_base = (n * (K / 8)) * 7; for (int32_t k8 = 0; k8 < K / 8; k8++) { - uint8_t src_0 = w_ptr[n * K + k8 * 8 + 0]; - uint8_t src_1 = w_ptr[n * K + k8 * 8 + 1]; - uint8_t src_2 = w_ptr[n * K + k8 * 8 + 2]; - uint8_t src_3 = w_ptr[n * K + k8 * 8 + 3]; - uint8_t src_4 = w_ptr[n * K + k8 * 8 + 4]; - uint8_t src_5 = w_ptr[n * K + k8 * 8 + 5]; - uint8_t src_6 = w_ptr[n * K + k8 * 8 + 6]; - uint8_t src_7 = w_ptr[n * K + k8 * 8 + 7]; + uint8_t src_val0 = w_ptr[n * K + k8 * 8 + 0]; + uint8_t src_val1 = w_ptr[n * K + k8 * 8 + 1]; + uint8_t src_val2 = w_ptr[n * K + k8 * 8 + 2]; + uint8_t src_val3 = w_ptr[n * K + k8 * 8 + 3]; + uint8_t src_val4 = w_ptr[n * K + k8 * 8 + 4]; + uint8_t src_val5 = w_ptr[n * K + k8 * 8 + 5]; + uint8_t src_val6 = w_ptr[n * K + k8 * 8 + 6]; + uint8_t src_val7 = w_ptr[n * K + k8 * 8 + 7]; - b_ptr[row_base + 7 * k8 + 0] = src_0 | ((src_7 & 1) << 7); - b_ptr[row_base + 7 * k8 + 1] = src_1 | ((src_7 & 2) << 6); - b_ptr[row_base + 7 * k8 + 2] = src_2 | ((src_7 & 4) << 5); - b_ptr[row_base + 7 * k8 + 3] = src_3 | ((src_7 & 8) << 4); - b_ptr[row_base + 7 * k8 + 4] = src_4 | ((src_7 & 16) << 3); - b_ptr[row_base + 7 * k8 + 5] = src_5 | ((src_7 & 32) << 2); - b_ptr[row_base + 7 * k8 + 6] = src_6 | ((src_7 & 64) << 1); + b_ptr[row_base + 7 * k8 + 0] = src_val0 | (src_val1 << 7); + b_ptr[row_base + 7 * k8 + 1] = (src_val1 >> 1) | (src_val2 << 6); + b_ptr[row_base + 7 * k8 + 2] = (src_val2 >> 2) | (src_val3 << 5); + b_ptr[row_base + 7 * k8 + 3] = (src_val3 >> 3) | (src_val4 << 4); + b_ptr[row_base + 7 * k8 + 4] = (src_val4 >> 4) | (src_val5 << 3); + b_ptr[row_base + 7 * k8 + 5] = (src_val5 >> 5) | (src_val6 << 2); + b_ptr[row_base + 7 * k8 + 6] = (src_val6 >> 6) | (src_val7 << 1); } } } diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 8a1e0fdb9e..524aee738d 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -51,6 +51,7 @@ void reference_linear_lowbit_quant_weights_cpu( int32_t M, int32_t K, int32_t N) { + int32_t ceil_K_group_size = (K + group_size - 1) / group_size; for (int32_t m = 0; m < M; m++) { for (int32_t n = 0; n < N; n++) { const int32_t k_block = (K + group_size - 1) / group_size; @@ -59,8 +60,8 @@ void reference_linear_lowbit_quant_weights_cpu( float rc = 0.0; int32_t k = 0; for (int32_t kb = 0; kb < k_block; kb++) { - const float scale = float(s_ptr[kb * N + n]); - const float zero = float(z_ptr[kb * N + n]); + const float scale = float(s_ptr[n * ceil_K_group_size + kb]); + const float zero = float(z_ptr[n * ceil_K_group_size + kb]); for (int32_t idx = 0; idx < group_size && k < K; idx++, k++) { const auto a_val = float(A_ptr[k]); uint8_t w_val = w_ptr[n * K + k]; @@ -217,6 +218,7 @@ void run_test_battery() { run_test(19, 256, 28, 256); run_test(1, 1000, 28, 256); run_test(19, 8, 36, 256); + run_test(1, 1024, 1024, 64); } int main() { diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index 2aeb7f4460..972caa039a 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -55,19 +55,19 @@ void check_linear_mps_args( group_size); TORCH_CHECK( - S.dim() == 2 && S.size(1) == N, + S.dim() == 2 && S.size(0) == N, __func__, - ": expect S to be 2d tensor with shape [:, ", + ": expect S to be 2d tensor with shape [", N, - "]"); + ",:]"); TORCH_CHECK(S.is_contiguous(), __func__, " : expect S to be contiguous."); TORCH_CHECK( - Z.dim() == 2 && Z.size(1) == N, + Z.dim() == 2 && Z.size(0) == N, __func__, - ": expect Z to be 2d tensor with shape [:, ", + ": expect Z to be 2d tensor with shape [", N, - "]"); + ",:]"); TORCH_CHECK(Z.is_contiguous(), __func__, " : expect Z to be contiguous."); } diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm index a6f417b17d..f8a8ffdae9 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm @@ -65,12 +65,12 @@ bool check_linear_mps_args( "Expect group_size to be 32, 64, 128 or 256"); ET_LOG_MSG_AND_RETURN_IF_FALSE( - S.dim() == 2 && S.size(1) == N, - "Expect S to be 2d tensor with shape [:, N]"); + S.dim() == 2 && S.size(0) == N, + "Expect S to be 2d tensor with shape [N, :]"); ET_LOG_MSG_AND_RETURN_IF_FALSE( - Z.dim() == 2 && Z.size(1) == N, - "Expect Z to be 2d tensor with shape [:, N]"); + Z.dim() == 2 && Z.size(0) == N, + "Expect Z to be 2d tensor with shape [N, :]"); return true; } diff --git a/torchao/experimental/ops/mps/mps_op_lib.py b/torchao/experimental/ops/mps/mps_op_lib.py index 145c77c3de..bee038ce19 100644 --- a/torchao/experimental/ops/mps/mps_op_lib.py +++ b/torchao/experimental/ops/mps/mps_op_lib.py @@ -37,10 +37,10 @@ def _( assert scales.is_contiguous() assert scales.dim() == 2 - assert scales.size(1) == n + assert scales.size(0) == n assert zeros.is_contiguous() assert zeros.dim() == 2 - assert zeros.size(1) == n + assert zeros.size(0) == n return torch.empty(m, n, dtype=activations.dtype, device="meta") diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index a3ac7a6431..dc2460110e 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -64,11 +64,11 @@ def _init_tensors(self, group_size, M, K, N, nbit, device="mps"): ceil_K_group_size = (K + group_size - 1) // group_size A = torch.rand(M, K, dtype=torch.float32, device=device) W = torch.randint(0, 1 << nbit, (N, K), dtype=torch.uint8, device=device) - S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01 + S = torch.rand(N, ceil_K_group_size, dtype=torch.float32, device=device) + 0.01 Z = torch.randint( 0, 1 << nbit, - (ceil_K_group_size, N), + (N, ceil_K_group_size), dtype=torch.float32, device=device, ) @@ -83,8 +83,8 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit): N = W.shape[0] K = W.shape[1] W = W.to(torch.float32) - scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] - zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + scales = S.unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + zeros = Z.unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] W = scales * W + zeros return torch.mm(A, W.t()) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 7afa91183e..04273fb1af 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -146,13 +146,14 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z): N = W.shape[0] K = W.shape[1] W = W.to(torch.float32) - scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] - zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + scales = S.unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + zeros = Z.unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] W = scales * W + zeros return torch.mm(A, W.t()) @parameterized.expand(BITWIDTHS) def test_accuracy(self, nbit): + print(f"nbit: {nbit}") group_size = 32 m = 3 n = 12 @@ -170,8 +171,7 @@ def test_accuracy(self, nbit): weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize( weight_cpu, group_size, nbit, True, torch.uint8 ) - weight_scales_cpu = weight_scales_cpu.t() - weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu + weight_zeros_cpu = -weight_zeros_cpu * weight_scales_cpu expected = self._reference_linear_lowbit_quant_weights( activations.cpu(), weight_qvals_cpu, diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index b7630cada3..2e50587c2a 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -529,8 +529,6 @@ def quantize_and_pack_weights(self, weights, nbit, group_size): weight_qvals, weight_scales, weight_zeros = _quantize( weights, self.group_size, self.nbit, has_weight_zeros=True, signed=False ) - weight_scales = torch.transpose_copy(weight_scales, 1, 0) - weight_zeros = torch.transpose_copy(weight_zeros, 1, 0) weight_zeros = -weight_zeros * weight_scales self.weight_scales = nn.Parameter(weight_scales, requires_grad=False) self.weight_zeros = nn.Parameter(weight_zeros, requires_grad=False) @@ -550,7 +548,7 @@ def forward(self, x): lead_shape = x.shape[0:-1] k = x.shape[-1] - n = self.weight_scales.shape[1] + n = self.weight_scales.shape[0] return self._linear_op( x.reshape(-1, k), self.packed_weights,