88from datetime import datetime
99import subprocess
1010
11- from setuptools import find_packages , setup
11+ from setuptools import find_packages , setup , Extension
1212
1313current_date = datetime .now ().strftime ("%Y%m%d" )
1414
@@ -35,10 +35,19 @@ def read_version(file_path="version.txt"):
3535
3636use_cpp = os .getenv ('USE_CPP' )
3737
38+ import platform
39+ build_torchao_experimental = (
40+ use_cpp == "1" and
41+ platform .machine ().startswith ("arm64" )
42+ )
43+
3844version_prefix = read_version ()
3945# Version is version.dev year month date if using nightlies and version if not
4046version = f"{ version_prefix } .dev{ current_date } " if os .environ .get ("TORCHAO_NIGHTLY" ) else version_prefix
4147
48+ def use_debug_mode ():
49+ return os .getenv ('DEBUG' , '0' ) == '1'
50+
4251import torch
4352
4453from torch .utils .cpp_extension import (
@@ -50,8 +59,48 @@ def read_version(file_path="version.txt"):
5059)
5160
5261
62+ # BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
63+ class TorchAOBuildExt (BuildExtension ):
64+ def build_extensions (self ):
65+ cmake_extensions = [ext for ext in self .extensions if isinstance (ext , CMakeExtension )]
66+ other_extensions = [ext for ext in self .extensions if not isinstance (ext , CMakeExtension )]
67+
68+ for ext in cmake_extensions :
69+ self .build_cmake (ext )
70+ for ext in other_extensions :
71+ self .build_other (ext )
72+
73+ def build_cmake (self , ext ):
74+ extdir = os .path .abspath (os .path .dirname (self .get_ext_fullpath (ext .name )))
75+
76+ build_type = "Debug" if use_debug_mode () else "Release"
77+
78+ from distutils .sysconfig import get_python_lib
79+ torch_dir = get_python_lib () + "/torch/share/cmake/Torch"
80+
81+ if not os .path .exists (self .build_temp ):
82+ os .makedirs (self .build_temp )
83+
84+ subprocess .check_call (
85+ ['cmake' , ext .sourcedir , '-DCMAKE_BUILD_TYPE=' + build_type , '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF' , '-DTorch_DIR=' + torch_dir , '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir ],
86+ cwd = self .build_temp
87+ )
88+ subprocess .check_call (
89+ ['cmake' , '--build' , '.' ],
90+ cwd = self .build_temp
91+ )
92+
93+ def build_other (self , ext ):
94+ super ().build_extension (ext )
95+
96+ class CMakeExtension (Extension ):
97+ def __init__ (self , name , sourcedir = '' ):
98+ Extension .__init__ (self , name , sources = [])
99+ self .sourcedir = os .path .abspath (sourcedir )
100+
101+
53102def get_extensions ():
54- debug_mode = os . getenv ( 'DEBUG' , '0' ) == '1'
103+ debug_mode = use_debug_mode ()
55104 if debug_mode :
56105 print ("Compiling in debug mode" )
57106
@@ -115,18 +164,25 @@ def get_extensions():
115164 if use_cuda :
116165 sources += cuda_sources
117166
118- if len (sources ) == 0 :
119- return None
120-
121- ext_modules = [
167+ ext_modules = []
168+ if len (sources ) > 0 :
169+ ext_modules .append (
122170 extension (
123171 "torchao._C" ,
124172 sources ,
125173 py_limited_api = True ,
126174 extra_compile_args = extra_compile_args ,
127175 extra_link_args = extra_link_args ,
128176 )
129- ]
177+ )
178+
179+ if build_torchao_experimental :
180+ ext_modules .append (
181+ CMakeExtension (
182+ "torchao.experimental" ,
183+ sourcedir = "torchao/experimental" ,
184+ )
185+ )
130186
131187 return ext_modules
132188
@@ -145,7 +201,7 @@ def get_extensions():
145201 long_description = open ("README.md" ).read (),
146202 long_description_content_type = "text/markdown" ,
147203 url = "https://github.com/pytorch/ao" ,
148- cmdclass = {"build_ext" : BuildExtension },
204+ cmdclass = {"build_ext" : TorchAOBuildExt },
149205 options = {"bdist_wheel" : {
150206 "py_limited_api" : "cp39"
151207 }},
0 commit comments