Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ if(TORCHAO_BUILD_ATEN_OPS)
ops/embedding_xbit/op_embedding_xbit_aten.cpp
ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp
ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp
ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp
)
list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten})
Expand Down Expand Up @@ -161,7 +163,9 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
ops/embedding_xbit/op_embedding_xbit_executorch.cpp
ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp
ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp
)
ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp)

list(TRANSFORM _torchao_op_srcs_executorch PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
add_library(torchao_ops_executorch STATIC ${_torchao_op_srcs_executorch})
target_link_torchao_parallel_backend(torchao_ops_executorch executorch)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h>

#define DEFINE_PACK_OP(weight_nbit) \
m.def( \
"_pack_groupwise_" #weight_nbit \
"bit_weight_with_lut(Tensor weight_qval_idxs, Tensor luts, int scale_group_size, int lut_group_size, Tensor? weight_scales, Tensor? bias, str? target) -> Tensor");

#define DEFINE_LINEAR_OP(weight_nbit) \
m.def( \
"_linear_groupwise_" #weight_nbit \
"bit_weight_with_lut(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k) -> Tensor"); \
m.def( \
"_linear_groupwise_" #weight_nbit \
"bit_weight_with_lut.out(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k, *, Tensor(a!) out) -> Tensor(a!)");

#define DEFINE_PACK_CPU_IMPL(weight_nbit) \
m.impl( \
"_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \
&pack_weights_with_lut_cpu<weight_nbit>);

#define DEFINE_PACK_META_IMPL(weight_nbit) \
m.impl( \
"_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \
&pack_weights_with_lut_meta<weight_nbit>);

#define DEFINE_LINEAR_CPU_IMPL(weight_nbit) \
m.impl( \
"_linear_groupwise_" #weight_nbit "bit_weight_with_lut", \
&linear_cpu<weight_nbit>); \
m.impl( \
"_linear_groupwise_" #weight_nbit "bit_weight_with_lut.out", \
&linear_out_cpu<weight_nbit>);

#define DEFINE_LINEAR_META_IMPL(weight_nbit) \
m.impl( \
"_linear_groupwise_" #weight_nbit "bit_weight_with_lut", \
&linear_meta<weight_nbit>); \


TORCH_LIBRARY_FRAGMENT(torchao, m) {
DEFINE_PACK_OP(1);
DEFINE_PACK_OP(2);
DEFINE_PACK_OP(3);
DEFINE_PACK_OP(4);

DEFINE_LINEAR_OP(1);
DEFINE_LINEAR_OP(2);
DEFINE_LINEAR_OP(3);
DEFINE_LINEAR_OP(4);
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
DEFINE_PACK_CPU_IMPL(1);
DEFINE_PACK_CPU_IMPL(2);
DEFINE_PACK_CPU_IMPL(3);
DEFINE_PACK_CPU_IMPL(4);

DEFINE_LINEAR_CPU_IMPL(1);
DEFINE_LINEAR_CPU_IMPL(2);
DEFINE_LINEAR_CPU_IMPL(3);
DEFINE_LINEAR_CPU_IMPL(4);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
DEFINE_PACK_META_IMPL(1);
DEFINE_PACK_META_IMPL(2);
DEFINE_PACK_META_IMPL(3);
DEFINE_PACK_META_IMPL(4);

DEFINE_LINEAR_META_IMPL(1);
DEFINE_LINEAR_META_IMPL(2);
DEFINE_LINEAR_META_IMPL(3);
DEFINE_LINEAR_META_IMPL(4);
}
Loading