Skip to content

Commit 8e46608

Browse files
metascroyfacebook-github-bot
authored andcommitted
torchao setup.py with cmake (#1490)
Summary: Initial draft of using cmake in torchao's build process. Install torchao with: ``` USE_CPP=1 pip install . ``` If on an arm64 machine, it builds the dynamic library for torchao at site-packages/torchao/libtorchao_ops_aten.dylib. On import of torchao, if this library is found it is loaded. Reviewed By: kimishpatel, drisspg Differential Revision: D67777662
1 parent cedadc7 commit 8e46608

File tree

2 files changed

+98
-13
lines changed

2 files changed

+98
-13
lines changed

setup.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import subprocess
99
from datetime import datetime
1010

11-
from setuptools import find_packages, setup
11+
from setuptools import Extension, find_packages, setup
1212

1313
current_date = datetime.now().strftime("%Y%m%d")
1414

@@ -41,6 +41,14 @@ def read_version(file_path="version.txt"):
4141

4242
use_cpp = os.getenv("USE_CPP")
4343

44+
import platform
45+
46+
build_torchao_experimental = (
47+
use_cpp == "1"
48+
and platform.machine().startswith("arm64")
49+
and platform.system() == "Darwin"
50+
)
51+
4452
version_prefix = read_version()
4553
# Version is version.dev year month date if using nightlies and version if not
4654
version = (
@@ -49,6 +57,11 @@ def read_version(file_path="version.txt"):
4957
else version_prefix
5058
)
5159

60+
61+
def use_debug_mode():
62+
return os.getenv("DEBUG", "0") == "1"
63+
64+
5265
import torch
5366
from torch.utils.cpp_extension import (
5467
CUDA_HOME,
@@ -59,8 +72,61 @@ def read_version(file_path="version.txt"):
5972
)
6073

6174

75+
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
76+
class TorchAOBuildExt(BuildExtension):
77+
def __init__(self, *args, **kwargs) -> None:
78+
super().__init__(*args, **kwargs)
79+
80+
def build_extensions(self):
81+
cmake_extensions = [
82+
ext for ext in self.extensions if isinstance(ext, CMakeExtension)
83+
]
84+
other_extensions = [
85+
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
86+
]
87+
for ext in cmake_extensions:
88+
self.build_cmake(ext)
89+
90+
# Use BuildExtension to build other extensions
91+
self.extensions = other_extensions
92+
super().build_extensions()
93+
94+
self.extensions = other_extensions + cmake_extensions
95+
96+
def build_cmake(self, ext):
97+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
98+
99+
build_type = "Debug" if use_debug_mode() else "Release"
100+
101+
from distutils.sysconfig import get_python_lib
102+
103+
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
104+
105+
if not os.path.exists(self.build_temp):
106+
os.makedirs(self.build_temp)
107+
108+
subprocess.check_call(
109+
[
110+
"cmake",
111+
ext.sourcedir,
112+
"-DCMAKE_BUILD_TYPE=" + build_type,
113+
"-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF",
114+
"-DTorch_DIR=" + torch_dir,
115+
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
116+
],
117+
cwd=self.build_temp,
118+
)
119+
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
120+
121+
122+
class CMakeExtension(Extension):
123+
def __init__(self, name, sourcedir=""):
124+
Extension.__init__(self, name, sources=[])
125+
self.sourcedir = os.path.abspath(sourcedir)
126+
127+
62128
def get_extensions():
63-
debug_mode = os.getenv("DEBUG", "0") == "1"
129+
debug_mode = use_debug_mode()
64130
if debug_mode:
65131
print("Compiling in debug mode")
66132

@@ -129,18 +195,25 @@ def get_extensions():
129195
if use_cuda:
130196
sources += cuda_sources
131197

132-
if len(sources) == 0:
133-
return None
198+
ext_modules = []
199+
if len(sources) > 0:
200+
ext_modules.append(
201+
extension(
202+
"torchao._C",
203+
sources,
204+
py_limited_api=True,
205+
extra_compile_args=extra_compile_args,
206+
extra_link_args=extra_link_args,
207+
)
208+
)
134209

135-
ext_modules = [
136-
extension(
137-
"torchao._C",
138-
sources,
139-
py_limited_api=True,
140-
extra_compile_args=extra_compile_args,
141-
extra_link_args=extra_link_args,
210+
if build_torchao_experimental:
211+
ext_modules.append(
212+
CMakeExtension(
213+
"torchao.experimental",
214+
sourcedir="torchao/experimental",
215+
)
142216
)
143-
]
144217

145218
return ext_modules
146219

@@ -159,6 +232,6 @@ def get_extensions():
159232
long_description=open("README.md").read(),
160233
long_description_content_type="text/markdown",
161234
url="https://github.com/pytorch/ao",
162-
cmdclass={"build_ext": BuildExtension},
235+
cmdclass={"build_ext": TorchAOBuildExt},
163236
options={"bdist_wheel": {"py_limited_api": "cp39"}},
164237
)

torchao/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@
3232
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
3333
torch.ops.load_library(so_files[0])
3434
from . import ops
35+
36+
# The following library contains CPU kernels from torchao/experimental
37+
# They are built automatically by ao/setup.py if on an ARM machine.
38+
# They can also be built outside of the torchao install process by
39+
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
40+
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
41+
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
42+
if len(experimental_lib) > 0:
43+
assert (
44+
len(experimental_lib) == 1
45+
), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
46+
torch.ops.load_library(experimental_lib[0])
3547
except:
3648
logging.debug("Skipping import of cpp extensions")
3749

0 commit comments

Comments
 (0)