88from datetime import datetime
99import subprocess
1010
11- from setuptools import find_packages , setup
11+ from setuptools import find_packages , setup , Extension
1212
1313current_date = datetime .now ().strftime ("%Y%m%d" )
1414
@@ -35,6 +35,12 @@ def read_version(file_path="version.txt"):
3535
3636use_cpp = os .getenv ('USE_CPP' )
3737
38+ import platform
39+ build_torchao_experimental = (
40+ use_cpp == "1" and
41+ platform .machine ().startswith ("arm64" )
42+ )
43+
3844version_prefix = read_version ()
3945# Version is version.dev year month date if using nightlies and version if not
4046version = f"{ version_prefix } .dev{ current_date } " if os .environ .get ("TORCHAO_NIGHTLY" ) else version_prefix
@@ -50,6 +56,44 @@ def read_version(file_path="version.txt"):
5056)
5157
5258
59+ # BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
60+ class TorchAOBuildExt (BuildExtension ):
61+ def build_extensions (self ):
62+ cmake_extensions = [ext for ext in self .extensions if isinstance (ext , CMakeExtension )]
63+ other_extensions = [ext for ext in self .extensions if not isinstance (ext , CMakeExtension )]
64+
65+ for ext in cmake_extensions :
66+ self .build_cmake (ext )
67+ for ext in other_extensions :
68+ self .build_other (ext )
69+
70+ def build_cmake (self , ext ):
71+ extdir = os .path .abspath (os .path .dirname (self .get_ext_fullpath (ext .name )))
72+
73+ from distutils .sysconfig import get_python_lib
74+ torch_dir = get_python_lib () + "/torch/share/cmake/Torch"
75+
76+ if not os .path .exists (self .build_temp ):
77+ os .makedirs (self .build_temp )
78+
79+ subprocess .check_call (
80+ ['cmake' , ext .sourcedir , '-DCMAKE_BUILD_TYPE=Release' , '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF' , '-DTorch_DIR=' + torch_dir , '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir ],
81+ cwd = self .build_temp
82+ )
83+ subprocess .check_call (
84+ ['cmake' , '--build' , '.' ],
85+ cwd = self .build_temp
86+ )
87+
88+ def build_other (self , ext ):
89+ super ().build_extension (ext )
90+
91+ class CMakeExtension (Extension ):
92+ def __init__ (self , name , sourcedir = '' ):
93+ Extension .__init__ (self , name , sources = [])
94+ self .sourcedir = os .path .abspath (sourcedir )
95+
96+
5397def get_extensions ():
5498 debug_mode = os .getenv ('DEBUG' , '0' ) == '1'
5599 if debug_mode :
@@ -103,18 +147,25 @@ def get_extensions():
103147 if use_cuda :
104148 sources += cuda_sources
105149
106- if len (sources ) == 0 :
107- return None
108-
109- ext_modules = [
150+ ext_modules = []
151+ if len (sources ) > 0 :
152+ ext_modules .append (
110153 extension (
111154 "torchao._C" ,
112155 sources ,
113156 py_limited_api = True ,
114157 extra_compile_args = extra_compile_args ,
115158 extra_link_args = extra_link_args ,
116159 )
117- ]
160+ )
161+
162+ if build_torchao_experimental :
163+ ext_modules .append (
164+ CMakeExtension (
165+ "torchao.experimental" ,
166+ sourcedir = "torchao/experimental" ,
167+ )
168+ )
118169
119170 return ext_modules
120171
@@ -133,7 +184,7 @@ def get_extensions():
133184 long_description = open ("README.md" ).read (),
134185 long_description_content_type = "text/markdown" ,
135186 url = "https://github.com/pytorch/ao" ,
136- cmdclass = {"build_ext" : BuildExtension },
187+ cmdclass = {"build_ext" : TorchAOBuildExt },
137188 options = {"bdist_wheel" : {
138189 "py_limited_api" : "cp39"
139190 }},
0 commit comments