1111#include < torchao/experimental/kernels/cpu/parallel.h>
1212
1313template <int weight_nbit>
14- at::Tensor pack_weights_cpu (
14+ at::Tensor pack_weights_without_zeros_cpu (
1515 const at::Tensor& weight_qvals,
1616 const at::Tensor& weight_scales,
1717 // TODO(T200095131): convert to int64_t when supported by AOTI
@@ -54,9 +54,8 @@ at::Tensor pack_weights_cpu(
5454
5555 auto packed_weight_data_size =
5656 get_packed_weight_data_size (ukernel_config, n, k, group_size);
57- auto options = torch::TensorOptions ().dtype (torch::kInt8 );
58-
59- at::Tensor packed_weights = torch::empty ({packed_weight_data_size}, options);
57+ at::Tensor packed_weights =
58+ torch::empty ({packed_weight_data_size}, torch::kInt8 );
6059 pack_weight_data_operator (
6160 ukernel_config,
6261 pack_weight_tiling_params,
@@ -72,7 +71,74 @@ at::Tensor pack_weights_cpu(
7271}
7372
7473template <int weight_nbit>
75- at::Tensor pack_weights_meta (
74+ at::Tensor pack_weights_with_zeros_cpu (
75+ const at::Tensor& weight_qvals,
76+ const at::Tensor& weight_scales,
77+ const at::Tensor& weight_zeros,
78+ // TODO(T200095131): convert to int64_t when supported by AOTI
79+ // group_size is a meta tensor with size (group_size)
80+ const at::Tensor& group_size_tensor) {
81+ int64_t group_size = group_size_tensor.size (0 );
82+
83+ TORCH_CHECK (
84+ weight_qvals.dtype () == torch::kInt8 , " weight_qvals must be int8" );
85+ TORCH_CHECK (weight_qvals.dim () == 2 , " weight_qvals must be 2D" );
86+
87+ // In PyTorch, weights are nxk in row-major format (with activations being
88+ // right-multiplied).
89+ // In kernel, activations are left-multiplied by kxn transposed
90+ // weights in column-major format.
91+ // Note the underlying data is the same in both cases
92+ int n = weight_qvals.size (0 );
93+ int k = weight_qvals.size (1 );
94+
95+ TORCH_CHECK (
96+ weight_scales.dtype () == torch::kFloat32 ,
97+ " weight_scales must be float32" );
98+ TORCH_CHECK (weight_scales.dim () == 1 , " weight_scales must be 1D" );
99+ TORCH_CHECK (
100+ weight_scales.size (0 ) == ((n * k) / group_size),
101+ " expected 1 scale per group" );
102+ TORCH_CHECK (
103+ weight_zeros.dtype () == torch::kInt8 , " weight_zeros must be int8" );
104+ TORCH_CHECK (weight_zeros.dim () == 1 , " weight_zeros must be 1D" );
105+ TORCH_CHECK (
106+ weight_zeros.size (0 ) == ((n * k) / group_size),
107+ " expected 1 zero per group" );
108+
109+ using namespace torchao ::operators::cpu::linear::
110+ channelwise_8bit_activation_groupwise_lowbit_weight;
111+
112+ auto ukernel_config = get_ukernel_config<
113+ weight_nbit,
114+ true /* has_weight_zeros*/ ,
115+ false /* has_bias*/ ,
116+ false /* has_clamp*/ >();
117+ auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params (
118+ ukernel_config, n, /* target_panels_per_thread=*/ 1 );
119+
120+ torchao::set_num_threads (torch::get_num_threads ());
121+
122+ auto packed_weight_data_size =
123+ get_packed_weight_data_size (ukernel_config, n, k, group_size);
124+ at::Tensor packed_weights =
125+ torch::empty ({packed_weight_data_size}, torch::kInt8 );
126+ pack_weight_data_operator (
127+ ukernel_config,
128+ pack_weight_tiling_params,
129+ packed_weights.data_ptr <int8_t >(),
130+ n,
131+ k,
132+ group_size,
133+ weight_qvals.const_data_ptr <int8_t >(),
134+ weight_scales.const_data_ptr <float >(),
135+ weight_zeros.const_data_ptr <int8_t >());
136+
137+ return packed_weights;
138+ }
139+
140+ template <int weight_nbit>
141+ at::Tensor pack_weights_without_zeros_meta (
76142 const at::Tensor& weight_qvals,
77143 const at::Tensor& weight_scales,
78144 // TODO(T200095131): convert to int64_t when supported by AOTI
@@ -98,6 +164,33 @@ at::Tensor pack_weights_meta(
98164}
99165
100166template <int weight_nbit>
167+ at::Tensor pack_weights_with_zeros_meta (
168+ const at::Tensor& weight_qvals,
169+ const at::Tensor& weight_scales,
170+ const at::Tensor& weight_zeros,
171+ // TODO(T200095131): convert to int64_t when supported by AOTI
172+ // group_size is a meta tensor with size (group_size)
173+ const at::Tensor& group_size_tensor) {
174+ int64_t group_size = group_size_tensor.size (0 );
175+
176+ int n = weight_qvals.size (0 );
177+ int k = weight_qvals.size (1 );
178+
179+ using namespace torchao ::operators::cpu::linear::
180+ channelwise_8bit_activation_groupwise_lowbit_weight;
181+
182+ auto ukernel_config = get_ukernel_config<
183+ weight_nbit,
184+ true /* has_weight_zeros*/ ,
185+ false /* has_bias*/ ,
186+ false /* has_clamp*/ >();
187+
188+ auto packed_weight_data_size =
189+ get_packed_weight_data_size (ukernel_config, n, k, group_size);
190+ return torch::empty ({packed_weight_data_size}).to (" meta" );
191+ }
192+
193+ template <int weight_nbit, bool has_weight_zeros>
101194at::Tensor linear_cpu (
102195 const at::Tensor& packed_weights,
103196 // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
@@ -123,7 +216,7 @@ at::Tensor linear_cpu(
123216
124217 auto ukernel_config = get_ukernel_config<
125218 weight_nbit,
126- false /* has_weight_zeros*/ ,
219+ has_weight_zeros /* has_weight_zeros*/ ,
127220 false /* has_bias*/ ,
128221 false /* has_clamp*/ >();
129222 auto linear_tiling_params = get_default_linear_tiling_params (
@@ -167,7 +260,7 @@ at::Tensor linear_cpu(
167260 return output_tensor;
168261}
169262
170- template <int weight_nbit>
263+ template <int weight_nbit, bool has_weight_zeros >
171264at::Tensor linear_meta (
172265 const at::Tensor& packed_weights,
173266 // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
@@ -187,26 +280,78 @@ at::Tensor linear_meta(
187280}
188281
189282TORCH_LIBRARY (torchao, m) {
283+ // Pack weights without zeros
284+ m.def (
285+ " _pack_weights_a8sz_w2s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
286+ m.def (
287+ " _pack_weights_a8sz_w3s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
288+ m.def (
289+ " _pack_weights_a8sz_w4s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
290+ m.def (
291+ " _pack_weights_a8sz_w5s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
292+ // Pack weights with zeros
293+ m.def (
294+ " _pack_weights_a8sz_w2sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
295+ m.def (
296+ " _pack_weights_a8sz_w3sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
297+ m.def (
298+ " _pack_weights_a8sz_w4sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
299+ m.def (
300+ " _pack_weights_a8sz_w5sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
301+ // Linear weights without zeros
302+ m.def (
303+ " _linear_a8sz_w2s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
304+ m.def (
305+ " _linear_a8sz_w3s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
306+ m.def (
307+ " _linear_a8sz_w4s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
308+ m.def (
309+ " _linear_a8sz_w5s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
310+ // Linear weights with zeros
190311 m.def (
191- " _pack_weights_3bit (Tensor weight_qvals , Tensor weight_scales , Tensor group_size) -> Tensor" );
312+ " _linear_a8sz_w2sz (Tensor packed_weights , Tensor n , Tensor k, Tensor group_size, Tensor activations ) -> Tensor" );
192313 m.def (
193- " _linear_3bit (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
314+ " _linear_a8sz_w3sz (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
194315 m.def (
195- " _pack_weights_4bit (Tensor weight_qvals , Tensor weight_scales , Tensor group_size) -> Tensor" );
316+ " _linear_a8sz_w4sz (Tensor packed_weights , Tensor n , Tensor k, Tensor group_size, Tensor activations ) -> Tensor" );
196317 m.def (
197- " _linear_4bit (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
318+ " _linear_a8sz_w5sz (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
198319}
199320
200321TORCH_LIBRARY_IMPL (torchao, CPU, m) {
201- m.impl (" _pack_weights_3bit" , &pack_weights_cpu<3 >);
202- m.impl (" _linear_3bit" , &linear_cpu<3 >);
203- m.impl (" _pack_weights_4bit" , &pack_weights_cpu<4 >);
204- m.impl (" _linear_4bit" , &linear_cpu<4 >);
322+ m.impl (" _pack_weights_a8sz_w2s" , &pack_weights_without_zeros_cpu<2 >);
323+ m.impl (" _pack_weights_a8sz_w3s" , &pack_weights_without_zeros_cpu<3 >);
324+ m.impl (" _pack_weights_a8sz_w4s" , &pack_weights_without_zeros_cpu<4 >);
325+ m.impl (" _pack_weights_a8sz_w5s" , &pack_weights_without_zeros_cpu<5 >);
326+ m.impl (" _pack_weights_a8sz_w2sz" , &pack_weights_with_zeros_cpu<2 >);
327+ m.impl (" _pack_weights_a8sz_w3sz" , &pack_weights_with_zeros_cpu<3 >);
328+ m.impl (" _pack_weights_a8sz_w4sz" , &pack_weights_with_zeros_cpu<4 >);
329+ m.impl (" _pack_weights_a8sz_w5sz" , &pack_weights_with_zeros_cpu<5 >);
330+ m.impl (" _linear_a8sz_w2s" , &linear_cpu<2 , false >);
331+ m.impl (" _linear_a8sz_w3s" , &linear_cpu<3 , false >);
332+ m.impl (" _linear_a8sz_w4s" , &linear_cpu<4 , false >);
333+ m.impl (" _linear_a8sz_w5s" , &linear_cpu<5 , false >);
334+ m.impl (" _linear_a8sz_w2sz" , &linear_cpu<2 , true >);
335+ m.impl (" _linear_a8sz_w3sz" , &linear_cpu<3 , true >);
336+ m.impl (" _linear_a8sz_w4sz" , &linear_cpu<4 , true >);
337+ m.impl (" _linear_a8sz_w5sz" , &linear_cpu<5 , true >);
205338}
206339
207340TORCH_LIBRARY_IMPL (torchao, Meta, m) {
208- m.impl (" _pack_weights_3bit" , &pack_weights_meta<3 >);
209- m.impl (" _linear_3bit" , &linear_meta<3 >);
210- m.impl (" _pack_weights_4bit" , &pack_weights_meta<4 >);
211- m.impl (" _linear_4bit" , &linear_meta<4 >);
341+ m.impl (" _pack_weights_a8sz_w2s" , &pack_weights_without_zeros_meta<2 >);
342+ m.impl (" _pack_weights_a8sz_w3s" , &pack_weights_without_zeros_meta<3 >);
343+ m.impl (" _pack_weights_a8sz_w4s" , &pack_weights_without_zeros_meta<4 >);
344+ m.impl (" _pack_weights_a8sz_w5s" , &pack_weights_without_zeros_meta<5 >);
345+ m.impl (" _pack_weights_a8sz_w2sz" , &pack_weights_with_zeros_meta<2 >);
346+ m.impl (" _pack_weights_a8sz_w3sz" , &pack_weights_with_zeros_meta<3 >);
347+ m.impl (" _pack_weights_a8sz_w4sz" , &pack_weights_with_zeros_meta<4 >);
348+ m.impl (" _pack_weights_a8sz_w5sz" , &pack_weights_with_zeros_meta<5 >);
349+ m.impl (" _linear_a8sz_w2s" , &linear_meta<2 , false >);
350+ m.impl (" _linear_a8sz_w3s" , &linear_meta<3 , false >);
351+ m.impl (" _linear_a8sz_w4s" , &linear_meta<4 , false >);
352+ m.impl (" _linear_a8sz_w5s" , &linear_meta<5 , false >);
353+ m.impl (" _linear_a8sz_w2sz" , &linear_meta<2 , true >);
354+ m.impl (" _linear_a8sz_w3sz" , &linear_meta<3 , true >);
355+ m.impl (" _linear_a8sz_w4sz" , &linear_meta<4 , true >);
356+ m.impl (" _linear_a8sz_w5sz" , &linear_meta<5 , true >);
212357}
0 commit comments