Skip to content

Commit 5a94188

Browse files
metal lowbit kernels: optimized 2-bit, 3-bit and 4-bit shaders
1 parent 603d908 commit 5a94188

17 files changed

+865
-45
lines changed

torchao/experimental/kernels/mps/metal.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
- func: Vec4Type
2+
file: common.metal
3+
14
- func: int1mm
2-
file: divbit.metal
5+
file: int1mm.metal
36

47
- func: int2mm
5-
file: divbit.metal
8+
file: int2mm_opt.metal
69

710
- func: int3mm
8-
file: int3mm.metal
11+
file: int3mm_opt.metal
912

1013
- func: int4mm
11-
file: divbit.metal
14+
file: int4mm_opt.metal
1215

1316
- func: int5mm
1417
file: int5mm.metal
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
template <typename T> struct Vec4Type {};
2+
3+
template <> struct Vec4Type<float> {
4+
using type = float4;
5+
};
6+
7+
template <> struct Vec4Type<half> {
8+
using type = half4;
9+
};
10+
11+
#if __METAL_VERSION__ >= 310
12+
template <> struct Vec4Type<bfloat> {
13+
using type = bfloat4;
14+
};
15+
#endif

torchao/experimental/kernels/mps/metal/divbit.metal

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ kernel void divbit_mm(
3030
constant T *A_ptr = A + m * K;
3131
constant uchar *B_ptr = B;
3232

33-
constexpr uint8_t zero_shift = 1 << (nbit - 1);
3433
constexpr uint8_t values_per_byte = 8 / nbit;
3534
constexpr uint8_t minimask = (1 << nbit) - 1;
3635

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
/**
5+
* 1-Bit Quantized Linear.
6+
*
7+
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
8+
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 8)
9+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
11+
* @param[outputData] M x N output tensor of floating point dtype (same as input)
12+
* @param[sizes] The sizes involved in the order: M, K, N
13+
*
14+
* Dispatched threads: N x M x 1
15+
*/
16+
template<typename T, unsigned groupSize>
17+
kernel void int1pack_mm(
18+
constant T * A [[buffer(0)]],
19+
constant uchar * B [[buffer(1)]],
20+
constant T * scales [[buffer(2)]],
21+
constant T * zeros [[buffer(3)]],
22+
device T * outputData [[buffer(4)]],
23+
constant uint3 & sizes [[buffer(5)]], // M, K, N
24+
uint2 thread_index [[thread_position_in_grid]]) {
25+
const uint K = sizes.y;
26+
const uint N = sizes.z;
27+
const uint m = thread_index.y; // 0..M-1
28+
const uint n = thread_index.x; // 0..N-1
29+
const uint32_t k_block = (K + groupSize - 1) / groupSize;
30+
constant T *A_ptr = A + m * K;
31+
constant uchar *B_ptr = B + n * K / 8;
32+
33+
float rc = 0.0;
34+
uint k = 0;
35+
for (uint32_t kb = 0; kb < k_block ; kb ++) {
36+
const float scale = float(scales[kb * N + n]);
37+
const float zero = float(zeros[kb * N + n]);
38+
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
39+
const auto a_val0 = float(A_ptr[k + 0]);
40+
const auto a_val1 = float(A_ptr[k + 1]);
41+
const auto a_val2 = float(A_ptr[k + 2]);
42+
const auto a_val3 = float(A_ptr[k + 3]);
43+
const auto a_val4 = float(A_ptr[k + 4]);
44+
const auto a_val5 = float(A_ptr[k + 5]);
45+
const auto a_val6 = float(A_ptr[k + 6]);
46+
const auto a_val7 = float(A_ptr[k + 7]);
47+
48+
uchar b0 = B_ptr[(k / 8)];
49+
50+
uchar w_val0 = b0 & 0x01;
51+
uchar w_val1 = (b0 & 0x02) >> 1;
52+
uchar w_val2 = (b0 & 0x04) >> 2;
53+
uchar w_val3 = (b0 & 0x08) >> 3;
54+
uchar w_val4 = (b0 & 0x10) >> 4;
55+
uchar w_val5 = (b0 & 0x20) >> 5;
56+
uchar w_val6 = (b0 & 0x40) >> 6;
57+
uchar w_val7 = (b0 & 0x80) >> 7;
58+
59+
rc += a_val0 * (scale * float(w_val0) + zero);
60+
rc += a_val1 * (scale * float(w_val1) + zero);
61+
rc += a_val2 * (scale * float(w_val2) + zero);
62+
rc += a_val3 * (scale * float(w_val3) + zero);
63+
rc += a_val4 * (scale * float(w_val4) + zero);
64+
rc += a_val5 * (scale * float(w_val5) + zero);
65+
rc += a_val6 * (scale * float(w_val6) + zero);
66+
rc += a_val7 * (scale * float(w_val7) + zero);
67+
}
68+
}
69+
outputData[m * N + n] = T(rc);
70+
}
71+
72+
#define INSTANTIATE_INT1MM(DTYPE, GSIZE) \
73+
template \
74+
[[host_name("int1pack_mm_" #GSIZE "_" #DTYPE)]] \
75+
kernel void int1pack_mm<DTYPE, GSIZE>( \
76+
constant DTYPE * A [[buffer(0)]], \
77+
constant uchar * B [[buffer(1)]], \
78+
constant DTYPE * scales [[buffer(2)]], \
79+
constant DTYPE * zeros [[buffer(3)]], \
80+
device DTYPE * outputData [[buffer(4)]], \
81+
constant uint3 & sizes [[buffer(5)]], \
82+
uint2 thread_index [[thread_position_in_grid]])
83+
84+
INSTANTIATE_INT1MM(float, 32);
85+
INSTANTIATE_INT1MM(half, 32);
86+
INSTANTIATE_INT1MM(float, 64);
87+
INSTANTIATE_INT1MM(half, 64);
88+
INSTANTIATE_INT1MM(float, 128);
89+
INSTANTIATE_INT1MM(half, 128);
90+
INSTANTIATE_INT1MM(float, 256);
91+
INSTANTIATE_INT1MM(half, 256);
92+
#if __METAL_VERSION__ >= 310
93+
INSTANTIATE_INT1MM(bfloat, 32);
94+
INSTANTIATE_INT1MM(bfloat, 64);
95+
INSTANTIATE_INT1MM(bfloat, 128);
96+
INSTANTIATE_INT1MM(bfloat, 256);
97+
#endif
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
/**
5+
* 2-Bit Quantized Linear.
6+
*
7+
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
8+
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (K / 4)
9+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
11+
* @param[outputData] M x N output tensor of floating point dtype (same as input)
12+
* @param[sizes] The sizes involved in the order: M, K, N
13+
*
14+
* Dispatched threads: N x M x 1
15+
*/
16+
template<typename T, unsigned groupSize>
17+
kernel void int2pack_mm(
18+
constant T * A [[buffer(0)]],
19+
constant uchar * B [[buffer(1)]],
20+
constant T * scales [[buffer(2)]],
21+
constant T * zeros [[buffer(3)]],
22+
device T * outputData [[buffer(4)]],
23+
constant uint3 & sizes [[buffer(5)]], // M, K, N
24+
uint2 thread_index [[thread_position_in_grid]]) {
25+
const uint K = sizes.y;
26+
const uint N = sizes.z;
27+
const uint m = thread_index.y; // 0..M-1
28+
const uint n = thread_index.x; // 0..N-1
29+
const uint32_t k_block = (K + groupSize - 1) / groupSize;
30+
constant T *A_ptr = A + m * K;
31+
constant uchar *B_ptr = B + n * 2 * K / 8;
32+
33+
float rc = 0.0;
34+
uint k = 0;
35+
for (uint32_t kb = 0; kb < k_block ; kb ++) {
36+
const float scale = float(scales[kb * N + n]);
37+
const float zero = float(zeros[kb * N + n]);
38+
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
39+
const auto a_val0 = float(A_ptr[k + 0]);
40+
const auto a_val1 = float(A_ptr[k + 1]);
41+
const auto a_val2 = float(A_ptr[k + 2]);
42+
const auto a_val3 = float(A_ptr[k + 3]);
43+
const auto a_val4 = float(A_ptr[k + 4]);
44+
const auto a_val5 = float(A_ptr[k + 5]);
45+
const auto a_val6 = float(A_ptr[k + 6]);
46+
const auto a_val7 = float(A_ptr[k + 7]);
47+
48+
uchar b0 = B_ptr[2 * (k / 8) + 0];
49+
uchar b1 = B_ptr[2 * (k / 8) + 1];
50+
51+
uchar w_val0 = b0 & 0x03;
52+
uchar w_val1 = (b0 & 0x0c) >> 2;
53+
uchar w_val2 = (b0 & 0x30) >> 4;
54+
uchar w_val3 = (b0 & 0xc0) >> 6;
55+
56+
uchar w_val4 = b1 & 0x03;
57+
uchar w_val5 = (b1 & 0x0c) >> 2;
58+
uchar w_val6 = (b1 & 0x30) >> 4;
59+
uchar w_val7 = (b1 & 0xc0) >> 6;
60+
61+
rc += a_val0 * (scale * float(w_val0) + zero);
62+
rc += a_val1 * (scale * float(w_val1) + zero);
63+
rc += a_val2 * (scale * float(w_val2) + zero);
64+
rc += a_val3 * (scale * float(w_val3) + zero);
65+
rc += a_val4 * (scale * float(w_val4) + zero);
66+
rc += a_val5 * (scale * float(w_val5) + zero);
67+
rc += a_val6 * (scale * float(w_val6) + zero);
68+
rc += a_val7 * (scale * float(w_val7) + zero);
69+
}
70+
}
71+
outputData[m * N + n] = T(rc);
72+
}
73+
74+
#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \
75+
template \
76+
[[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] \
77+
kernel void int2pack_mm<DTYPE, GSIZE>( \
78+
constant DTYPE * A [[buffer(0)]], \
79+
constant uchar * B [[buffer(1)]], \
80+
constant DTYPE * scales [[buffer(2)]], \
81+
constant DTYPE * zeros [[buffer(3)]], \
82+
device DTYPE * outputData [[buffer(4)]], \
83+
constant uint3 & sizes [[buffer(5)]], \
84+
uint2 thread_index [[thread_position_in_grid]])
85+
86+
INSTANTIATE_INT2MM(float, 32);
87+
INSTANTIATE_INT2MM(half, 32);
88+
INSTANTIATE_INT2MM(float, 64);
89+
INSTANTIATE_INT2MM(half, 64);
90+
INSTANTIATE_INT2MM(float, 128);
91+
INSTANTIATE_INT2MM(half, 128);
92+
INSTANTIATE_INT2MM(float, 256);
93+
INSTANTIATE_INT2MM(half, 256);
94+
#if __METAL_VERSION__ >= 310
95+
INSTANTIATE_INT2MM(bfloat, 32);
96+
INSTANTIATE_INT2MM(bfloat, 64);
97+
INSTANTIATE_INT2MM(bfloat, 128);
98+
INSTANTIATE_INT2MM(bfloat, 256);
99+
#endif
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#include <metal_simdgroup>
2+
#include <metal_stdlib>
3+
using namespace metal;
4+
5+
/*
6+
This code takes heavy inspiration from MLX:
7+
https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h
8+
Specifically:
9+
- Multiplying activation by inverse scaling factor to reduce compute
10+
boundedness
11+
- Handling zero point by accumulating act in separate sum term. Needed with
12+
optimization done above. MLX MIT License:
13+
https://github.com/ml-explore/mlx/blob/main/LICENSE
14+
*/
15+
16+
/*
17+
@brief This shader implements 2-bit matrix-vector multiplication where A
18+
matrix is fp16, bfloat or float and B matrix is a 2-bit groupwise-quantized weight
19+
matrix.
20+
@param [in] A is activation matrix of size M x K.
21+
@param [in] B is weight matrix of size M x K. Each byte contains 4 2-bit
22+
values, along K dim, packed together.
23+
@param [in] scales_ptr is scales ptr corresponding each
24+
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
25+
channels.
26+
@param [in] zeros_ptr is zero points corresponding each
27+
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N]. N = output
28+
channels.
29+
output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
30+
@param [out] output_data is output matrix of size M x N.
31+
@param [in] sizes array contains values of M, K and N.
32+
@param [in] thread_index is global thread id.
33+
@param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31].
34+
*/
35+
template <typename T, unsigned group_size>
36+
kernel void int2pack_mm(constant T *A [[buffer(0)]],
37+
constant uchar *B [[buffer(1)]],
38+
constant T *scales_ptr [[buffer(2)]],
39+
constant T *zeros_ptr [[buffer(3)]],
40+
device T *output_data [[buffer(4)]],
41+
constant uint3 &sizes [[buffer(5)]], // M, K, N
42+
uint3 thread_index [[thread_position_in_grid]],
43+
uint tid_in_simdgroup [[thread_index_in_simdgroup]]) {
44+
constexpr uint threads_per_channel = 32;
45+
constexpr uint ks_per_thread = 4;
46+
constexpr uint k_pack_factor = 4;
47+
const uint K = sizes.y;
48+
const uint N = sizes.z;
49+
uint n = thread_index.x; // 0..N/4-1
50+
uint m = thread_index.z; // 0..M
51+
n = n / threads_per_channel;
52+
n = n * 4;
53+
// This is starting k for each thread. In the example above, for thread 1 this
54+
// value will be 4.
55+
uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread;
56+
constexpr int k_jump = threads_per_channel * ks_per_thread;
57+
58+
using vecT = typename Vec4Type<T>::type;
59+
constant vecT *A_ptr = reinterpret_cast<constant vecT *>(A + m * K);
60+
constant uchar *B_ptr = B + ((n * K) / k_pack_factor);
61+
62+
thread float4 result = float4(0.0);
63+
// We multipy group of 4 channels with these scales.
64+
// Because corresponding values from weight matrix are effectively left
65+
// shifted. This is to avoid doing right shift on those values which ends up
66+
// affecting performance. This is the trick applied in MLX kernels.
67+
float4 act_div_scales = {1.f, 1 / 4.f, 1 / 16.f, 1 / 64.f};
68+
69+
for (; k < K; k += k_jump) {
70+
// Find specific group to which channels handled by this thread
71+
// belong.
72+
uint k_block_index = k / group_size;
73+
uint scales_group_offset = (k_block_index * N + n);
74+
75+
vecT scales =
76+
(reinterpret_cast<constant vecT *>(scales_ptr + scales_group_offset))[0];
77+
// Adding zero point results in 10% perf penalty.
78+
vecT zeros =
79+
(reinterpret_cast<constant vecT *>(zeros_ptr + scales_group_offset))[0];
80+
float4 zeros_float = float4(zeros);
81+
82+
float4 a_val = float4(A_ptr[k / 4]);
83+
// We are gonna skip right-shifts of the weights and hence divide by corresponding factor.
84+
float4 a_vec = a_val * act_div_scales;
85+
float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3];
86+
87+
float4x4 b_mat;
88+
ushort b_val0 = (B_ptr + (k + 0 * K) / k_pack_factor)[0];
89+
ushort b_val1 = (B_ptr + (k + 1 * K) / k_pack_factor)[0];
90+
ushort b_val2 = (B_ptr + (k + 2 * K) / k_pack_factor)[0];
91+
ushort b_val3 = (B_ptr + (k + 3 * K) / k_pack_factor)[0];
92+
b_mat[0] = scales[0] * float4(float(b_val0 & 0x03), float(b_val0 & 0x0c),
93+
float(b_val0 & 0x30), float(b_val0 & 0xc0));
94+
b_mat[1] = scales[1] * float4(float(b_val1 & 0x03), float(b_val1 & 0x0c),
95+
float(b_val1 & 0x30), float(b_val1 & 0xc0));
96+
b_mat[2] = scales[2] * float4(float(b_val2 & 0x03), float(b_val2 & 0x0c),
97+
float(b_val2 & 0x30), float(b_val2 & 0xc0));
98+
b_mat[3] = scales[3] * float4(float(b_val3 & 0x03), float(b_val3 & 0x0c),
99+
float(b_val3 & 0x30), float(b_val3 & 0xc0));
100+
101+
result += a_vec * b_mat;
102+
result += a_val_sum * zeros_float;
103+
}
104+
result += simd_shuffle_down(result, 1);
105+
result += simd_shuffle_down(result, 2);
106+
result += simd_shuffle_down(result, 4);
107+
result += simd_shuffle_down(result, 8);
108+
result += simd_shuffle_down(result, 16);
109+
if (tid_in_simdgroup % threads_per_channel == 0) {
110+
reinterpret_cast<device vecT *>(output_data + m * N)[n / 4] = vecT(result);
111+
}
112+
}
113+
114+
#define INSTANTIATE_INT2MM(DTYPE, GSIZE) \
115+
template [[host_name("int2pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \
116+
int2pack_mm<DTYPE, GSIZE>( \
117+
constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \
118+
constant DTYPE * scales_ptr [[buffer(2)]], \
119+
constant DTYPE * zeros_ptr [[buffer(3)]], \
120+
device DTYPE * output_data [[buffer(4)]], \
121+
constant uint3 & sizes [[buffer(5)]], \
122+
uint3 thread_index [[thread_position_in_grid]], \
123+
uint tid_in_simdgroup [[thread_index_in_simdgroup]])
124+
125+
INSTANTIATE_INT2MM(float, 32);
126+
INSTANTIATE_INT2MM(half, 32);
127+
INSTANTIATE_INT2MM(float, 64);
128+
INSTANTIATE_INT2MM(half, 64);
129+
INSTANTIATE_INT2MM(float, 128);
130+
INSTANTIATE_INT2MM(half, 128);
131+
INSTANTIATE_INT2MM(float, 256);
132+
INSTANTIATE_INT2MM(half, 256);
133+
#if __METAL_VERSION__ >= 310
134+
INSTANTIATE_INT2MM(bfloat, 32);
135+
INSTANTIATE_INT2MM(bfloat, 64);
136+
INSTANTIATE_INT2MM(bfloat, 128);
137+
INSTANTIATE_INT2MM(bfloat, 256);
138+
#endif

0 commit comments

Comments
 (0)