88import subprocess
99from datetime import datetime
1010
11- from setuptools import find_packages , setup
11+ from setuptools import Extension , find_packages , setup
1212
1313current_date = datetime .now ().strftime ("%Y%m%d" )
1414
@@ -41,6 +41,14 @@ def read_version(file_path="version.txt"):
4141
4242use_cpp = os .getenv ("USE_CPP" )
4343
44+ import platform
45+
46+ build_torchao_experimental = (
47+ use_cpp == "1"
48+ and platform .machine ().startswith ("arm64" )
49+ and platform .system () == "Darwin"
50+ )
51+
4452version_prefix = read_version ()
4553# Version is version.dev year month date if using nightlies and version if not
4654version = (
@@ -49,6 +57,11 @@ def read_version(file_path="version.txt"):
4957 else version_prefix
5058)
5159
60+
61+ def use_debug_mode ():
62+ return os .getenv ("DEBUG" , "0" ) == "1"
63+
64+
5265import torch
5366from torch .utils .cpp_extension import (
5467 CUDA_HOME ,
@@ -59,8 +72,61 @@ def read_version(file_path="version.txt"):
5972)
6073
6174
75+ # BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
76+ class TorchAOBuildExt (BuildExtension ):
77+ def __init__ (self , * args , ** kwargs ) -> None :
78+ super ().__init__ (* args , ** kwargs )
79+
80+ def build_extensions (self ):
81+ cmake_extensions = [
82+ ext for ext in self .extensions if isinstance (ext , CMakeExtension )
83+ ]
84+ other_extensions = [
85+ ext for ext in self .extensions if not isinstance (ext , CMakeExtension )
86+ ]
87+ for ext in cmake_extensions :
88+ self .build_cmake (ext )
89+
90+ # Use BuildExtension to build other extensions
91+ self .extensions = other_extensions
92+ super ().build_extensions ()
93+
94+ self .extensions = other_extensions + cmake_extensions
95+
96+ def build_cmake (self , ext ):
97+ extdir = os .path .abspath (os .path .dirname (self .get_ext_fullpath (ext .name )))
98+
99+ build_type = "Debug" if use_debug_mode () else "Release"
100+
101+ from distutils .sysconfig import get_python_lib
102+
103+ torch_dir = get_python_lib () + "/torch/share/cmake/Torch"
104+
105+ if not os .path .exists (self .build_temp ):
106+ os .makedirs (self .build_temp )
107+
108+ subprocess .check_call (
109+ [
110+ "cmake" ,
111+ ext .sourcedir ,
112+ "-DCMAKE_BUILD_TYPE=" + build_type ,
113+ "-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF" ,
114+ "-DTorch_DIR=" + torch_dir ,
115+ "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir ,
116+ ],
117+ cwd = self .build_temp ,
118+ )
119+ subprocess .check_call (["cmake" , "--build" , "." ], cwd = self .build_temp )
120+
121+
122+ class CMakeExtension (Extension ):
123+ def __init__ (self , name , sourcedir = "" ):
124+ Extension .__init__ (self , name , sources = [])
125+ self .sourcedir = os .path .abspath (sourcedir )
126+
127+
62128def get_extensions ():
63- debug_mode = os . getenv ( "DEBUG" , "0" ) == "1"
129+ debug_mode = use_debug_mode ()
64130 if debug_mode :
65131 print ("Compiling in debug mode" )
66132
@@ -129,18 +195,25 @@ def get_extensions():
129195 if use_cuda :
130196 sources += cuda_sources
131197
132- if len (sources ) == 0 :
133- return None
198+ ext_modules = []
199+ if len (sources ) > 0 :
200+ ext_modules .append (
201+ extension (
202+ "torchao._C" ,
203+ sources ,
204+ py_limited_api = True ,
205+ extra_compile_args = extra_compile_args ,
206+ extra_link_args = extra_link_args ,
207+ )
208+ )
134209
135- ext_modules = [
136- extension (
137- "torchao._C" ,
138- sources ,
139- py_limited_api = True ,
140- extra_compile_args = extra_compile_args ,
141- extra_link_args = extra_link_args ,
210+ if build_torchao_experimental :
211+ ext_modules .append (
212+ CMakeExtension (
213+ "torchao.experimental" ,
214+ sourcedir = "torchao/experimental" ,
215+ )
142216 )
143- ]
144217
145218 return ext_modules
146219
@@ -159,6 +232,6 @@ def get_extensions():
159232 long_description = open ("README.md" ).read (),
160233 long_description_content_type = "text/markdown" ,
161234 url = "https://github.com/pytorch/ao" ,
162- cmdclass = {"build_ext" : BuildExtension },
235+ cmdclass = {"build_ext" : TorchAOBuildExt },
163236 options = {"bdist_wheel" : {"py_limited_api" : "cp39" }},
164237)
0 commit comments