diff --git a/setup.py b/setup.py index b16f78eb35..6a3e3678ff 100644 --- a/setup.py +++ b/setup.py @@ -71,11 +71,13 @@ def use_debug_mode(): from torch.utils.cpp_extension import ( CUDA_HOME, IS_WINDOWS, + ROCM_HOME, BuildExtension, CppExtension, CUDAExtension, ) +IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) class BuildOptions: def __init__(self): @@ -250,13 +252,18 @@ def get_extensions(): print( "PyTorch GPU support is not available. Skipping compilation of CUDA extensions" ) - if CUDA_HOME is None and torch.cuda.is_available(): - print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") + + if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available(): + print( + "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" + ) print( "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" ) - use_cuda = torch.cuda.is_available() and CUDA_HOME is not None + use_cuda = torch.cuda.is_available() and ( + CUDA_HOME is not None or ROCM_HOME is not None + ) extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] @@ -272,7 +279,8 @@ def get_extensions(): if debug_mode: extra_compile_args["cxx"].append("-g") - extra_compile_args["nvcc"].append("-g") + if "nvcc" in extra_compile_args: + extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) else: extra_compile_args["cxx"].extend( @@ -284,8 +292,8 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") - this_dir = os.path.dirname(os.path.curdir) - extensions_dir = os.path.join(this_dir, "torchao", "csrc") + curdir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(curdir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") @@ -314,6 +322,25 @@ def get_extensions(): "-I" + cutlass_extensions_include_dir, ] ) + + curdir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(curdir, "torchao", "csrc") + sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) + + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list( + glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) + ) + + extensions_hip_dir = os.path.join( + extensions_dir, "cuda", "tensor_core_tiled_layout" + ) + hip_sources = list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) + ) + + if not IS_ROCM and use_cuda: + sources += cuda_sources else: # Remove CUTLASS-based kernels from the cuda_sources list. An # assumption is that these files will have "cutlass" in its @@ -325,6 +352,21 @@ def get_extensions(): ) sources = [s for s in sources if s not in cutlass_sources] + # TOOD: Remove this and use what CUDA has once we fix all the builds. + if IS_ROCM and use_cuda: + # Add ROCm GPU architecture check + gpu_arch = torch.cuda.get_device_properties(0).name + if gpu_arch != "gfx942": + print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) + return None + sources += hip_sources + + if len(sources) == 0: + return None + ext_modules = [] if len(sources) > 0: ext_modules.append( diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index ea0f24c202..1fc96f60ec 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,4 +1,4 @@ -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #include #include @@ -7,13 +7,24 @@ #include #include +#if defined(USE_ROCM) +#include +#include +#include +#endif + template constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { static_assert(std::is_integral::value && std::is_integral::value, ""); const uint64_t blocks = a / b + (a % b != 0); return blocks; } + +#if defined(USE_ROCM) +constexpr int32_t kWarpSize = 64; +#else constexpr int32_t kWarpSize = 32; +#endif //Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization //https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 @@ -30,38 +41,71 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { uint32_t const source_i4s = source; // First, we extract the i4s and construct an intermediate fp16 number. +#if !defined(USE_ROCM) static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; +#endif static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; +// AMD MI300X ISA that performs two bitwise operations in a single instruction: +// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM +// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values +// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16 +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[0]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + #pragma unroll for (int ii = 1; ii < kElements / 2; ++ii) { i4s >>= 4; // or is it 8? // (i4s & 0x000f000f) | 0x43004300 +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[ii]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif } // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; +#if defined(USE_ROCM) +#if ROCM_VERSION >= 60200 + auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); + auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); +#else + auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16{0xC308}); + auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); +#endif +#else + static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308; + static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80; +#endif // Finally, we construct the output numbers. #pragma unroll for (int ii = 0; ii < kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction +#if defined(USE_ROCM) + result.vals[ii] = __hfma2(result.vals[ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR); +#else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) - : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + : "r"(h[ii]), "r"(BF16_UNIT_VALUE), "r"(BF16_SCALE_FACTOR)); +#endif } return result; @@ -123,11 +167,22 @@ __global__ void _dequantize_int4_kernel( // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; +#if defined(USE_ROCM) + __nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); + __nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f)); + + if (scales_and_zeros) { + const auto& sz = *scales_and_zeros; + const __nv_bfloat16* pSZ = reinterpret_cast(&sz[qgroup][n0][0]); + + scale2 = __bfloat162bfloat162(pSZ[0]); + zero2 = __bfloat162bfloat162(pSZ[1]); + } +#else const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); - - // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); +#endif #pragma unroll for (int i = 0; i < 4; i++) {