Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import os
import sys
import yaml

torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
assert torchao_root is not None, "TORCHAO_ROOT is not set"
if len(sys.argv) != 2:
print("Usage: gen_metal_shader_lib.py <output_file>")
sys.exit(1)

# Output file where the generated code will be written
OUTPUT_FILE = sys.argv[1]

MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")
MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

# Path to yaml file containing the list of .metal files to include
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")
Expand All @@ -21,9 +32,6 @@
# Path to the folder containing the .metal files
METAL_DIR = os.path.join(MPS_DIR, "metal")

# Output file where the generated code will be written
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")

prefix = """/**
* This file is generated by gen_metal_shader_lib.py
*/
Expand All @@ -48,6 +56,7 @@

"""

os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
with open(OUTPUT_FILE, "w") as outf:
outf.write(prefix)
for file in metal_files:
Expand Down
1 change: 1 addition & 0 deletions torchao/experimental/ops/mps/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cmake-out/
60 changes: 60 additions & 0 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)

project(torchao_ops_mps_linear_fp_act_xbit_weight)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED YES)

if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

if (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(FATAL_ERROR "Unified Memory requires Apple Silicon architecture")
endif()
else()
message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS")
endif()

find_package(Torch REQUIRED)

# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
add_custom_command(
OUTPUT ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
)
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})

if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
endif()
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")

include_directories(${TORCHAO_INCLUDE_DIRS})
include_directories(${CMAKE_INSTALL_PREFIX}/include)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)

target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1)

# Enable Metal support
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

install(
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
EXPORT _targets
DESTINATION lib
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// LICENSE file in the root directory of this source tree.

// clang-format off
#include <torch/extension.h>
#include <torch/library.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torchao/experimental/kernels/mps/src/lowbit.h>
// clang-format on
Expand Down Expand Up @@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
return B;
}

// Registers _C as a Python extension module.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

TORCH_LIBRARY(torchao, m) {
m.def("_pack_weight_1bit(Tensor W) -> Tensor");
m.def("_pack_weight_2bit(Tensor W) -> Tensor");
Expand Down
19 changes: 19 additions & 0 deletions torchao/experimental/ops/mps/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash -eu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cd "$(dirname "$BASH_SOURCE")"

export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cmake prefix path points to site-packages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's where the cmake stuff is. I am working on a conda environment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean thats where the cmake stuff is. prefix path is used for looking up packages https://cmake.org/cmake/help/latest/variable/CMAKE_PREFIX_PATH.html and other stuff

echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=${PWD}/cmake-out
echo "CMAKE_OUT: ${CMAKE_OUT}"

cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
-S . \
-B ${CMAKE_OUT}
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release
23 changes: 0 additions & 23 deletions torchao/experimental/ops/mps/setup.py

This file was deleted.

37 changes: 25 additions & 12 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,38 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized

def parameterized(test_cases):
def decorator(func):
def wrapper(self):
for case in test_cases:
with self.subTest(case=case):
func(self, *case)
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)

return wrapper

return decorator
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e


class TestLowBitQuantWeightsLinear(unittest.TestCase):
cases = [
CASES = [
(nbit, *param)
for nbit in range(1, 8)
for param in [
Expand Down Expand Up @@ -73,7 +86,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):
W = scales * W + zeros
return torch.mm(A, W.t())

@parameterized(cases)
@parameterized.expand(CASES)
def test_linear(self, nbit, M=1, K=32, N=32, group_size=32):
print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}")
A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit)
Expand Down
23 changes: 22 additions & 1 deletion torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,34 @@
import sys

import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
from torchao.experimental.quant_api import _quantize

libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e


class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
BITWIDTHS = range(1, 8)
Expand Down
Loading