88import subprocess
99from datetime import datetime
1010
11- from setuptools import find_packages , setup
11+ from setuptools import find_packages , setup , Extension
1212
1313current_date = datetime .now ().strftime ("%Y%m%d" )
1414
@@ -41,6 +41,13 @@ def read_version(file_path="version.txt"):
4141
4242use_cpp = os .getenv ("USE_CPP" )
4343
44+ import platform
45+ build_torchao_experimental = (
46+ use_cpp == "1" and
47+ platform .machine ().startswith ("arm64" ) and
48+ platform .system () == "Darwin"
49+ )
50+
4451version_prefix = read_version ()
4552# Version is version.dev year month date if using nightlies and version if not
4653version = (
@@ -49,6 +56,9 @@ def read_version(file_path="version.txt"):
4956 else version_prefix
5057)
5158
59+ def use_debug_mode ():
60+ return os .getenv ('DEBUG' , '0' ) == '1'
61+
5262import torch
5363from torch .utils .cpp_extension import (
5464 CUDA_HOME ,
@@ -59,8 +69,51 @@ def read_version(file_path="version.txt"):
5969)
6070
6171
72+ # BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
73+ class TorchAOBuildExt (BuildExtension ):
74+ def __init__ (self , * args , ** kwargs ) -> None :
75+ super ().__init__ (* args , ** kwargs )
76+
77+ def build_extensions (self ):
78+ cmake_extensions = [ext for ext in self .extensions if isinstance (ext , CMakeExtension )]
79+ other_extensions = [ext for ext in self .extensions if not isinstance (ext , CMakeExtension )]
80+ for ext in cmake_extensions :
81+ self .build_cmake (ext )
82+
83+ # Use BuildExtension to build other extensions
84+ self .extensions = other_extensions
85+ super ().build_extensions ()
86+
87+ self .extensions = other_extensions + cmake_extensions
88+
89+ def build_cmake (self , ext ):
90+ extdir = os .path .abspath (os .path .dirname (self .get_ext_fullpath (ext .name )))
91+
92+ build_type = "Debug" if use_debug_mode () else "Release"
93+
94+ from distutils .sysconfig import get_python_lib
95+ torch_dir = get_python_lib () + "/torch/share/cmake/Torch"
96+
97+ if not os .path .exists (self .build_temp ):
98+ os .makedirs (self .build_temp )
99+
100+ subprocess .check_call (
101+ ['cmake' , ext .sourcedir , '-DCMAKE_BUILD_TYPE=' + build_type , '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF' , '-DTorch_DIR=' + torch_dir , '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir ],
102+ cwd = self .build_temp
103+ )
104+ subprocess .check_call (
105+ ['cmake' , '--build' , '.' ],
106+ cwd = self .build_temp
107+ )
108+
109+ class CMakeExtension (Extension ):
110+ def __init__ (self , name , sourcedir = '' ):
111+ Extension .__init__ (self , name , sources = [])
112+ self .sourcedir = os .path .abspath (sourcedir )
113+
114+
62115def get_extensions ():
63- debug_mode = os . getenv ( "DEBUG" , "0" ) == "1"
116+ debug_mode = use_debug_mode ()
64117 if debug_mode :
65118 print ("Compiling in debug mode" )
66119
@@ -129,18 +182,25 @@ def get_extensions():
129182 if use_cuda :
130183 sources += cuda_sources
131184
132- if len (sources ) == 0 :
133- return None
134-
135- ext_modules = [
185+ ext_modules = []
186+ if len (sources ) > 0 :
187+ ext_modules .append (
136188 extension (
137189 "torchao._C" ,
138190 sources ,
139191 py_limited_api = True ,
140192 extra_compile_args = extra_compile_args ,
141193 extra_link_args = extra_link_args ,
142194 )
143- ]
195+ )
196+
197+ if build_torchao_experimental :
198+ ext_modules .append (
199+ CMakeExtension (
200+ "torchao.experimental" ,
201+ sourcedir = "torchao/experimental" ,
202+ )
203+ )
144204
145205 return ext_modules
146206
@@ -159,6 +219,6 @@ def get_extensions():
159219 long_description = open ("README.md" ).read (),
160220 long_description_content_type = "text/markdown" ,
161221 url = "https://github.com/pytorch/ao" ,
162- cmdclass = {"build_ext" : BuildExtension },
222+ cmdclass = {"build_ext" : TorchAOBuildExt },
163223 options = {"bdist_wheel" : {"py_limited_api" : "cp39" }},
164224)
0 commit comments