diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index 521f2a5718..fdee217434 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -116,6 +116,8 @@ if(TORCHAO_BUILD_ATEN_OPS) ops/embedding_xbit/op_embedding_xbit_aten.cpp ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp + ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp + ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp ) list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten}) @@ -161,7 +163,9 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS) ops/embedding_xbit/op_embedding_xbit_executorch.cpp ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp - ) + ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp + ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp) + list(TRANSFORM _torchao_op_srcs_executorch PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") add_library(torchao_ops_executorch STATIC ${_torchao_op_srcs_executorch}) target_link_torchao_parallel_backend(torchao_ops_executorch executorch) diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp b/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp new file mode 100644 index 0000000000..06046a4ce9 --- /dev/null +++ b/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp @@ -0,0 +1,80 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#define DEFINE_PACK_OP(weight_nbit) \ + m.def( \ + "_pack_groupwise_" #weight_nbit \ + "bit_weight_with_lut(Tensor weight_qval_idxs, Tensor luts, int scale_group_size, int lut_group_size, Tensor? weight_scales, Tensor? bias, str? target) -> Tensor"); + +#define DEFINE_LINEAR_OP(weight_nbit) \ + m.def( \ + "_linear_groupwise_" #weight_nbit \ + "bit_weight_with_lut(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k) -> Tensor"); \ + m.def( \ + "_linear_groupwise_" #weight_nbit \ + "bit_weight_with_lut.out(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k, *, Tensor(a!) out) -> Tensor(a!)"); + +#define DEFINE_PACK_CPU_IMPL(weight_nbit) \ + m.impl( \ + "_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \ + &pack_weights_with_lut_cpu); + +#define DEFINE_PACK_META_IMPL(weight_nbit) \ + m.impl( \ + "_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \ + &pack_weights_with_lut_meta); + +#define DEFINE_LINEAR_CPU_IMPL(weight_nbit) \ + m.impl( \ + "_linear_groupwise_" #weight_nbit "bit_weight_with_lut", \ + &linear_cpu); \ + m.impl( \ + "_linear_groupwise_" #weight_nbit "bit_weight_with_lut.out", \ + &linear_out_cpu); + +#define DEFINE_LINEAR_META_IMPL(weight_nbit) \ + m.impl( \ + "_linear_groupwise_" #weight_nbit "bit_weight_with_lut", \ + &linear_meta); \ + + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + DEFINE_PACK_OP(1); + DEFINE_PACK_OP(2); + DEFINE_PACK_OP(3); + DEFINE_PACK_OP(4); + + DEFINE_LINEAR_OP(1); + DEFINE_LINEAR_OP(2); + DEFINE_LINEAR_OP(3); + DEFINE_LINEAR_OP(4); +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + DEFINE_PACK_CPU_IMPL(1); + DEFINE_PACK_CPU_IMPL(2); + DEFINE_PACK_CPU_IMPL(3); + DEFINE_PACK_CPU_IMPL(4); + + DEFINE_LINEAR_CPU_IMPL(1); + DEFINE_LINEAR_CPU_IMPL(2); + DEFINE_LINEAR_CPU_IMPL(3); + DEFINE_LINEAR_CPU_IMPL(4); +} + +TORCH_LIBRARY_IMPL(torchao, Meta, m) { + DEFINE_PACK_META_IMPL(1); + DEFINE_PACK_META_IMPL(2); + DEFINE_PACK_META_IMPL(3); + DEFINE_PACK_META_IMPL(4); + + DEFINE_LINEAR_META_IMPL(1); + DEFINE_LINEAR_META_IMPL(2); + DEFINE_LINEAR_META_IMPL(3); + DEFINE_LINEAR_META_IMPL(4); +}