Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[submodule "3rdparty/tilelang"]
path = 3rdparty/tilelang
url = https://github.com/tile-ai/tilelang
branch = main
branch = v0.1.4
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
50 changes: 29 additions & 21 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,34 +116,39 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
Returns:
str: The path where the LLVM archive was extracted.
"""
ubuntu_version = "16.04"
ubuntu_version = "22.04"
if version >= "17.0.0":
ubuntu_version = "22.04"
if version >= "16.0.0":
ubuntu_version = "20.04"
elif version >= "13.0.0":
ubuntu_version = "18.04"
else:
ubuntu_version = "16.04"

base_url = (f"https://github.com/llvm/llvm-project/releases/download/llvmorg-{version}")
file_name = f"clang+llvm-{version}-{'aarch64-linux-gnu' if is_aarch64 else f'x86_64-linux-gnu-ubuntu-{ubuntu_version}'}.tar.xz"

download_url = f"{base_url}/{file_name}"

# Download the file
print(f"Downloading {file_name} from {download_url}")
with urllib.request.urlopen(download_url) as response:
if response.status != 200:
raise Exception(f"Download failed with status code {response.status}")
file_content = response.read()
# Ensure the extract path exists
os.makedirs(extract_path, exist_ok=True)

# if the file already exists, remove it
if os.path.exists(os.path.join(extract_path, file_name)):
os.remove(os.path.join(extract_path, file_name))

# Extract the file
print(f"Extracting {file_name} to {extract_path}")
with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar:
tar.extractall(path=extract_path)
# Download the file
print(f"Downloading {file_name} from {download_url}")
with urllib.request.urlopen(download_url) as response:
if response.status != 200:
raise Exception(f"Download failed with status code {response.status}")
file_content = response.read()
# Ensure the extract path exists
os.makedirs(extract_path, exist_ok=True)

# if the file already exists, remove it
if os.path.exists(os.path.join(extract_path, file_name)):
os.remove(os.path.join(extract_path, file_name))

# Extract the file
print(f"Extracting {file_name} to {extract_path}")
with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar:
tar.extractall(path=extract_path)

print("Download and extraction completed successfully.")
return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", "")))
Expand All @@ -153,11 +158,10 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
"bitblas": ["py.typed"],
}

LLVM_VERSION = "10.0.1"
LLVM_VERSION = "10.0.1"
IS_AARCH64 = False # Set to True if on an aarch64 platform
EXTRACT_PATH = "3rdparty" # Default extraction path


def update_submodules():
"""Updates git submodules."""
try:
Expand All @@ -178,7 +182,7 @@ def build_tvm(llvm_config_path):
# Set LLVM path and enable CUDA in config.cmake
with open("config.cmake", "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
config_file.write("set(USE_CUDA /usr/local/cuda)\n")
config_file.write(f"set(USE_CUDA {os.environ.get("CUDA_HOME", "/usr/local/cuda")})\n")
# Run CMake and make
try:
subprocess.check_call(["cmake", ".."])
Expand Down Expand Up @@ -215,7 +219,9 @@ def build_tilelang(TVM_PREBUILD_PATH: str = "./3rdparty/tvm/build"):
def setup_llvm_for_tvm():
"""Downloads and extracts LLVM, then configures TVM to use it."""
# Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script

extract_path = download_and_extract_llvm(LLVM_VERSION, IS_AARCH64, EXTRACT_PATH)

llvm_config_path = os.path.join(extract_path, "bin", "llvm-config")
return extract_path, llvm_config_path

Expand All @@ -242,14 +248,16 @@ class BitBLASBuilPydCommand(build_py):
def run(self):
build_py.run(self)
# custom build tvm
update_submodules()
# update_submodules()
# Set up LLVM for TVM
_, llvm_path = setup_llvm_for_tvm()
# Build TVM
build_tvm(llvm_path)
# Build TILELANG
build_tilelang()

print("===== BUILD dependencies successfully ! =====")

# Copy the built TVM to the package directory
TVM_PREBUILD_ITEMS = [
"3rdparty/tvm/build/libtvm_runtime.so",
Expand Down