88from datetime import datetime
99import subprocess
1010
11- from setuptools import find_packages , setup
11+ from setuptools import find_packages , setup , Extension
1212
1313current_date = datetime .now ().strftime ("%Y%m%d" )
1414
@@ -35,10 +35,19 @@ def read_version(file_path="version.txt"):
3535
3636use_cpp = os .getenv ('USE_CPP' )
3737
38+ import platform
39+ build_torchao_experimental = (
40+ use_cpp == "1" and
41+ platform .machine ().startswith ("arm64" )
42+ )
43+
3844version_prefix = read_version ()
3945# Version is version.dev year month date if using nightlies and version if not
4046version = f"{ version_prefix } .dev{ current_date } " if os .environ .get ("TORCHAO_NIGHTLY" ) else version_prefix
4147
48+ def use_debug_mode ():
49+ return os .getenv ('DEBUG' , '0' ) == '1'
50+
4251import torch
4352
4453from torch .utils .cpp_extension import (
@@ -50,8 +59,51 @@ def read_version(file_path="version.txt"):
5059)
5160
5261
62+ # BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
63+ class TorchAOBuildExt (BuildExtension ):
64+ def __init__ (self , * args , ** kwargs ) -> None :
65+ super ().__init__ (* args , ** kwargs )
66+
67+ def build_extensions (self ):
68+ cmake_extensions = [ext for ext in self .extensions if isinstance (ext , CMakeExtension )]
69+ other_extensions = [ext for ext in self .extensions if not isinstance (ext , CMakeExtension )]
70+ for ext in cmake_extensions :
71+ self .build_cmake (ext )
72+
73+ # Use BuildExtension to build other extensions
74+ self .extensions = other_extensions
75+ super ().build_extensions ()
76+
77+ self .extensions = other_extensions + cmake_extensions
78+
79+ def build_cmake (self , ext ):
80+ extdir = os .path .abspath (os .path .dirname (self .get_ext_fullpath (ext .name )))
81+
82+ build_type = "Debug" if use_debug_mode () else "Release"
83+
84+ from distutils .sysconfig import get_python_lib
85+ torch_dir = get_python_lib () + "/torch/share/cmake/Torch"
86+
87+ if not os .path .exists (self .build_temp ):
88+ os .makedirs (self .build_temp )
89+
90+ subprocess .check_call (
91+ ['cmake' , ext .sourcedir , '-DCMAKE_BUILD_TYPE=' + build_type , '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF' , '-DTorch_DIR=' + torch_dir , '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir ],
92+ cwd = self .build_temp
93+ )
94+ subprocess .check_call (
95+ ['cmake' , '--build' , '.' ],
96+ cwd = self .build_temp
97+ )
98+
99+ class CMakeExtension (Extension ):
100+ def __init__ (self , name , sourcedir = '' ):
101+ Extension .__init__ (self , name , sources = [])
102+ self .sourcedir = os .path .abspath (sourcedir )
103+
104+
53105def get_extensions ():
54- debug_mode = os . getenv ( 'DEBUG' , '0' ) == '1'
106+ debug_mode = use_debug_mode ()
55107 if debug_mode :
56108 print ("Compiling in debug mode" )
57109
@@ -115,18 +167,25 @@ def get_extensions():
115167 if use_cuda :
116168 sources += cuda_sources
117169
118- if len (sources ) == 0 :
119- return None
120-
121- ext_modules = [
170+ ext_modules = []
171+ if len (sources ) > 0 :
172+ ext_modules .append (
122173 extension (
123174 "torchao._C" ,
124175 sources ,
125176 py_limited_api = True ,
126177 extra_compile_args = extra_compile_args ,
127178 extra_link_args = extra_link_args ,
128179 )
129- ]
180+ )
181+
182+ if build_torchao_experimental :
183+ ext_modules .append (
184+ CMakeExtension (
185+ "torchao.experimental" ,
186+ sourcedir = "torchao/experimental" ,
187+ )
188+ )
130189
131190 return ext_modules
132191
@@ -145,7 +204,7 @@ def get_extensions():
145204 long_description = open ("README.md" ).read (),
146205 long_description_content_type = "text/markdown" ,
147206 url = "https://github.com/pytorch/ao" ,
148- cmdclass = {"build_ext" : BuildExtension },
207+ cmdclass = {"build_ext" : TorchAOBuildExt },
149208 options = {"bdist_wheel" : {
150209 "py_limited_api" : "cp39"
151210 }},
0 commit comments