Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6d92e40
enable build for rocm for fp6_llm
lcskrishna Oct 16, 2024
14b3fce
Merge pull request #1 from lcskrishna/cl/rocm-enablement
petrex Oct 17, 2024
f1a22cf
enable tiled layout extension
lcskrishna Oct 23, 2024
0bef6ca
fix build error related to option
Oct 23, 2024
893ae03
require rocm 6.2
Oct 23, 2024
a0d3788
enable tensor tiled layout extension with successful compilation
lcskrishna Oct 24, 2024
e4e654d
enable successful build
lcskrishna Oct 24, 2024
3e2c6a1
clean-up
Oct 29, 2024
c86880e
Merge pull request #3 from lcskrishna/csrikris_enable_tensor_tile
petrex Oct 29, 2024
91d3c75
fix potential memory access issue
Oct 29, 2024
38b7d1c
fix __nv_bfloat162 init
Nov 12, 2024
279f4b3
add comment for MI300x isa
Nov 12, 2024
612ad14
Merge branch 'main' into rocm_enablement_staging
petrex Nov 18, 2024
bbf5a72
fix build for non-rocm
lcskrishna Jan 6, 2025
735570e
Merge pull request #4 from lcskrishna/rocm_enablement
petrex Jan 6, 2025
253c188
Merge branch 'main' into rocm_enablement_staging
petrex Jan 6, 2025
26fa19c
better naming
Jan 9, 2025
8f02096
lint
Jan 9, 2025
ced9647
resolve merge conflict
Jan 9, 2025
452fa2a
Merge branch 'main' into rocm_enablement_staging
petrex Jan 9, 2025
2ec95f9
Merge branch 'main' into rocm_enablement_staging
petrex Jan 16, 2025
89bae5f
lint
Jan 16, 2025
135c2b2
Merge branch 'main' into rocm_enablement_staging
petrex Feb 25, 2025
4148828
Refactor: Rename `this_dir` to `curdir` in setup.py
Feb 28, 2025
2272a3f
Merge branch 'main' into rocm_enablement_staging
petrex Mar 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def use_debug_mode():
from torch.utils.cpp_extension import (
CUDA_HOME,
IS_WINDOWS,
ROCM_HOME,
BuildExtension,
CppExtension,
CUDAExtension,
)

IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)

class BuildOptions:
def __init__(self):
Expand Down Expand Up @@ -250,13 +252,18 @@ def get_extensions():
print(
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
)
if CUDA_HOME is None and torch.cuda.is_available():
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")

if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
print(
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
)
print(
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
)

use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
use_cuda = torch.cuda.is_available() and (
CUDA_HOME is not None or ROCM_HOME is not None
)
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
Expand All @@ -272,7 +279,8 @@ def get_extensions():

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
if "nvcc" in extra_compile_args:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
else:
extra_compile_args["cxx"].extend(
Expand All @@ -284,8 +292,8 @@ def get_extensions():
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
curdir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(curdir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
Expand Down Expand Up @@ -314,6 +322,25 @@ def get_extensions():
"-I" + cutlass_extensions_include_dir,
]
)

curdir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(curdir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

extensions_hip_dir = os.path.join(
extensions_dir, "cuda", "tensor_core_tiled_layout"
)
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)

if not IS_ROCM and use_cuda:
sources += cuda_sources
else:
# Remove CUTLASS-based kernels from the cuda_sources list. An
# assumption is that these files will have "cutlass" in its
Expand All @@ -325,6 +352,21 @@ def get_extensions():
)
sources = [s for s in sources if s not in cutlass_sources]

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
# Add ROCm GPU architecture check
gpu_arch = torch.cuda.get_device_properties(0).name
if gpu_arch != "gfx942":
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
print(
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
)
return None
sources += hip_sources

if len(sources) == 0:
return None

ext_modules = []
if len(sources) > 0:
ext_modules.append(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
Expand All @@ -7,13 +7,24 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

#if defined(USE_ROCM)
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#endif

template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
const uint64_t blocks = a / b + (a % b != 0);
return blocks;
}

#if defined(USE_ROCM)
constexpr int32_t kWarpSize = 64;
#else
constexpr int32_t kWarpSize = 32;
#endif

//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
Expand All @@ -30,38 +41,71 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
uint32_t const source_i4s = source;

// First, we extract the i4s and construct an intermediate fp16 number.
#if !defined(USE_ROCM)
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
#endif
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;

// We don't have enough mantissa to remove as much shift overhead as FP16, so
// we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;
// AMD MI300X ISA that performs two bitwise operations in a single instruction:
// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM
// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values
// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[0])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif

#pragma unroll
for (int ii = 1; ii < kElements / 2; ++ii) {
i4s >>= 4; // or is it 8?
// (i4s & 0x000f000f) | 0x43004300
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[ii])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif
}

// This is the BF16 {-136, -136} represented as an integer.
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
#if defined(USE_ROCM)
#if ROCM_VERSION >= 60200
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
#else
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16{0xC308});
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
#endif
#else
static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308;
static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80;
#endif

// Finally, we construct the output numbers.
#pragma unroll
for (int ii = 0; ii < kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
#if defined(USE_ROCM)
result.vals[ii] = __hfma2(result.vals[ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR);
#else
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
: "r"(h[ii]), "r"(BF16_UNIT_VALUE), "r"(BF16_SCALE_FACTOR));
#endif
}

return result;
Expand Down Expand Up @@ -123,11 +167,22 @@ __global__ void _dequantize_int4_kernel(
// All b values within a 16x16 tile should fall within the same q group
// Hence we load 1 scale and zero per loop
int qgroup = ks[0] / groupSize;
#if defined(USE_ROCM)
__nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
__nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));

if (scales_and_zeros) {
const auto& sz = *scales_and_zeros;
const __nv_bfloat16* pSZ = reinterpret_cast<const __nv_bfloat16*>(&sz[qgroup][n0][0]);

scale2 = __bfloat162bfloat162(pSZ[0]);
zero2 = __bfloat162bfloat162(pSZ[1]);
}
#else
const __nv_bfloat16 *pSZ = reinterpret_cast<const __nv_bfloat16*>(&scales_and_zeros.value()[qgroup][n0][0]);

// Vectorize scales and zeros
__nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
__nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]);
#endif

#pragma unroll
for (int i = 0; i < 4; i++) {
Expand Down
Loading