Skip to content

Commit 6db1789

Browse files
metascroyfacebook-github-bot
authored andcommitted
torchao setup.py with cmake (#1490)
Summary: Initial draft of using cmake in torchao's build process. Install torchao with: ``` USE_CPP=1 pip install . ``` If on an arm64 machine, it builds the dynamic library for torchao at site-packages/torchao/libtorchao_ops_aten.dylib. On import of torchao, if this library is found it is loaded. Differential Revision: D67777662
1 parent 1be4307 commit 6db1789

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

setup.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from datetime import datetime
99
import subprocess
1010

11-
from setuptools import find_packages, setup
11+
from setuptools import find_packages, setup, Extension
1212

1313
current_date = datetime.now().strftime("%Y%m%d")
1414

@@ -35,10 +35,19 @@ def read_version(file_path="version.txt"):
3535

3636
use_cpp = os.getenv('USE_CPP')
3737

38+
import platform
39+
build_torchao_experimental = (
40+
use_cpp == "1" and
41+
platform.machine().startswith("arm64")
42+
)
43+
3844
version_prefix = read_version()
3945
# Version is version.dev year month date if using nightlies and version if not
4046
version = f"{version_prefix}.dev{current_date}" if os.environ.get("TORCHAO_NIGHTLY") else version_prefix
4147

48+
def use_debug_mode():
49+
return os.getenv('DEBUG', '0') == '1'
50+
4251
import torch
4352

4453
from torch.utils.cpp_extension import (
@@ -50,8 +59,48 @@ def read_version(file_path="version.txt"):
5059
)
5160

5261

62+
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
63+
class TorchAOBuildExt(BuildExtension):
64+
def build_extensions(self):
65+
cmake_extensions = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
66+
other_extensions = [ext for ext in self.extensions if not isinstance(ext, CMakeExtension)]
67+
68+
for ext in cmake_extensions:
69+
self.build_cmake(ext)
70+
for ext in other_extensions:
71+
self.build_other(ext)
72+
73+
def build_cmake(self, ext):
74+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
75+
76+
build_type = "Debug" if use_debug_mode() else "Release"
77+
78+
from distutils.sysconfig import get_python_lib
79+
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
80+
81+
if not os.path.exists(self.build_temp):
82+
os.makedirs(self.build_temp)
83+
84+
subprocess.check_call(
85+
['cmake', ext.sourcedir, '-DCMAKE_BUILD_TYPE=' + build_type, '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF', '-DTorch_DIR=' + torch_dir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir],
86+
cwd=self.build_temp
87+
)
88+
subprocess.check_call(
89+
['cmake', '--build', '.'],
90+
cwd=self.build_temp
91+
)
92+
93+
def build_other(self, ext):
94+
super().build_extension(ext)
95+
96+
class CMakeExtension(Extension):
97+
def __init__(self, name, sourcedir=''):
98+
Extension.__init__(self, name, sources=[])
99+
self.sourcedir = os.path.abspath(sourcedir)
100+
101+
53102
def get_extensions():
54-
debug_mode = os.getenv('DEBUG', '0') == '1'
103+
debug_mode = use_debug_mode()
55104
if debug_mode:
56105
print("Compiling in debug mode")
57106

@@ -115,18 +164,25 @@ def get_extensions():
115164
if use_cuda:
116165
sources += cuda_sources
117166

118-
if len(sources) == 0:
119-
return None
120-
121-
ext_modules = [
167+
ext_modules = []
168+
if len(sources) > 0:
169+
ext_modules.append(
122170
extension(
123171
"torchao._C",
124172
sources,
125173
py_limited_api=True,
126174
extra_compile_args=extra_compile_args,
127175
extra_link_args=extra_link_args,
128176
)
129-
]
177+
)
178+
179+
if build_torchao_experimental:
180+
ext_modules.append(
181+
CMakeExtension(
182+
"torchao.experimental",
183+
sourcedir="torchao/experimental",
184+
)
185+
)
130186

131187
return ext_modules
132188

@@ -145,7 +201,7 @@ def get_extensions():
145201
long_description=open("README.md").read(),
146202
long_description_content_type="text/markdown",
147203
url="https://github.com/pytorch/ao",
148-
cmdclass={"build_ext": BuildExtension},
204+
cmdclass={"build_ext": TorchAOBuildExt},
149205
options={"bdist_wheel": {
150206
"py_limited_api": "cp39"
151207
}},

torchao/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
2828
torch.ops.load_library(so_files[0])
2929
from . import ops
30+
31+
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
32+
if len(experimental_lib) > 0:
33+
assert len(experimental_lib) == 1, f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
34+
torch.ops.load_library(experimental_lib[0])
3035
except:
3136
logging.debug("Skipping import of cpp extensions")
3237

0 commit comments

Comments
 (0)