Skip to content

Commit 16c85e8

Browse files
metascroyfacebook-github-bot
authored andcommitted
torchao setup.py with cmake (#1490)
Summary: Initial draft of using cmake in torchao's build process. Install torchao with: ``` USE_CPP=1 pip install . ``` If on an arm64 machine, it builds the dynamic library for torchao at site-packages/torchao/libtorchao_ops_aten.dylib. On import of torchao, if this library is found it is loaded. Reviewed By: kimishpatel, drisspg Differential Revision: D67777662
1 parent 982141b commit 16c85e8

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

setup.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import subprocess
99
from datetime import datetime
1010

11-
from setuptools import find_packages, setup
11+
from setuptools import find_packages, setup, Extension
1212

1313
current_date = datetime.now().strftime("%Y%m%d")
1414

@@ -41,6 +41,13 @@ def read_version(file_path="version.txt"):
4141

4242
use_cpp = os.getenv("USE_CPP")
4343

44+
import platform
45+
build_torchao_experimental = (
46+
use_cpp == "1" and
47+
platform.machine().startswith("arm64") and
48+
platform.system() == "Darwin"
49+
)
50+
4451
version_prefix = read_version()
4552
# Version is version.dev year month date if using nightlies and version if not
4653
version = (
@@ -49,6 +56,9 @@ def read_version(file_path="version.txt"):
4956
else version_prefix
5057
)
5158

59+
def use_debug_mode():
60+
return os.getenv('DEBUG', '0') == '1'
61+
5262
import torch
5363
from torch.utils.cpp_extension import (
5464
CUDA_HOME,
@@ -59,8 +69,51 @@ def read_version(file_path="version.txt"):
5969
)
6070

6171

72+
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
73+
class TorchAOBuildExt(BuildExtension):
74+
def __init__(self, *args, **kwargs) -> None:
75+
super().__init__(*args, **kwargs)
76+
77+
def build_extensions(self):
78+
cmake_extensions = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
79+
other_extensions = [ext for ext in self.extensions if not isinstance(ext, CMakeExtension)]
80+
for ext in cmake_extensions:
81+
self.build_cmake(ext)
82+
83+
# Use BuildExtension to build other extensions
84+
self.extensions = other_extensions
85+
super().build_extensions()
86+
87+
self.extensions = other_extensions + cmake_extensions
88+
89+
def build_cmake(self, ext):
90+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
91+
92+
build_type = "Debug" if use_debug_mode() else "Release"
93+
94+
from distutils.sysconfig import get_python_lib
95+
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
96+
97+
if not os.path.exists(self.build_temp):
98+
os.makedirs(self.build_temp)
99+
100+
subprocess.check_call(
101+
['cmake', ext.sourcedir, '-DCMAKE_BUILD_TYPE=' + build_type, '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF', '-DTorch_DIR=' + torch_dir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir],
102+
cwd=self.build_temp
103+
)
104+
subprocess.check_call(
105+
['cmake', '--build', '.'],
106+
cwd=self.build_temp
107+
)
108+
109+
class CMakeExtension(Extension):
110+
def __init__(self, name, sourcedir=''):
111+
Extension.__init__(self, name, sources=[])
112+
self.sourcedir = os.path.abspath(sourcedir)
113+
114+
62115
def get_extensions():
63-
debug_mode = os.getenv("DEBUG", "0") == "1"
116+
debug_mode = use_debug_mode()
64117
if debug_mode:
65118
print("Compiling in debug mode")
66119

@@ -129,18 +182,25 @@ def get_extensions():
129182
if use_cuda:
130183
sources += cuda_sources
131184

132-
if len(sources) == 0:
133-
return None
134-
135-
ext_modules = [
185+
ext_modules = []
186+
if len(sources) > 0:
187+
ext_modules.append(
136188
extension(
137189
"torchao._C",
138190
sources,
139191
py_limited_api=True,
140192
extra_compile_args=extra_compile_args,
141193
extra_link_args=extra_link_args,
142194
)
143-
]
195+
)
196+
197+
if build_torchao_experimental:
198+
ext_modules.append(
199+
CMakeExtension(
200+
"torchao.experimental",
201+
sourcedir="torchao/experimental",
202+
)
203+
)
144204

145205
return ext_modules
146206

@@ -159,6 +219,6 @@ def get_extensions():
159219
long_description=open("README.md").read(),
160220
long_description_content_type="text/markdown",
161221
url="https://github.com/pytorch/ao",
162-
cmdclass={"build_ext": BuildExtension},
222+
cmdclass={"build_ext": TorchAOBuildExt},
163223
options={"bdist_wheel": {"py_limited_api": "cp39"}},
164224
)

torchao/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@
3232
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
3333
torch.ops.load_library(so_files[0])
3434
from . import ops
35+
36+
# The following library contains CPU kernels from torchao/experimental
37+
# They are built automatically by ao/setup.py if on an ARM machine.
38+
# They can also be built outside of the torchao install process by
39+
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
40+
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
41+
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
42+
if len(experimental_lib) > 0:
43+
assert len(experimental_lib) == 1, f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
44+
torch.ops.load_library(experimental_lib[0])
3545
except:
3646
logging.debug("Skipping import of cpp extensions")
3747

0 commit comments

Comments
 (0)