From 6d92e407b63afed56e94f8ef6c5e86c280e12886 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Wed, 16 Oct 2024 05:18:03 +0000 Subject: [PATCH 01/15] enable build for rocm for fp6_llm --- setup.py | 53 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 229e18eec6..eb56f863ab 100644 --- a/setup.py +++ b/setup.py @@ -46,9 +46,11 @@ def read_version(file_path="version.txt"): CUDAExtension, BuildExtension, CUDA_HOME, + ROCM_HOME, IS_WINDOWS ) +IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) def get_extensions(): debug_mode = os.getenv('DEBUG', '0') == '1' @@ -57,11 +59,11 @@ def get_extensions(): if not torch.cuda.is_available(): print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions") - if CUDA_HOME is None and torch.cuda.is_available(): - print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") + if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available(): + print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions") print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit") - use_cuda = torch.cuda.is_available() and CUDA_HOME is not None + use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None) extension = CUDAExtension if use_cuda else CppExtension if not IS_WINDOWS: @@ -71,15 +73,14 @@ def get_extensions(): "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", ], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - "-t=0", - ] } + if use_cuda and not IS_ROCM: + extra_compile_args["nvcc"] = ["-O3" if not debug_mode else "-O0", "-t=0",] if debug_mode: extra_compile_args["cxx"].append("-g") - extra_compile_args["nvcc"].append("-g") + if "nvcc" in extra_compile_args: + extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) else: @@ -107,17 +108,35 @@ def get_extensions(): extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) - if use_cuda: + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "fp6_llm") + hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) + + if not IS_ROCM and use_cuda: sources += cuda_sources - ext_modules = [ - extension( - "torchao._C", - sources, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ] + # TOOD: Remove this and use what CUDA has once we fix all the builds. + if IS_ROCM and use_cuda: + sources += hip_sources + + ## TODO: remove this condition and use what we have in CUDA once we fix the individual builds. + if not IS_ROCM: + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] + else: + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] return ext_modules From f1a22cf227ea1ce6757b1db13d3af5f6cf1d51f5 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Wed, 23 Oct 2024 07:58:49 +0000 Subject: [PATCH 02/15] enable tiled layout extension --- setup.py | 28 +++------ .../tensor_core_tiled_layout.cu | 59 +++++++++++++++++-- 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/setup.py b/setup.py index eb56f863ab..d3710a46da 100644 --- a/setup.py +++ b/setup.py @@ -108,7 +108,7 @@ def get_extensions(): extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "fp6_llm") + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout") hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) if not IS_ROCM and use_cuda: @@ -119,24 +119,14 @@ def get_extensions(): sources += hip_sources ## TODO: remove this condition and use what we have in CUDA once we fix the individual builds. - if not IS_ROCM: - ext_modules = [ - extension( - "torchao._C", - sources, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ] - else: - ext_modules = [ - extension( - "torchao._C", - sources, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ] + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] return ext_modules diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 7af29caac9..fef650ba88 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 #include #include @@ -7,13 +7,24 @@ #include #include +#if defined(USE_ROCM) +#include +#include +#include +#endif + template constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { static_assert(std::is_integral::value && std::is_integral::value, ""); const uint64_t blocks = a / b + (a % b != 0); return blocks; } + +#if defined(USE_ROCM) +constexpr int32_t kWarpSize = 64; +#else constexpr int32_t kWarpSize = 32; +#endif //Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization //https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 @@ -30,38 +41,68 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { uint32_t const source_i4s = source; // First, we extract the i4s and construct an intermediate fp16 number. +#if !defined(USE_ROCM) static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; +#endif static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; + +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[0]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + #pragma unroll for (int ii = 1; ii < kElements / 2; ++ii) { i4s >>= 4; // or is it 8? // (i4s & 0x000f000f) | 0x43004300 +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[ii]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif } // This is the BF16 {-136, -136} represented as an integer. +#if defined(USE_ROCM) +#if ROCM_VERSION >= 60200 + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); +#else + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); +#endif +#else static constexpr uint32_t BF16_BIAS = 0xC308C308; static constexpr uint32_t BF16_ONE = 0x3F803F80; +#endif // Finally, we construct the output numbers. #pragma unroll for (int ii = 0; ii < kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction +#if defined(USE_ROCM) + result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); +#else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); +#endif } return result; @@ -123,11 +164,17 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); - - // Vectorize scales and zeros - __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); - __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); + if (scales_and_zeros.has_value()) { + const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + + // Vectorize scales and zeros + __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); + __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); + } + else { + __nv_bfloat162 scale2 = {1.0f, 1.0f}; + __nv_bfloat162 zero2 = {0.0f, 0.0f}; + } #pragma unroll for (int i = 0; i < 4; i++) { From 0bef6ca901bdc0ec084fd473093cab1f059d70c4 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 23 Oct 2024 14:55:45 -0700 Subject: [PATCH 03/15] fix build error related to option --- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index fef650ba88..59f310d872 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,8 +164,9 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - if (scales_and_zeros.has_value()) { - const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + if (scales_and_zeros) { + const auto&sz = *scales_and_zeros; + const __nv_bfloat16 *pSZ = reinterpret_cast(&sz[qgroup][n0][0]); // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); From 893ae03b3db8ea3d2c284d83657cdaf3ae431bd8 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 23 Oct 2024 15:04:59 -0700 Subject: [PATCH 04/15] require rocm 6.2 --- .../cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 59f310d872..2fa74bad2d 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 #include #include From a0d37887009c1b9a7c4b9aa5adbaf8cfd0e754b2 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 24 Oct 2024 10:58:36 +0000 Subject: [PATCH 05/15] enable tensor tiled layout extension with successful compilation --- .../tensor_core_tiled_layout.cu | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index fef650ba88..f7ccf6f6bc 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,16 +164,20 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - if (scales_and_zeros.has_value()) { - const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + __nv_bfloat162 scale2, zero2; + if (scales_and_zeros) { + const auto& sz = *scales_and_zeros; + const __nv_bfloat16 *pSZ = reinterpret_cast(&sz[qgroup][n0][0]); // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); } else { - __nv_bfloat162 scale2 = {1.0f, 1.0f}; - __nv_bfloat162 zero2 = {0.0f, 0.0f}; + //scale2 = {1.0f, 1.0f}; + //zero2 = {0.0f, 0.0f}; + scale2.x = 1.0f; scale2.y = 1.0f; + zero2.x = 1.0f; zero2.y = 1.0f; } #pragma unroll @@ -237,6 +241,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout( group_size == 256); TORCH_CHECK(numQGroups == K / group_size); TORCH_CHECK(scales_and_zeros.dim() == 3); + std::cout << "CHAI: " << scales_and_zeros.size(1) << "," << scales_and_zeros.size(2) << "," << N << std::endl; TORCH_CHECK(scales_and_zeros.size(1) == N); TORCH_CHECK(scales_and_zeros.size(2) == 2); From 3e2c6a1acff3d3c4642a94e1248012b6da900b44 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 29 Oct 2024 14:37:29 -0700 Subject: [PATCH 06/15] clean-up --- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 4835c2c7f3..8542d63d5b 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #include #include @@ -174,8 +174,6 @@ __global__ void _dequantize_int4_kernel( __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); } else { - //scale2 = {1.0f, 1.0f}; - //zero2 = {0.0f, 0.0f}; scale2.x = 1.0f; scale2.y = 1.0f; zero2.x = 1.0f; zero2.y = 1.0f; } @@ -241,7 +239,6 @@ at::Tensor _dequantize_tensor_core_tiled_layout( group_size == 256); TORCH_CHECK(numQGroups == K / group_size); TORCH_CHECK(scales_and_zeros.dim() == 3); - std::cout << "CHAI: " << scales_and_zeros.size(1) << "," << scales_and_zeros.size(2) << "," << N << std::endl; TORCH_CHECK(scales_and_zeros.size(1) == N); TORCH_CHECK(scales_and_zeros.size(2) == 2); From 91d3c752ae72d369fc1af257b430421d2c95d697 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 29 Oct 2024 14:50:57 -0700 Subject: [PATCH 07/15] fix potential memory access issue --- .../tensor_core_tiled_layout.cu | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 8542d63d5b..3bbdc032e3 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,18 +164,15 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - __nv_bfloat162 scale2, zero2; - if (scales_and_zeros) { - const auto&sz = *scales_and_zeros; - const __nv_bfloat16 *pSZ = reinterpret_cast(&sz[qgroup][n0][0]); + __nv_bfloat162 scale2 = {1.0f, 1.0f}; + __nv_bfloat162 zero2 = {1.0f, 1.0f}; - // Vectorize scales and zeros - __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); - __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); - } - else { - scale2.x = 1.0f; scale2.y = 1.0f; - zero2.x = 1.0f; zero2.y = 1.0f; + if (scales_and_zeros) { + const auto& sz = *scales_and_zeros; + const __nv_bfloat16* pSZ = reinterpret_cast(&sz[qgroup][n0][0]); + + scale2 = __bfloat162bfloat162(pSZ[0]); + zero2 = __bfloat162bfloat162(pSZ[1]); } #pragma unroll From 38b7d1c45cd663d560f966788ab7ae7a48bd68b9 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Tue, 12 Nov 2024 16:17:10 -0600 Subject: [PATCH 08/15] fix __nv_bfloat162 init --- .../cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 3bbdc032e3..f876c2667a 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,8 +164,8 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - __nv_bfloat162 scale2 = {1.0f, 1.0f}; - __nv_bfloat162 zero2 = {1.0f, 1.0f}; + __nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); + __nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); if (scales_and_zeros) { const auto& sz = *scales_and_zeros; From 279f4b3994e10340d1e90fbdd520529a568bdc02 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Tue, 12 Nov 2024 16:27:23 -0600 Subject: [PATCH 09/15] add comment for MI300x isa --- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index f876c2667a..10976d9417 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -50,7 +50,10 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; - +// AMD MI300X ISA that performs two bitwise operations in a single instruction: +// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM +// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values +// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16 #if defined(USE_ROCM) asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(h[0]) From bbf5a727567eb1cb76735c444665326a0037ceba Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Mon, 6 Jan 2025 06:39:17 +0000 Subject: [PATCH 10/15] fix build for non-rocm --- setup.py | 2 +- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4b4c5ce061..e0a589c93f 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ def get_extensions(): if not torch.cuda.is_available(): print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions") - if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available(): + if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available(): print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions") print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit") diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 480d30d1c5..d1c5d49fda 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -167,6 +167,7 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; +#if defined(USE_ROCM) __nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); __nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); @@ -177,6 +178,11 @@ __global__ void _dequantize_int4_kernel( scale2 = __bfloat162bfloat162(pSZ[0]); zero2 = __bfloat162bfloat162(pSZ[1]); } +#else + const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); + __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); +#endif #pragma unroll for (int i = 0; i < 4; i++) { From 26fa19cbc4fa6cb717475a497128cfc1e6878f1b Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 9 Jan 2025 16:39:01 -0600 Subject: [PATCH 11/15] better naming --- .../tensor_core_tiled_layout.cu | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index d1c5d49fda..f3f1bf9636 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -83,15 +83,15 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { // This is the BF16 {-136, -136} represented as an integer. #if defined(USE_ROCM) #if ROCM_VERSION >= 60200 - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); + auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); + auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); #else - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); + auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16{0xC308}); + auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); #endif #else - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; + static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308; + static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80; #endif // Finally, we construct the output numbers. @@ -100,11 +100,11 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction #if defined(USE_ROCM) - result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); + result.vals[ii] = __hfma2(result.vals[ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR); #else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) - : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + : "r"(h[ii]), "r"(BF16_UNIT_VALUE), "r"(BF16_SCALE_FACTOR)); #endif } @@ -369,3 +369,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { } #endif +git checkout main -- file.txt From 8f0209660325deaffb8aee56c29fac5befcab542 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 9 Jan 2025 16:43:22 -0600 Subject: [PATCH 12/15] lint --- .../cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index f3f1bf9636..8ae24c7dce 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -369,4 +369,3 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { } #endif -git checkout main -- file.txt From ced96477a8ceb00091ad4408fee7128394460982 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 9 Jan 2025 16:48:20 -0600 Subject: [PATCH 13/15] resolve merge conflict --- setup.py | 89 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 30 deletions(-) diff --git a/setup.py b/setup.py index 5b01da8fa6..fef71dcbdb 100644 --- a/setup.py +++ b/setup.py @@ -3,10 +3,10 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import os import glob -from datetime import datetime +import os import subprocess +from datetime import datetime from setuptools import find_packages, setup @@ -15,14 +15,20 @@ def get_git_commit_id(): try: - return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() + return ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode("ascii") + .strip() + ) except Exception: return "" + def read_requirements(file_path): with open(file_path, "r") as file: return file.read().splitlines() + def read_version(file_path="version.txt"): with open(file_path, "r") as file: return file.readline().strip() @@ -33,37 +39,49 @@ def read_version(file_path="version.txt"): if version_suffix is None: version_suffix = f"+git{get_git_commit_id()}" -use_cpp = os.getenv('USE_CPP') +use_cpp = os.getenv("USE_CPP") version_prefix = read_version() # Version is version.dev year month date if using nightlies and version if not -version = f"{version_prefix}.dev{current_date}" if os.environ.get("TORCHAO_NIGHTLY") else version_prefix +version = ( + f"{version_prefix}.dev{current_date}" + if os.environ.get("TORCHAO_NIGHTLY") + else version_prefix +) import torch - from torch.utils.cpp_extension import ( - CppExtension, - CUDAExtension, - BuildExtension, CUDA_HOME, + IS_WINDOWS, ROCM_HOME, - IS_WINDOWS + BuildExtension, + CppExtension, + CUDAExtension, ) IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) + def get_extensions(): - debug_mode = os.getenv('DEBUG', '0') == '1' + debug_mode = os.getenv("DEBUG", "0") == "1" if debug_mode: print("Compiling in debug mode") if not torch.cuda.is_available(): - print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions") + print( + "PyTorch GPU support is not available. Skipping compilation of CUDA extensions" + ) if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available(): - print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions") - print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit") + print( + "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" + ) + print( + "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" + ) - use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None) + use_cuda = torch.cuda.is_available() and ( + CUDA_HOME is not None or ROCM_HOME is not None + ) extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] @@ -86,10 +104,7 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) else: - extra_compile_args["cxx"] = [ - "/O2" if not debug_mode else "/Od", - "/permissive-" - ] + extra_compile_args["cxx"] = ["/O2" if not debug_mode else "/Od", "/permissive-"] if debug_mode: extra_compile_args["cxx"].append("/ZI") @@ -103,26 +118,42 @@ def get_extensions(): cutlass_dir = os.path.join(this_dir, "third_party", "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") if use_cutlass: - extra_compile_args["nvcc"].extend([ - "-DTORCHAO_USE_CUTLASS", - "-I" + cutlass_include_dir, - ]) + extra_compile_args["nvcc"].extend( + [ + "-DTORCHAO_USE_CUTLASS", + "-I" + cutlass_include_dir, + ] + ) this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) + cuda_sources = list( + glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) + ) - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout") - hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) + extensions_hip_dir = os.path.join( + extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin" + ) + hip_sources = list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) + ) if not IS_ROCM and use_cuda: sources += cuda_sources # TOOD: Remove this and use what CUDA has once we fix all the builds. if IS_ROCM and use_cuda: + # Add ROCm GPU architecture check + gpu_arch = torch.cuda.get_device_properties(0).name + if gpu_arch != "gfx942": + print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) + return None sources += hip_sources if len(sources) == 0: @@ -143,7 +174,7 @@ def get_extensions(): setup( name="torchao", - version=version+version_suffix, + version=version + version_suffix, packages=find_packages(), include_package_data=True, package_data={ @@ -156,7 +187,5 @@ def get_extensions(): long_description_content_type="text/markdown", url="https://github.com/pytorch/ao", cmdclass={"build_ext": BuildExtension}, - options={"bdist_wheel": { - "py_limited_api": "cp39" - }}, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) From 89bae5fb11857cdd64a94d14c16ded1155e9d2bb Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 15 Jan 2025 16:09:30 -0800 Subject: [PATCH 14/15] lint --- setup.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/setup.py b/setup.py index 2907f31cf5..187f7e15e2 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,6 @@ def use_debug_mode(): CUDAExtension, ) - IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) # Constant known variables used throughout this file @@ -196,7 +195,6 @@ def __init__(self, name, sourcedir=""): self.sourcedir = os.path.abspath(sourcedir) - def get_extensions(): debug_mode = use_debug_mode() if debug_mode: @@ -279,7 +277,6 @@ def get_extensions(): if not IS_ROCM and use_cuda: sources += cuda_sources - # TOOD: Remove this and use what CUDA has once we fix all the builds. if IS_ROCM and use_cuda: # Add ROCm GPU architecture check @@ -307,7 +304,6 @@ def get_extensions(): ) ) - if build_torchao_experimental: ext_modules.append( CMakeExtension( From 4148828cd7b7efea385a310397c548d9a43cbd9e Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Fri, 28 Feb 2025 10:04:26 -0800 Subject: [PATCH 15/15] Refactor: Rename `this_dir` to `curdir` in setup.py Minor variable renaming for clarity in extension directory path resolution --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 2d92a7c094..c3978a6702 100644 --- a/setup.py +++ b/setup.py @@ -247,8 +247,8 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") - this_dir = os.path.dirname(os.path.curdir) - extensions_dir = os.path.join(this_dir, "torchao", "csrc") + curdir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(curdir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") @@ -278,8 +278,8 @@ def get_extensions(): ] ) - this_dir = os.path.dirname(os.path.curdir) - extensions_dir = os.path.join(this_dir, "torchao", "csrc") + curdir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(curdir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda")