1- #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
1+ #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
22
33#include < ATen/ATen.h>
44#include < ATen/core/Tensor.h>
77#include < c10/cuda/CUDAGuard.h>
88#include < torch/library.h>
99
10+ #if defined(USE_ROCM)
11+ #include < hip/hip_bf16.h>
12+ #include < hip/hip_fp16.h>
13+ #include < hip/hip_runtime.h>
14+ #endif
15+
1016template <typename U, typename V>
1117constexpr __host__ __device__ auto divUp (U a, V b) -> decltype(a + b) {
1218 static_assert (std::is_integral<U>::value && std::is_integral<V>::value, " " );
1319 const uint64_t blocks = a / b + (a % b != 0 );
1420 return blocks;
1521}
22+
23+ #if defined(USE_ROCM)
24+ constexpr int32_t kWarpSize = 64 ;
25+ #else
1626constexpr int32_t kWarpSize = 32 ;
27+ #endif
1728
1829// Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
1930// https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
@@ -30,38 +41,71 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
3041 uint32_t const source_i4s = source;
3142
3243 // First, we extract the i4s and construct an intermediate fp16 number.
44+ #if !defined(USE_ROCM)
3345 static constexpr uint32_t immLut = (0xf0 & 0xcc ) | 0xaa ;
46+ #endif
3447 static constexpr uint32_t MASK = 0x000f000f ;
3548 static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300 ;
3649
3750 // We don't have enough mantissa to remove as much shift overhead as FP16, so
3851 // we must loop. No shift needed for first item.
3952 uint32_t i4s = source_i4s;
53+ // AMD MI300X ISA that performs two bitwise operations in a single instruction:
54+ // v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM
55+ // - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values
56+ // - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16
57+ #if defined(USE_ROCM)
58+ asm volatile (" v_and_or_b32 %0, %1, %2, %3"
59+ : " =v" (h[0 ])
60+ : " v" (i4s), " v" (MASK), " v" (I4s_TO_BF16s_MAGIC_NUM));
61+ #else
4062 asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
4163 : " =r" (h[0 ])
4264 : " r" (i4s), " n" (MASK), " n" (I4s_TO_BF16s_MAGIC_NUM), " n" (immLut));
65+ #endif
66+
4367#pragma unroll
4468 for (int ii = 1 ; ii < kElements / 2 ; ++ii) {
4569 i4s >>= 4 ; // or is it 8?
4670 // (i4s & 0x000f000f) | 0x43004300
71+ #if defined(USE_ROCM)
72+ asm volatile (" v_and_or_b32 %0, %1, %2, %3"
73+ : " =v" (h[ii])
74+ : " v" (i4s), " v" (MASK), " v" (I4s_TO_BF16s_MAGIC_NUM));
75+ #else
4776 asm volatile (
4877 " lop3.b32 %0, %1, %2, %3, %4;\n "
4978 : " =r" (h[ii])
5079 : " r" (i4s), " n" (MASK), " n" (I4s_TO_BF16s_MAGIC_NUM), " n" (immLut));
80+ #endif
5181 }
5282
5383 // This is the BF16 {-136, -136} represented as an integer.
54- static constexpr uint32_t BF16_BIAS = 0xC308C308 ;
55- static constexpr uint32_t BF16_ONE = 0x3F803F80 ;
84+ #if defined(USE_ROCM)
85+ #if ROCM_VERSION >= 60200
86+ auto BF16_SCALE_FACTOR = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0xC308 }));
87+ auto BF16_UNIT_VALUE = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0x3F80 }));
88+ #else
89+ auto BF16_SCALE_FACTOR = __bfloat162bfloat162 (__hip_bfloat16{0xC308 });
90+ auto BF16_UNIT_VALUE = __bfloat162bfloat162 (__hip_bfloat16{0x3F80 });
91+ #endif
92+ #else
93+ static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308 ;
94+ static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80 ;
95+ #endif
5696
5797// Finally, we construct the output numbers.
5898#pragma unroll
5999 for (int ii = 0 ; ii < kElements / 2 ; ++ii) {
60100 // Since this section is for Ampere+, we use bf16 fma to do the bias
61101 // subtraction
102+ #if defined(USE_ROCM)
103+ result.vals [ii] = __hfma2 (result.vals [ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR);
104+ #else
62105 asm (" fma.rn.bf16x2 %0, %1, %2, %3;\n "
63106 : " =r" (h[ii])
64- : " r" (h[ii]), " r" (BF16_ONE), " r" (BF16_BIAS));
107+ : " r" (h[ii]), " r" (BF16_UNIT_VALUE), " r" (BF16_SCALE_FACTOR));
108+ #endif
65109 }
66110
67111 return result;
@@ -123,11 +167,22 @@ __global__ void _dequantize_int4_kernel(
123167 // All b values within a 16x16 tile should fall within the same q group
124168 // Hence we load 1 scale and zero per loop
125169 int qgroup = ks[0 ] / groupSize;
170+ #if defined(USE_ROCM)
171+ __nv_bfloat162 scale2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
172+ __nv_bfloat162 zero2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
173+
174+ if (scales_and_zeros) {
175+ const auto & sz = *scales_and_zeros;
176+ const __nv_bfloat16* pSZ = reinterpret_cast <const __nv_bfloat16*>(&sz[qgroup][n0][0 ]);
177+
178+ scale2 = __bfloat162bfloat162 (pSZ[0 ]);
179+ zero2 = __bfloat162bfloat162 (pSZ[1 ]);
180+ }
181+ #else
126182 const __nv_bfloat16 *pSZ = reinterpret_cast <const __nv_bfloat16*>(&scales_and_zeros.value ()[qgroup][n0][0 ]);
127-
128- // Vectorize scales and zeros
129183 __nv_bfloat162 scale2 = __bfloat162bfloat162 (pSZ[0 ]);
130184 __nv_bfloat162 zero2 = __bfloat162bfloat162 (pSZ[1 ]);
185+ #endif
131186
132187 #pragma unroll
133188 for (int i = 0 ; i < 4 ; i++) {
0 commit comments