1- #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
1+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
22
33#include < ATen/ATen.h>
44#include < ATen/core/Tensor.h>
77#include < c10/cuda/CUDAGuard.h>
88#include < torch/extension.h>
99
10- #if defined(USE_ROCM)
11- #include < hip/hip_bf16.h>
12- #include < hip/hip_fp16.h>
13- #include < hip/hip_runtime.h>
14- #endif
15-
1610template <typename U, typename V>
1711constexpr __host__ __device__ auto divUp (U a, V b) -> decltype(a + b) {
1812 static_assert (std::is_integral<U>::value && std::is_integral<V>::value, " " );
1913 const uint64_t blocks = a / b + (a % b != 0 );
2014 return blocks;
2115}
22-
23- #if defined(USE_ROCM)
24- constexpr int32_t kWarpSize = 64 ;
25- #else
2616constexpr int32_t kWarpSize = 32 ;
27- #endif
2817
2918// Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
3019// https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
@@ -41,71 +30,38 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
4130 uint32_t const source_i4s = source;
4231
4332 // First, we extract the i4s and construct an intermediate fp16 number.
44- #if !defined(USE_ROCM)
4533 static constexpr uint32_t immLut = (0xf0 & 0xcc ) | 0xaa ;
46- #endif
4734 static constexpr uint32_t MASK = 0x000f000f ;
4835 static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300 ;
4936
5037 // We don't have enough mantissa to remove as much shift overhead as FP16, so
5138 // we must loop. No shift needed for first item.
5239 uint32_t i4s = source_i4s;
53- // AMD MI300X ISA that performs two bitwise operations in a single instruction:
54- // v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM
55- // - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values
56- // - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16
57- #if defined(USE_ROCM)
58- asm volatile (" v_and_or_b32 %0, %1, %2, %3"
59- : " =v" (h[0 ])
60- : " v" (i4s), " v" (MASK), " v" (I4s_TO_BF16s_MAGIC_NUM));
61- #else
6240 asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
6341 : " =r" (h[0 ])
6442 : " r" (i4s), " n" (MASK), " n" (I4s_TO_BF16s_MAGIC_NUM), " n" (immLut));
65- #endif
66-
6743#pragma unroll
6844 for (int ii = 1 ; ii < kElements / 2 ; ++ii) {
6945 i4s >>= 4 ; // or is it 8?
7046 // (i4s & 0x000f000f) | 0x43004300
71- #if defined(USE_ROCM)
72- asm volatile (" v_and_or_b32 %0, %1, %2, %3"
73- : " =v" (h[ii])
74- : " v" (i4s), " v" (MASK), " v" (I4s_TO_BF16s_MAGIC_NUM));
75- #else
7647 asm volatile (
7748 " lop3.b32 %0, %1, %2, %3, %4;\n "
7849 : " =r" (h[ii])
7950 : " r" (i4s), " n" (MASK), " n" (I4s_TO_BF16s_MAGIC_NUM), " n" (immLut));
80- #endif
8151 }
8252
8353 // This is the BF16 {-136, -136} represented as an integer.
84- #if defined(USE_ROCM)
85- #if ROCM_VERSION >= 60200
86- auto BF16_BIAS = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0xC308 }));
87- auto BF16_ONE = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0x3F80 }));
88- #else
89- auto BF16_BIAS = __bfloat162bfloat162 (__hip_bfloat16{0xC308 });
90- auto BF16_ONE = __bfloat162bfloat162 (__hip_bfloat16{0x3F80 });
91- #endif
92- #else
9354 static constexpr uint32_t BF16_BIAS = 0xC308C308 ;
9455 static constexpr uint32_t BF16_ONE = 0x3F803F80 ;
95- #endif
9656
9757// Finally, we construct the output numbers.
9858#pragma unroll
9959 for (int ii = 0 ; ii < kElements / 2 ; ++ii) {
10060 // Since this section is for Ampere+, we use bf16 fma to do the bias
10161 // subtraction
102- #if defined(USE_ROCM)
103- result.vals [ii] = __hfma2 (result.vals [ii], BF16_ONE, BF16_BIAS);
104- #else
10562 asm (" fma.rn.bf16x2 %0, %1, %2, %3;\n "
10663 : " =r" (h[ii])
10764 : " r" (h[ii]), " r" (BF16_ONE), " r" (BF16_BIAS));
108- #endif
10965 }
11066
11167 return result;
@@ -167,22 +123,11 @@ __global__ void _dequantize_int4_kernel(
167123 // All b values within a 16x16 tile should fall within the same q group
168124 // Hence we load 1 scale and zero per loop
169125 int qgroup = ks[0 ] / groupSize;
170- #if defined(USE_ROCM)
171- __nv_bfloat162 scale2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
172- __nv_bfloat162 zero2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
173-
174- if (scales_and_zeros) {
175- const auto & sz = *scales_and_zeros;
176- const __nv_bfloat16* pSZ = reinterpret_cast <const __nv_bfloat16*>(&sz[qgroup][n0][0 ]);
177-
178- scale2 = __bfloat162bfloat162 (pSZ[0 ]);
179- zero2 = __bfloat162bfloat162 (pSZ[1 ]);
180- }
181- #else
182126 const __nv_bfloat16 *pSZ = reinterpret_cast <const __nv_bfloat16*>(&scales_and_zeros.value ()[qgroup][n0][0 ]);
127+
128+ // Vectorize scales and zeros
183129 __nv_bfloat162 scale2 = __bfloat162bfloat162 (pSZ[0 ]);
184130 __nv_bfloat162 zero2 = __bfloat162bfloat162 (pSZ[1 ]);
185- #endif
186131
187132 #pragma unroll
188133 for (int i = 0 ; i < 4 ; i++) {
0 commit comments