From 6d92e407b63afed56e94f8ef6c5e86c280e12886 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Wed, 16 Oct 2024 05:18:03 +0000 Subject: [PATCH 01/24] enable build for rocm for fp6_llm --- setup.py | 53 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 229e18eec6..eb56f863ab 100644 --- a/setup.py +++ b/setup.py @@ -46,9 +46,11 @@ def read_version(file_path="version.txt"): CUDAExtension, BuildExtension, CUDA_HOME, + ROCM_HOME, IS_WINDOWS ) +IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) def get_extensions(): debug_mode = os.getenv('DEBUG', '0') == '1' @@ -57,11 +59,11 @@ def get_extensions(): if not torch.cuda.is_available(): print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions") - if CUDA_HOME is None and torch.cuda.is_available(): - print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") + if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available(): + print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions") print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit") - use_cuda = torch.cuda.is_available() and CUDA_HOME is not None + use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None) extension = CUDAExtension if use_cuda else CppExtension if not IS_WINDOWS: @@ -71,15 +73,14 @@ def get_extensions(): "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", ], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - "-t=0", - ] } + if use_cuda and not IS_ROCM: + extra_compile_args["nvcc"] = ["-O3" if not debug_mode else "-O0", "-t=0",] if debug_mode: extra_compile_args["cxx"].append("-g") - extra_compile_args["nvcc"].append("-g") + if "nvcc" in extra_compile_args: + extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) else: @@ -107,17 +108,35 @@ def get_extensions(): extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) - if use_cuda: + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "fp6_llm") + hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) + + if not IS_ROCM and use_cuda: sources += cuda_sources - ext_modules = [ - extension( - "torchao._C", - sources, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ] + # TOOD: Remove this and use what CUDA has once we fix all the builds. + if IS_ROCM and use_cuda: + sources += hip_sources + + ## TODO: remove this condition and use what we have in CUDA once we fix the individual builds. + if not IS_ROCM: + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] + else: + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] return ext_modules From f1a22cf227ea1ce6757b1db13d3af5f6cf1d51f5 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Wed, 23 Oct 2024 07:58:49 +0000 Subject: [PATCH 02/24] enable tiled layout extension --- setup.py | 28 +++------ .../tensor_core_tiled_layout.cu | 59 +++++++++++++++++-- 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/setup.py b/setup.py index eb56f863ab..d3710a46da 100644 --- a/setup.py +++ b/setup.py @@ -108,7 +108,7 @@ def get_extensions(): extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "fp6_llm") + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout") hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) if not IS_ROCM and use_cuda: @@ -119,24 +119,14 @@ def get_extensions(): sources += hip_sources ## TODO: remove this condition and use what we have in CUDA once we fix the individual builds. - if not IS_ROCM: - ext_modules = [ - extension( - "torchao._C", - sources, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ] - else: - ext_modules = [ - extension( - "torchao._C", - sources, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ] + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] return ext_modules diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 7af29caac9..fef650ba88 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 #include #include @@ -7,13 +7,24 @@ #include #include +#if defined(USE_ROCM) +#include +#include +#include +#endif + template constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { static_assert(std::is_integral::value && std::is_integral::value, ""); const uint64_t blocks = a / b + (a % b != 0); return blocks; } + +#if defined(USE_ROCM) +constexpr int32_t kWarpSize = 64; +#else constexpr int32_t kWarpSize = 32; +#endif //Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization //https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 @@ -30,38 +41,68 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { uint32_t const source_i4s = source; // First, we extract the i4s and construct an intermediate fp16 number. +#if !defined(USE_ROCM) static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; +#endif static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; + +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[0]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + #pragma unroll for (int ii = 1; ii < kElements / 2; ++ii) { i4s >>= 4; // or is it 8? // (i4s & 0x000f000f) | 0x43004300 +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[ii]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif } // This is the BF16 {-136, -136} represented as an integer. +#if defined(USE_ROCM) +#if ROCM_VERSION >= 60200 + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); +#else + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); +#endif +#else static constexpr uint32_t BF16_BIAS = 0xC308C308; static constexpr uint32_t BF16_ONE = 0x3F803F80; +#endif // Finally, we construct the output numbers. #pragma unroll for (int ii = 0; ii < kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction +#if defined(USE_ROCM) + result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); +#else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); +#endif } return result; @@ -123,11 +164,17 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); - - // Vectorize scales and zeros - __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); - __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); + if (scales_and_zeros.has_value()) { + const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + + // Vectorize scales and zeros + __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); + __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); + } + else { + __nv_bfloat162 scale2 = {1.0f, 1.0f}; + __nv_bfloat162 zero2 = {0.0f, 0.0f}; + } #pragma unroll for (int i = 0; i < 4; i++) { From 0bef6ca901bdc0ec084fd473093cab1f059d70c4 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 23 Oct 2024 14:55:45 -0700 Subject: [PATCH 03/24] fix build error related to option --- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index fef650ba88..59f310d872 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,8 +164,9 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - if (scales_and_zeros.has_value()) { - const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + if (scales_and_zeros) { + const auto&sz = *scales_and_zeros; + const __nv_bfloat16 *pSZ = reinterpret_cast(&sz[qgroup][n0][0]); // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); From 893ae03b3db8ea3d2c284d83657cdaf3ae431bd8 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 23 Oct 2024 15:04:59 -0700 Subject: [PATCH 04/24] require rocm 6.2 --- .../cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 59f310d872..2fa74bad2d 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 #include #include From a0d37887009c1b9a7c4b9aa5adbaf8cfd0e754b2 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 24 Oct 2024 10:58:36 +0000 Subject: [PATCH 05/24] enable tensor tiled layout extension with successful compilation --- .../tensor_core_tiled_layout.cu | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index fef650ba88..f7ccf6f6bc 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,16 +164,20 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - if (scales_and_zeros.has_value()) { - const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + __nv_bfloat162 scale2, zero2; + if (scales_and_zeros) { + const auto& sz = *scales_and_zeros; + const __nv_bfloat16 *pSZ = reinterpret_cast(&sz[qgroup][n0][0]); // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); } else { - __nv_bfloat162 scale2 = {1.0f, 1.0f}; - __nv_bfloat162 zero2 = {0.0f, 0.0f}; + //scale2 = {1.0f, 1.0f}; + //zero2 = {0.0f, 0.0f}; + scale2.x = 1.0f; scale2.y = 1.0f; + zero2.x = 1.0f; zero2.y = 1.0f; } #pragma unroll @@ -237,6 +241,7 @@ at::Tensor _dequantize_tensor_core_tiled_layout( group_size == 256); TORCH_CHECK(numQGroups == K / group_size); TORCH_CHECK(scales_and_zeros.dim() == 3); + std::cout << "CHAI: " << scales_and_zeros.size(1) << "," << scales_and_zeros.size(2) << "," << N << std::endl; TORCH_CHECK(scales_and_zeros.size(1) == N); TORCH_CHECK(scales_and_zeros.size(2) == 2); From 3e2c6a1acff3d3c4642a94e1248012b6da900b44 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 29 Oct 2024 14:37:29 -0700 Subject: [PATCH 06/24] clean-up --- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 4835c2c7f3..8542d63d5b 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere and ROCm > 5.7 +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #include #include @@ -174,8 +174,6 @@ __global__ void _dequantize_int4_kernel( __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); } else { - //scale2 = {1.0f, 1.0f}; - //zero2 = {0.0f, 0.0f}; scale2.x = 1.0f; scale2.y = 1.0f; zero2.x = 1.0f; zero2.y = 1.0f; } @@ -241,7 +239,6 @@ at::Tensor _dequantize_tensor_core_tiled_layout( group_size == 256); TORCH_CHECK(numQGroups == K / group_size); TORCH_CHECK(scales_and_zeros.dim() == 3); - std::cout << "CHAI: " << scales_and_zeros.size(1) << "," << scales_and_zeros.size(2) << "," << N << std::endl; TORCH_CHECK(scales_and_zeros.size(1) == N); TORCH_CHECK(scales_and_zeros.size(2) == 2); From 91d3c752ae72d369fc1af257b430421d2c95d697 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 29 Oct 2024 14:50:57 -0700 Subject: [PATCH 07/24] fix potential memory access issue --- .../tensor_core_tiled_layout.cu | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 8542d63d5b..3bbdc032e3 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,18 +164,15 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - __nv_bfloat162 scale2, zero2; - if (scales_and_zeros) { - const auto&sz = *scales_and_zeros; - const __nv_bfloat16 *pSZ = reinterpret_cast(&sz[qgroup][n0][0]); + __nv_bfloat162 scale2 = {1.0f, 1.0f}; + __nv_bfloat162 zero2 = {1.0f, 1.0f}; - // Vectorize scales and zeros - __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); - __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); - } - else { - scale2.x = 1.0f; scale2.y = 1.0f; - zero2.x = 1.0f; zero2.y = 1.0f; + if (scales_and_zeros) { + const auto& sz = *scales_and_zeros; + const __nv_bfloat16* pSZ = reinterpret_cast(&sz[qgroup][n0][0]); + + scale2 = __bfloat162bfloat162(pSZ[0]); + zero2 = __bfloat162bfloat162(pSZ[1]); } #pragma unroll From 38b7d1c45cd663d560f966788ab7ae7a48bd68b9 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Tue, 12 Nov 2024 16:17:10 -0600 Subject: [PATCH 08/24] fix __nv_bfloat162 init --- .../cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 3bbdc032e3..f876c2667a 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -164,8 +164,8 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; - __nv_bfloat162 scale2 = {1.0f, 1.0f}; - __nv_bfloat162 zero2 = {1.0f, 1.0f}; + __nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); + __nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); if (scales_and_zeros) { const auto& sz = *scales_and_zeros; From 279f4b3994e10340d1e90fbdd520529a568bdc02 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Tue, 12 Nov 2024 16:27:23 -0600 Subject: [PATCH 09/24] add comment for MI300x isa --- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index f876c2667a..10976d9417 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -50,7 +50,10 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; - +// AMD MI300X ISA that performs two bitwise operations in a single instruction: +// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM +// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values +// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16 #if defined(USE_ROCM) asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(h[0]) From bbf5a727567eb1cb76735c444665326a0037ceba Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Mon, 6 Jan 2025 06:39:17 +0000 Subject: [PATCH 10/24] fix build for non-rocm --- setup.py | 2 +- .../tensor_core_tiled_layout/tensor_core_tiled_layout.cu | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4b4c5ce061..e0a589c93f 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ def get_extensions(): if not torch.cuda.is_available(): print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions") - if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available(): + if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available(): print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions") print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit") diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 480d30d1c5..d1c5d49fda 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -167,6 +167,7 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; +#if defined(USE_ROCM) __nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); __nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); @@ -177,6 +178,11 @@ __global__ void _dequantize_int4_kernel( scale2 = __bfloat162bfloat162(pSZ[0]); zero2 = __bfloat162bfloat162(pSZ[1]); } +#else + const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); + __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); +#endif #pragma unroll for (int i = 0; i < 4; i++) { From a2f1736ed006f0b81de07279a465928ba38c90d7 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Thu, 17 Oct 2024 14:49:30 -0500 Subject: [PATCH 11/24] add sparse_marlin kernel to the build --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5b01da8fa6..9b554fbe3a 100644 --- a/setup.py +++ b/setup.py @@ -115,8 +115,9 @@ def get_extensions(): extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout") + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin") hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) + hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.h"), recursive=True)) if not IS_ROCM and use_cuda: sources += cuda_sources From f817edf135c14e39e94546bde7fb871dbf3bf27d Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Thu, 17 Oct 2024 15:29:21 -0500 Subject: [PATCH 12/24] drop .h from conversion --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 9b554fbe3a..1acdb76760 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,6 @@ def get_extensions(): extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin") hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) - hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.h"), recursive=True)) if not IS_ROCM and use_cuda: sources += cuda_sources From c9bc1bcac1887327d902a1b6b01b00ed8f9e50e9 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 17 Oct 2024 13:50:48 -0700 Subject: [PATCH 13/24] cp_asyc4_pred_zfill() AMD implementation --- torchao/csrc/cuda/sparse_marlin/mem.h | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 59d5af38e7..5d78ee42a3 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -27,6 +27,16 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, const bool zfill = false) { const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); + #ifdef USE_ROCM + uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" @@ -35,6 +45,7 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, " @p cp.async.cg.shared.global [%1], [%2], %3;\n" "}\n" ::"r"((int)pred), "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + #endif } __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, @@ -133,4 +144,4 @@ __device__ inline void barrier_release(int* lock, bool reset = false) { : "l"(lock), "r"(val)); } } -} // namespace torchao \ No newline at end of file +} // namespace torchao From 16feff449be6308ba82b0e37589fb828ea56959e Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 17 Oct 2024 19:39:04 -0700 Subject: [PATCH 14/24] implement matching mem utility with amd GCN isa --- torchao/csrc/cuda/sparse_marlin/mem.h | 78 ++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 5d78ee42a3..0a3f980f44 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -51,6 +51,16 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; + #ifdef USE_ROCM + uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr)); + #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" @@ -59,70 +69,125 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, " @p cp.async.cg.shared.global [%1], [%2], %3;\n" "}\n" ::"r"((int)pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)); + #endif } // Asynchronous global->shared copy __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; + #ifdef USE_ROCM + uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + asm volatile( + "{\n" + " ds_read_b128 %0, %1 offset:0;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr)); + #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); + #endif } // Async copy fence. __device__ inline void cp_async_fence() { +#ifdef USE_ROCM + __builtin_amdgcn_s_waitcnt(0); +#else asm volatile("cp.async.commit_group;\n" ::); +#endif } // Wait until at most `n` async copy stages are still pending. template __device__ inline void cp_async_wait() { +#ifdef USE_ROCM + // For AMD GPUs, we use s_waitcnt + // This waits for all outstanding memory operations to complete + __builtin_amdgcn_s_waitcnt(0); +#else + // For NVIDIA GPUs, use the original instruction asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + #ifdef USE_ROCM + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + asm volatile( + "ds_read_b128 %0, %1 offset:0\n" + "ds_read_b128 %2, %1 offset:16\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + : "v"(smem)); + #else uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); + #endif } __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_m); + #ifdef USE_ROCM + uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + asm volatile( + "ds_read_b64 %0, %2 offset:0\n" + : "=v"(a[0]), "=v"(a[1]) + : "v"(smem)); + #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + #endif } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); + #ifdef USE_ROCM + uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + asm volatile( + "ds_read_b128 %0, %4 offset:0\n" + "ds_read_b128 %2, %4 offset:16\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + : "v"(smem)); + #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); + #endif } // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; - do + do { // Guarantee that subsequent writes by this threadblock will be visible // globally. + #ifdef USE_ROCM + asm volatile("flat_load_dword %0, %1 glc\n\t" + "s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t" + : "=v"(state) + : "v"(lock)); + #else asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - while (state != count); + #endif + } while (state != count); } __syncthreads(); } @@ -138,10 +203,19 @@ __device__ inline void barrier_release(int* lock, bool reset = false) { int val = 1; // Make sure that all writes since acquiring this barrier are visible // globally, while releasing the barrier. + #ifdef USE_ROCM + asm volatile("s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t" + "s_memrealtime\n\t" + "s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t" + "flat_atomic_add_i32 %0, %1\n\t" + : "+v"(*lock) + : "v"(val)); + #else asm volatile("fence.acq_rel.gpu;\n"); asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + #endif } } } // namespace torchao From 0b215558f051caac09af346445bfc0abe19880a7 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 17 Oct 2024 19:50:30 -0700 Subject: [PATCH 15/24] implement mma util with amd gcn isa --- torchao/csrc/cuda/sparse_marlin/mma.h | 73 ++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index dde6938d83..b8da31870b 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -27,7 +27,11 @@ namespace torchao { // | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction // | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially // | reduced performance on some future architectures -#if defined CUDA_VERSION && CUDA_VERSION >= 12050 + +#if defined(USE_ROCM) + // HIP ISA doesn't have an equivalent for ordered_metadata, so we'll use the standard mma instruction + #define MMA_SP_INST "v_mfma_f32_16x16x16f16 " +#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12050 #define MMA_SP_INST \ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " #else @@ -84,15 +88,28 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, template __device__ inline int lop3(int a, int b, int c) { int res; + #ifdef USE_ROCM + // AMD GPUs don't have a direct equivalent to lop3, so we implement it using bitwise operations + res = (a & b & c) | (a & b & ~c) | (a & ~b & c) | (~a & b & c); + // Apply the LUT + res = (res & lut) | (~res & ~lut); + #else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + #endif return res; } __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, float c3) { uint2 r; + #ifdef USE_ROCM + // AMD implementation + r.x = __builtin_amdgcn_cvt_pkrtz(c0, c1); + r.y = __builtin_amdgcn_cvt_pkrtz(c2, c3); + #else + // NVIDIA implementation asm("{\n\t" ".reg .f16 a, b, c, d; \n\t" "cvt.rn.f16.f32 a, %2; \n\t" @@ -104,6 +121,7 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, "}" : "=r"(r.x), "=r"(r.y) : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); + #endif return r; } @@ -112,9 +130,16 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; + #ifdef USE_ROCM + // AMD implementation + res = ((a & 0xFF) << 24) | ((a & 0xFF00) << 8) | ((a & 0xFF0000) >> 8) | ((a & 0xFF000000) >> 24); + res = (res >> (start_byte * 8)) & mask; + #else + // NVIDIA implementation asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); + #endif return res; } @@ -136,11 +161,24 @@ __device__ inline FragB dequant_4bit(int q) { const int ADD = 0xd480d480; FragB frag_b; + #ifdef USE_ROCM + // AMD implementation + __half2* lo_ptr = reinterpret_cast<__half2*>(&lo); + __half2* hi_ptr = reinterpret_cast<__half2*>(&hi); + const __half2* SUB_ptr = reinterpret_cast(&SUB); + const __half2* MUL_ptr = reinterpret_cast(&MUL); + const __half2* ADD_ptr = reinterpret_cast(&ADD); + + frag_b[0] = __hsub(*lo_ptr, *SUB_ptr); + frag_b[1] = __hfma(*hi_ptr, *MUL_ptr, *ADD_ptr); + #else + // NVIDIA implementation frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); + #endif return frag_b; } @@ -159,24 +197,56 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; FragB frag_b; + #ifdef USE_ROCM + // AMD implementation + __half2* lo_ptr = reinterpret_cast<__half2*>(&lo); + __half2* hi_ptr = reinterpret_cast<__half2*>(&hi); + const __half2* magic_num_ptr = reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM); + + frag_b[0] = __hsub(*lo_ptr, *magic_num_ptr); + frag_b[1] = __hsub(*hi_ptr, *magic_num_ptr); + #else + // NVIDIA implementation frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(*reinterpret_cast(&hi), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + #endif return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + #ifdef USE_ROCM + // AMD implementation + __half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul(frag_b[0], s); + frag_b[1] = __hmul(frag_b[1], s); + #else + // NVIDIA implementation half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); + #endif } __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, FragS& s0, float* c4, float* c5, float* c6, float* c7, FragS& s1) { + #ifdef USE_ROCM + // AMD implementation + *c0 = __builtin_amdgcn_fmul_legacy(*c0, __half2float(s0[0].x)); + *c1 = __builtin_amdgcn_fmul_legacy(*c1, __half2float(s0[0].y)); + *c2 = __builtin_amdgcn_fmul_legacy(*c2, __half2float(s0[1].x)); + *c3 = __builtin_amdgcn_fmul_legacy(*c3, __half2float(s0[1].y)); + + *c4 = __builtin_amdgcn_fmul_legacy(*c4, __half2float(s1[0].x)); + *c5 = __builtin_amdgcn_fmul_legacy(*c5, __half2float(s1[0].y)); + *c6 = __builtin_amdgcn_fmul_legacy(*c6, __half2float(s1[1].x)); + *c7 = __builtin_amdgcn_fmul_legacy(*c7, __half2float(s1[1].y)); + #else + // NVIDIA implementation *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); @@ -186,6 +256,7 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); + #endif } } // namespace torchao \ No newline at end of file From f23b194b99a52db40eb3164de82b664f810d47ae Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 17 Oct 2024 20:00:24 -0700 Subject: [PATCH 16/24] enable rocm path --- torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu index 380d69130c..eafa8df05b 100644 --- a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu +++ b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu @@ -52,7 +52,7 @@ static constexpr int min_thread_n = 128; static constexpr int tile_size = 16; static constexpr int max_par = 64; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && !defined(USE_ROCM) template Date: Tue, 22 Oct 2024 13:30:45 +0000 Subject: [PATCH 17/24] update copy from global to lds --- torchao/csrc/cuda/sparse_marlin/mem.h | 62 +++++++++++++++++---------- torchao/csrc/cuda/sparse_marlin/mma.h | 5 ++- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 0a3f980f44..54f38fb358 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -19,6 +19,17 @@ #include "base.h" namespace torchao { + +#ifdef USE_ROCM +#include + +// utility function for ROCm for equivalent for cvta_to_shared. +template +__device__ __forceinline__ uint32_t cvta_to_shared(T* ptr) { + return (uint32_t)(uint64_t)(ptr); +} +#endif + // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. __device__ inline void cp_async4_pred_zfill(void* smem_ptr, @@ -28,14 +39,16 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + //asm volatile( + // "{\n" + // " .reg .pred p;\n" + // " setp.ne.b32 p, %0, 0;\n" + // " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent + // "}\n" ::"r"((int)pred), + // "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + uint32_t smem = cvta_to_shared(smem_ptr); + __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -52,14 +65,16 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr)); + //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + //asm volatile( + // "{\n" + // " .reg .pred p;\n" + // " setp.ne.b32 p, %0, 0;\n" + // " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent + // "}\n" ::"r"((int)pred), + // "r"(smem), "l"(glob_ptr)); + uint32_t smem = cvta_to_shared(smem_ptr); + __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -76,12 +91,15 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - asm volatile( - "{\n" - " ds_read_b128 %0, %1 offset:0;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr)); + //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + //asm volatile( + // "{\n" + // " ds_read_b128 %0, %1 offset:0;\n" + // "}\n" ::"r"(smem), + // "l"(glob_ptr)); + uint32_t smem = cvta_to_shared(smem_ptr); + __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + #else uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index b8da31870b..9e9a9be519 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -17,7 +17,10 @@ #pragma once #include "base.h" + +#ifndef USE_ROCM #include +#endif namespace torchao { @@ -259,4 +262,4 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, #endif } -} // namespace torchao \ No newline at end of file +} // namespace torchao From a80730b3ec01cb1c4fc572d441cca4e34199dad0 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 23 Oct 2024 15:27:53 -0700 Subject: [PATCH 18/24] implement cvta_to_shared() --- torchao/csrc/cuda/sparse_marlin/mem.h | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 54f38fb358..06bf45eb8d 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -23,10 +23,21 @@ namespace torchao { #ifdef USE_ROCM #include -// utility function for ROCm for equivalent for cvta_to_shared. +// Convert generic pointer to shared memory address for ROCm template -__device__ __forceinline__ uint32_t cvta_to_shared(T* ptr) { - return (uint32_t)(uint64_t)(ptr); +__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) { + // First get the address as a size_t to handle all pointer sizes + size_t addr = reinterpret_cast(ptr); + + // Extract the lower 32 bits which represent the shared memory offset + // This is safe because shared memory addresses are always within 32-bit range + return static_cast(addr & 0xFFFFFFFF); +} +#else +// For CUDA, use the native intrinsic +template +__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) { + return static_cast(__cvta_generic_to_shared(ptr)); } #endif From d2c7ce4028d1f2301cc7efb524f9683cc46d699a Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 23 Oct 2024 15:35:09 -0700 Subject: [PATCH 19/24] consolidate code with cvta_to_shared() --- torchao/csrc/cuda/sparse_marlin/mem.h | 48 +++++---------------------- 1 file changed, 9 insertions(+), 39 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 06bf45eb8d..1569e3cdda 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -49,19 +49,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, const bool zfill = false) { const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); - #ifdef USE_ROCM - //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - //asm volatile( - // "{\n" - // " .reg .pred p;\n" - // " setp.ne.b32 p, %0, 0;\n" - // " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent - // "}\n" ::"r"((int)pred), - // "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); uint32_t smem = cvta_to_shared(smem_ptr); + #ifdef USE_ROCM __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); #else - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " .reg .pred p;\n" @@ -75,19 +66,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; - #ifdef USE_ROCM - //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - //asm volatile( - // "{\n" - // " .reg .pred p;\n" - // " setp.ne.b32 p, %0, 0;\n" - // " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent - // "}\n" ::"r"((int)pred), - // "r"(smem), "l"(glob_ptr)); uint32_t smem = cvta_to_shared(smem_ptr); + #ifdef USE_ROCM __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); #else - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " .reg .pred p;\n" @@ -101,18 +83,10 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, // Asynchronous global->shared copy __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; - #ifdef USE_ROCM - //uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); - //asm volatile( - // "{\n" - // " ds_read_b128 %0, %1 offset:0;\n" - // "}\n" ::"r"(smem), - // "l"(glob_ptr)); uint32_t smem = cvta_to_shared(smem_ptr); + #ifdef USE_ROCM __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); - #else - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " cp.async.cg.shared.global [%0], [%1], %2;\n" @@ -146,17 +120,15 @@ __device__ inline void cp_async_wait() { // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - #ifdef USE_ROCM uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); + uint32_t smem = cvta_to_shared(smem_ptr); + #ifdef USE_ROCM asm volatile( "ds_read_b128 %0, %1 offset:0\n" "ds_read_b128 %2, %1 offset:16\n" : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) : "v"(smem)); #else - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); @@ -165,14 +137,13 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_m); + uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); asm volatile( "ds_read_b64 %0, %2 offset:0\n" : "=v"(a[0]), "=v"(a[1]) : "v"(smem)); #else - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); @@ -183,15 +154,14 @@ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { // memory, directly in tensor core layout. __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - uint32_t smem = static_cast(__builtin_amdgcn_s_getpc()); asm volatile( - "ds_read_b128 %0, %4 offset:0\n" - "ds_read_b128 %2, %4 offset:16\n" + "ds_read_b128 %0, %1 offset:0\n" + "ds_read_b128 %2, %1 offset:16\n" : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) : "v"(smem)); #else - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) From a4e8c300fff7e5a7a915d0ddf03660dae7f3b36a Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 8 Jan 2025 17:03:43 -0600 Subject: [PATCH 20/24] lint --- setup.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 542624386a..7f85adae3e 100644 --- a/setup.py +++ b/setup.py @@ -52,8 +52,8 @@ def read_version(file_path="version.txt"): import torch from torch.utils.cpp_extension import ( CUDA_HOME, - ROCM_HOME, IS_WINDOWS, + ROCM_HOME, BuildExtension, CppExtension, CUDAExtension, @@ -61,18 +61,27 @@ def read_version(file_path="version.txt"): IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) + def get_extensions(): debug_mode = os.getenv("DEBUG", "0") == "1" if debug_mode: print("Compiling in debug mode") if not torch.cuda.is_available(): - print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions") + print( + "PyTorch GPU support is not available. Skipping compilation of CUDA extensions" + ) if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available(): - print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions") - print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit") + print( + "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" + ) + print( + "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" + ) - use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None) + use_cuda = torch.cuda.is_available() and ( + CUDA_HOME is not None or ROCM_HOME is not None + ) extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] @@ -125,8 +134,12 @@ def get_extensions(): glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) ) - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin") - hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) + extensions_hip_dir = os.path.join( + extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin" + ) + hip_sources = list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) + ) if not IS_ROCM and use_cuda: sources += cuda_sources From c678cb026c810098c60347b05f80c19dbe682ad0 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 9 Jan 2025 16:06:40 -0600 Subject: [PATCH 21/24] add GPU arch check for MI300x --- setup.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/setup.py b/setup.py index 7f85adae3e..36810d0a49 100644 --- a/setup.py +++ b/setup.py @@ -146,6 +146,12 @@ def get_extensions(): # TOOD: Remove this and use what CUDA has once we fix all the builds. if IS_ROCM and use_cuda: + # Add ROCm GPU architecture check + gpu_arch = torch.cuda.get_device_properties(0).name + if gpu_arch != 'gfx942': + print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") + print("Currently only gfx942 is supported. Skipping compilation of ROCm extensions") + return None sources += hip_sources if len(sources) == 0: From 08d1cfbe515098dcadfed25eb67a736b1440d910 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Thu, 9 Jan 2025 16:15:11 -0600 Subject: [PATCH 22/24] revert change in tensor_core_tile_layout.cu --- setup.py | 6 +- .../tensor_core_tiled_layout.cu | 61 +------------------ 2 files changed, 7 insertions(+), 60 deletions(-) diff --git a/setup.py b/setup.py index 36810d0a49..fef71dcbdb 100644 --- a/setup.py +++ b/setup.py @@ -148,9 +148,11 @@ def get_extensions(): if IS_ROCM and use_cuda: # Add ROCm GPU architecture check gpu_arch = torch.cuda.get_device_properties(0).name - if gpu_arch != 'gfx942': + if gpu_arch != "gfx942": print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") - print("Currently only gfx942 is supported. Skipping compilation of ROCm extensions") + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) return None sources += hip_sources diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index d1c5d49fda..d3ddd66fe6 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere #include #include @@ -7,24 +7,13 @@ #include #include -#if defined(USE_ROCM) -#include -#include -#include -#endif - template constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { static_assert(std::is_integral::value && std::is_integral::value, ""); const uint64_t blocks = a / b + (a % b != 0); return blocks; } - -#if defined(USE_ROCM) -constexpr int32_t kWarpSize = 64; -#else constexpr int32_t kWarpSize = 32; -#endif //Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization //https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 @@ -41,71 +30,38 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { uint32_t const source_i4s = source; // First, we extract the i4s and construct an intermediate fp16 number. -#if !defined(USE_ROCM) static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; -#endif static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; -// AMD MI300X ISA that performs two bitwise operations in a single instruction: -// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM -// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values -// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16 -#if defined(USE_ROCM) - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(h[0]) - : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); -#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#endif - #pragma unroll for (int ii = 1; ii < kElements / 2; ++ii) { i4s >>= 4; // or is it 8? // (i4s & 0x000f000f) | 0x43004300 -#if defined(USE_ROCM) - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(h[ii]) - : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); -#else asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#endif } // This is the BF16 {-136, -136} represented as an integer. -#if defined(USE_ROCM) -#if ROCM_VERSION >= 60200 - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); -#else - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); -#endif -#else static constexpr uint32_t BF16_BIAS = 0xC308C308; static constexpr uint32_t BF16_ONE = 0x3F803F80; -#endif // Finally, we construct the output numbers. #pragma unroll for (int ii = 0; ii < kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction -#if defined(USE_ROCM) - result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); -#else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); -#endif } return result; @@ -167,22 +123,11 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; -#if defined(USE_ROCM) - __nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); - __nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); - - if (scales_and_zeros) { - const auto& sz = *scales_and_zeros; - const __nv_bfloat16* pSZ = reinterpret_cast(&sz[qgroup][n0][0]); - - scale2 = __bfloat162bfloat162(pSZ[0]); - zero2 = __bfloat162bfloat162(pSZ[1]); - } -#else const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + + // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); -#endif #pragma unroll for (int i = 0; i < 4; i++) { From aea9d81a34871d01d04b1563a1208d7070d307af Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 15 Jan 2025 15:09:16 -0800 Subject: [PATCH 23/24] lint refactor for better readibility --- setup.py | 57 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 0f64e6107e..d9b3c7e562 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,6 @@ def use_debug_mode(): CUDAExtension, ) - IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) # Constant known variables used throughout this file @@ -258,38 +257,41 @@ def get_extensions(): ] ) + # Get base directory and source paths this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") - sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list( - glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) - ) - - extensions_hip_dir = os.path.join( - extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin" - ) - hip_sources = list( - glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) - ) + # Collect C++ source files + sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - if not IS_ROCM and use_cuda: - sources += cuda_sources - - # TOOD: Remove this and use what CUDA has once we fix all the builds. - if IS_ROCM and use_cuda: - # Add ROCm GPU architecture check - gpu_arch = torch.cuda.get_device_properties(0).name - if gpu_arch != "gfx942": - print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") - print( - "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + # Collect CUDA source files if needed + if use_cuda: + if not IS_ROCM: + # Regular CUDA sources + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list( + glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) + ) + sources += cuda_sources + else: + # ROCm sources + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") + hip_sources = list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) - return None - sources += hip_sources - if len(sources) == 0: + # Check ROCm GPU architecture compatibility + gpu_arch = torch.cuda.get_device_properties(0).name + if gpu_arch != "gfx942": + print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) + return None + sources += hip_sources + + # Return None if no sources found + if not sources: return None ext_modules = [] @@ -304,7 +306,6 @@ def get_extensions(): ) ) - if build_torchao_experimental: ext_modules.append( CMakeExtension( From 15e29f1433e411fd91d0c2ee91d98445f9dfe2bb Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 4 Mar 2025 13:51:43 -0800 Subject: [PATCH 24/24] fix setup --- setup.py | 39 +++++++++------------ torchao/_models/llama/bsr_bench_results.txt | 27 ++++++++++++++ torchao/_models/llama/bsr_benchmarks.sh | 12 +++++++ 3 files changed, 55 insertions(+), 23 deletions(-) create mode 100644 torchao/_models/llama/bsr_bench_results.txt create mode 100644 torchao/_models/llama/bsr_benchmarks.sh diff --git a/setup.py b/setup.py index ef9ef959be..533afefcb6 100644 --- a/setup.py +++ b/setup.py @@ -330,30 +330,23 @@ def get_extensions(): # Collect C++ source files sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - # Collect CUDA source files if needed - if use_cuda: - if not IS_ROCM: - # Regular CUDA sources - extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list( - glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) - ) - sources += cuda_sources - else: - # ROCm sources - # Add sparse marlin support - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") - hip_sources = list( - glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) - ) - # Add tensor core tiled layout support - extensions_hip_dir = os.path.join( - extensions_dir, "cuda", "tensor_core_tiled_layout" - ) - hip_sources += list( - glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) - ) + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list( + glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) + ) + extensions_hip_dir = os.path.join( + extensions_dir, "cuda", "tensor_core_tiled_layout" + ) + hip_sources = list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) + ) + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") + hip_sources += list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) + ) + + # Collect CUDA source files if needed if not IS_ROCM and use_cuda: sources += cuda_sources else: diff --git a/torchao/_models/llama/bsr_bench_results.txt b/torchao/_models/llama/bsr_bench_results.txt new file mode 100644 index 0000000000..172581dedb --- /dev/null +++ b/torchao/_models/llama/bsr_bench_results.txt @@ -0,0 +1,27 @@ + +20250226151422, tok/s=133.29, tok/s_decode=134.40, ttft=0.0118, mem/s=2000.68 GB/s, peak_mem=16.30 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226151926, tok/s=242.08, tok/s_decode=256.68, ttft=0.0464, mem/s=1182.14 GB/s, peak_mem= 6.74 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226152416, tok/s=252.18, tok/s_decode=267.48, ttft=0.0448, mem/s=1229.49 GB/s, peak_mem= 6.73 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226153215, tok/s=204.19, tok/s_decode=213.86, ttft=0.0438, mem/s=1226.65 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226153628, tok/s=180.14, tok/s_decode=187.54, ttft=0.0433, mem/s=1081.56 GB/s, peak_mem= 8.26 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226160622, tok/s=246.20, tok/s_decode=255.21, ttft=0.0281, mem/s= 956.89 GB/s, peak_mem= 5.56 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226160651, tok/s=145.07, tok/s_decode=163.13, ttft=0.1522, mem/s=1461.87 GB/s, peak_mem=22.76 GB, model_size=10.08 GB quant: None, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 + +20250226161533, tok/s=250.71, tok/s_decode=254.78, ttft=0.0121, mem/s= 974.38 GB/s, peak_mem= 5.56 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226161913, tok/s=251.19, tok/s_decode=254.95, ttft=0.0112, mem/s= 976.26 GB/s, peak_mem= 5.63 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226181326, tok/s=134.44, tok/s_decode=140.82, ttft=0.0669, mem/s= 807.62 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226181520, tok/s=138.03, tok/s_decode=164.08, ttft=0.2295, mem/s=1390.97 GB/s, peak_mem=22.74 GB, model_size=10.08 GB quant: None, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226181738, tok/s=192.65, tok/s_decode=205.62, ttft=0.0649, mem/s=1157.32 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226182045, tok/s=192.75, tok/s_decode=206.24, ttft=0.0673, mem/s=1157.27 GB/s, peak_mem= 8.26 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226182350, tok/s=236.36, tok/s_decode=257.62, ttft=0.0693, mem/s=1154.19 GB/s, peak_mem= 6.74 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226182712, tok/s=231.24, tok/s_decode=250.55, ttft=0.0661, mem/s=1127.37 GB/s, peak_mem= 6.73 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226183255, tok/s=169.58, tok/s_decode=179.82, ttft=0.0665, mem/s=1018.74 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226183527, tok/s=184.74, tok/s_decode=196.38, ttft=0.0637, mem/s=1109.18 GB/s, peak_mem= 8.26 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226183734, tok/s=232.60, tok/s_decode=252.51, ttft=0.0673, mem/s=1135.85 GB/s, peak_mem= 6.74 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250226183953, tok/s=232.47, tok/s_decode=251.15, ttft=0.0635, mem/s=1133.40 GB/s, peak_mem= 6.73 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250227084325, tok/s=200.72, tok/s_decode=210.91, ttft=0.0475, mem/s=1205.82 GB/s, peak_mem= 8.00 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250227084708, tok/s=211.76, tok/s_decode=222.43, ttft=0.0447, mem/s=1271.42 GB/s, peak_mem= 7.99 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250227085051, tok/s=241.09, tok/s_decode=255.19, ttft=0.0452, mem/s=1177.31 GB/s, peak_mem= 6.47 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250227085446, tok/s=247.53, tok/s_decode=262.94, ttft=0.0468, mem/s=1206.80 GB/s, peak_mem= 6.46 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250227090411, tok/s=250.11, tok/s_decode=263.99, ttft=0.0416, mem/s=1219.39 GB/s, peak_mem= 6.46 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250227091144, tok/s=249.14, tok/s_decode=263.74, ttft=0.0439, mem/s=1214.68 GB/s, peak_mem= 6.46 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 \ No newline at end of file diff --git a/torchao/_models/llama/bsr_benchmarks.sh b/torchao/_models/llama/bsr_benchmarks.sh new file mode 100644 index 0000000000..0baa527fef --- /dev/null +++ b/torchao/_models/llama/bsr_benchmarks.sh @@ -0,0 +1,12 @@ + +# BSR benchmarks +export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B + +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result bsr_bench_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result bsr_bench_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result bsr_bench_results.txt +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.8-32 +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.8-64 +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.9-32 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.9-64