Skip to content

Commit 9e2bda1

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
cmake torchao_ops_mps_linear_fp_act_xbit_weight
Summary: Move from setup.py to cmake for building custom torchao mps ops Differential Revision: D66120124
1 parent 20b08ee commit 9e2bda1

File tree

5 files changed

+74
-28
lines changed

5 files changed

+74
-28
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
9+
project(torchao_ops_mps_linear_fp_act_xbit_weight)
10+
11+
set(CMAKE_CXX_STANDARD 17)
12+
set(CMAKE_CXX_STANDARD_REQUIRED YES)
13+
14+
if (NOT CMAKE_BUILD_TYPE)
15+
set(CMAKE_BUILD_TYPE Release)
16+
endif()
17+
18+
find_package(Torch REQUIRED)
19+
20+
if(NOT TORCHAO_INCLUDE_DIRS)
21+
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
22+
endif()
23+
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
24+
25+
include_directories(${TORCHAO_INCLUDE_DIRS})
26+
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
27+
28+
target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
29+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
30+
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1)
31+
32+
# Enable Metal support
33+
if (APPLE)
34+
find_library(METAL_LIB Metal)
35+
find_library(FOUNDATION_LIB Foundation)
36+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})
37+
else()
38+
message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS")
39+
endif()
40+
41+
install(
42+
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
43+
EXPORT _targets
44+
DESTINATION lib
45+
)

torchao/experimental/ops/mps/register.mm renamed to torchao/experimental/ops/mps/aten/register.mm

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// LICENSE file in the root directory of this source tree.
66

77
// clang-format off
8-
#include <torch/extension.h>
8+
#include <torch/library.h>
99
#include <ATen/native/mps/OperationUtils.h>
1010
#include <torchao/experimental/kernels/mps/src/lowbit.h>
1111
// clang-format on
@@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
147147
return B;
148148
}
149149

150-
// Registers _C as a Python extension module.
151-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
152-
153150
TORCH_LIBRARY(torchao, m) {
154151
m.def("_pack_weight_1bit(Tensor W) -> Tensor");
155152
m.def("_pack_weight_2bit(Tensor W) -> Tensor");
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash -eu
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
python ../../kernels/mps/codegen/gen_metal_shader_lib.py
9+
10+
export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
11+
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
12+
export CMAKE_OUT=cmake-out
13+
14+
cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
15+
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
16+
-S . \
17+
-B ${CMAKE_OUT}
18+
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release
19+
20+
rm ../../kernels/mps/src/metal_shader_lib.h

torchao/experimental/ops/mps/setup.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

torchao/experimental/ops/mps/test/test_lowbit.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
78
import torch
8-
import torchao_mps_ops
99
import unittest
1010

11+
path_libtorchao_ops_mps_aten = os.path.abspath(
12+
os.path.join(
13+
os.path.dirname(__file__), "../cmake-out/lib/libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
14+
)
15+
)
16+
torch.ops.load_library(path_libtorchao_ops_mps_aten)
17+
1118

1219
def parameterized(test_cases):
1320
def decorator(func):

0 commit comments

Comments
 (0)