Skip to content

Commit e4f77da

Browse files
metascroyfacebook-github-bot
authored andcommitted
torchao setup.py with cmake
Summary: Initial draft of using cmake in torchao's build process. Install torchao with: ``` USE_CPP=1 pip install . ``` If on an arm64 machine, it builds the dynamic library for torchao at site-packages/torchao/libtorchao_ops_aten.dylib. On import of torchao, if this library is found it is loaded. Differential Revision: D67777662
1 parent c59bce5 commit e4f77da

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

setup.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from datetime import datetime
99
import subprocess
1010

11-
from setuptools import find_packages, setup
11+
from setuptools import find_packages, setup, Extension
1212

1313
current_date = datetime.now().strftime("%Y%m%d")
1414

@@ -35,6 +35,12 @@ def read_version(file_path="version.txt"):
3535

3636
use_cpp = os.getenv('USE_CPP')
3737

38+
import platform
39+
build_torchao_experimental = (
40+
use_cpp == "1" and
41+
platform.machine().startswith("arm64")
42+
)
43+
3844
version_prefix = read_version()
3945
# Version is version.dev year month date if using nightlies and version if not
4046
version = f"{version_prefix}.dev{current_date}" if os.environ.get("TORCHAO_NIGHTLY") else version_prefix
@@ -50,6 +56,44 @@ def read_version(file_path="version.txt"):
5056
)
5157

5258

59+
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
60+
class TorchAOBuildExt(BuildExtension):
61+
def build_extensions(self):
62+
cmake_extensions = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
63+
other_extensions = [ext for ext in self.extensions if not isinstance(ext, CMakeExtension)]
64+
65+
for ext in cmake_extensions:
66+
self.build_cmake(ext)
67+
for ext in other_extensions:
68+
self.build_other(ext)
69+
70+
def build_cmake(self, ext):
71+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
72+
73+
from distutils.sysconfig import get_python_lib
74+
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
75+
76+
if not os.path.exists(self.build_temp):
77+
os.makedirs(self.build_temp)
78+
79+
subprocess.check_call(
80+
['cmake', ext.sourcedir, '-DCMAKE_BUILD_TYPE=Release', '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF', '-DTorch_DIR=' + torch_dir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir],
81+
cwd=self.build_temp
82+
)
83+
subprocess.check_call(
84+
['cmake', '--build', '.'],
85+
cwd=self.build_temp
86+
)
87+
88+
def build_other(self, ext):
89+
super().build_extension(ext)
90+
91+
class CMakeExtension(Extension):
92+
def __init__(self, name, sourcedir=''):
93+
Extension.__init__(self, name, sources=[])
94+
self.sourcedir = os.path.abspath(sourcedir)
95+
96+
5397
def get_extensions():
5498
debug_mode = os.getenv('DEBUG', '0') == '1'
5599
if debug_mode:
@@ -103,18 +147,25 @@ def get_extensions():
103147
if use_cuda:
104148
sources += cuda_sources
105149

106-
if len(sources) == 0:
107-
return None
108-
109-
ext_modules = [
150+
ext_modules = []
151+
if len(sources) > 0:
152+
ext_modules.append(
110153
extension(
111154
"torchao._C",
112155
sources,
113156
py_limited_api=True,
114157
extra_compile_args=extra_compile_args,
115158
extra_link_args=extra_link_args,
116159
)
117-
]
160+
)
161+
162+
if build_torchao_experimental:
163+
ext_modules.append(
164+
CMakeExtension(
165+
"torchao.experimental",
166+
sourcedir="torchao/experimental",
167+
)
168+
)
118169

119170
return ext_modules
120171

@@ -133,7 +184,7 @@ def get_extensions():
133184
long_description=open("README.md").read(),
134185
long_description_content_type="text/markdown",
135186
url="https://github.com/pytorch/ao",
136-
cmdclass={"build_ext": BuildExtension},
187+
cmdclass={"build_ext": TorchAOBuildExt},
137188
options={"bdist_wheel": {
138189
"py_limited_api": "cp39"
139190
}},

torchao/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
2828
torch.ops.load_library(so_files[0])
2929
from . import ops
30+
31+
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
32+
if len(experimental_lib) > 0:
33+
assert len(so_files) == 1, f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
34+
torch.ops.load_library(experimental_lib[0])
3035
except:
3136
logging.debug("Skipping import of cpp extensions")
3237

0 commit comments

Comments
 (0)