Skip to content

Commit 20a5673

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
cmake torchao_ops_mps_linear_fp_act_xbit_weight (#1304)
Summary: Pull Request resolved: #1304 Move from setup.py to cmake for building custom torchao mps ops Differential Revision: D66120124
1 parent b714026 commit 20a5673

File tree

7 files changed

+122
-45
lines changed

7 files changed

+122
-45
lines changed

torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from typing import Optional
28
import os
9+
import sys
310
import yaml
411

5-
torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
6-
assert torchao_root is not None, "TORCHAO_ROOT is not set"
12+
if len(sys.argv) != 2:
13+
print("Usage: gen_metal_shader_lib.py <output_file>")
14+
sys.exit(1)
15+
16+
# Output file where the generated code will be written
17+
OUTPUT_FILE = sys.argv[1]
718

8-
MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")
19+
MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
920

1021
# Path to yaml file containing the list of .metal files to include
1122
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")
@@ -21,9 +32,6 @@
2132
# Path to the folder containing the .metal files
2233
METAL_DIR = os.path.join(MPS_DIR, "metal")
2334

24-
# Output file where the generated code will be written
25-
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")
26-
2735
prefix = """/**
2836
* This file is generated by gen_metal_shader_lib.py
2937
*/
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
9+
project(torchao_ops_mps_linear_fp_act_xbit_weight)
10+
11+
set(CMAKE_CXX_STANDARD 17)
12+
set(CMAKE_CXX_STANDARD_REQUIRED YES)
13+
14+
if (NOT CMAKE_BUILD_TYPE)
15+
set(CMAKE_BUILD_TYPE Release)
16+
endif()
17+
18+
if (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
19+
if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
20+
message(FATAL_ERROR "Unified Memory requires Apple Silicon arquitecture")
21+
endif()
22+
else()
23+
message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS")
24+
endif()
25+
26+
find_package(Torch REQUIRED)
27+
28+
if(NOT TORCHAO_INCLUDE_DIRS)
29+
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
30+
endif()
31+
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
32+
33+
include_directories(${TORCHAO_INCLUDE_DIRS})
34+
include_directories(${CMAKE_INSTALL_PREFIX}/include)
35+
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
36+
37+
target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
38+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
39+
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1)
40+
41+
# Enable Metal support
42+
find_library(METAL_LIB Metal)
43+
find_library(FOUNDATION_LIB Foundation)
44+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})
45+
46+
install(
47+
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
48+
EXPORT _targets
49+
DESTINATION lib
50+
)

torchao/experimental/ops/mps/register.mm renamed to torchao/experimental/ops/mps/aten/register.mm

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// LICENSE file in the root directory of this source tree.
66

77
// clang-format off
8-
#include <torch/extension.h>
8+
#include <torch/library.h>
99
#include <ATen/native/mps/OperationUtils.h>
1010
#include <torchao/experimental/kernels/mps/src/lowbit.h>
1111
// clang-format on
@@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
147147
return B;
148148
}
149149

150-
// Registers _C as a Python extension module.
151-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
152-
153150
TORCH_LIBRARY(torchao, m) {
154151
m.def("_pack_weight_1bit(Tensor W) -> Tensor");
155152
m.def("_pack_weight_2bit(Tensor W) -> Tensor");
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash -eu
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
cd "$(dirname "$BASH_SOURCE")"
9+
10+
export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
11+
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
12+
export CMAKE_OUT=$(python -c "import sys; print(sys.prefix)")/torchao_mps/cmake-out
13+
echo "CMAKE_OUT: ${CMAKE_OUT}"
14+
15+
export INCLUDE_PATH=${CMAKE_OUT}/include
16+
mkdir -p ${INCLUDE_PATH}/torchao/experimental/kernels/mps/src/
17+
export GENERATED_METAL_SHADER_LIB=${INCLUDE_PATH}/torchao/experimental/kernels/mps/src/metal_shader_lib.h
18+
python ../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
19+
echo "GENERATED_METAL_SHADER_LIB: ${GENERATED_METAL_SHADER_LIB}"
20+
21+
cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
22+
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
23+
-S . \
24+
-B ${CMAKE_OUT}
25+
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release

torchao/experimental/ops/mps/setup.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

torchao/experimental/ops/mps/test/test_lowbit.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,31 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
8+
import sys
79
import torch
8-
import torchao_mps_ops
910
import unittest
1011

12+
from parameterized import parameterized
1113

12-
def parameterized(test_cases):
13-
def decorator(func):
14-
def wrapper(self):
15-
for case in test_cases:
16-
with self.subTest(case=case):
17-
func(self, *case)
14+
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
15+
libpath = os.path.join(sys.prefix, "torchao_mps/cmake-out/lib/", libname)
1816

19-
return wrapper
17+
try:
18+
torch.ops.load_library(libpath)
19+
except:
20+
print(f"Failed to load library {libpath}")
21+
raise
2022

21-
return decorator
23+
for nbit in range(1, 8):
24+
op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
25+
assert op is not None
26+
op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
27+
assert op is not None
2228

2329

2430
class TestLowBitQuantWeightsLinear(unittest.TestCase):
25-
cases = [
31+
CASES = [
2632
(nbit, *param)
2733
for nbit in range(1, 8)
2834
for param in [
@@ -73,7 +79,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):
7379
W = scales * W + zeros
7480
return torch.mm(A, W.t())
7581

76-
@parameterized(cases)
82+
@parameterized.expand(CASES)
7783
def test_linear(self, nbit, M=1, K=32, N=32, group_size=32):
7884
print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}")
7985
A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit)

torchao/experimental/ops/mps/test/test_quantizer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,27 @@
1111
import sys
1212

1313
import torch
14-
import torchao_mps_ops
1514
import unittest
1615

1716
from parameterized import parameterized
1817
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
1918
from torchao.experimental.quant_api import _quantize
2019

20+
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
21+
libpath = os.path.join(sys.prefix, "torchao_mps/cmake-out/lib/", libname)
22+
23+
try:
24+
torch.ops.load_library(libpath)
25+
except:
26+
print(f"Failed to load library {libpath}")
27+
raise
28+
29+
for nbit in range(1, 8):
30+
op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
31+
assert op is not None
32+
op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
33+
assert op is not None
34+
2135

2236
class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
2337
BITWIDTHS = range(1, 8)

0 commit comments

Comments
 (0)