Skip to content

Commit 75b6816

Browse files
authored
Merge branch 'main' into rocm_sparse_marlin
2 parents 8b34390 + 883dc65 commit 75b6816

File tree

3 files changed

+80
-11
lines changed

3 files changed

+80
-11
lines changed

.github/workflows/float8nocompile_test.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@ on:
77
- 'gh/**'
88
paths:
99
- 'torchao/prototype/float8nocompile/**'
10-
- '!torchao/prototype/float8nocompile/**'
1110
pull_request:
1211
branches:
1312
- main
1413
- 'gh/**'
1514
paths:
1615
- 'torchao/prototype/float8nocompile/**'
17-
- '!torchao/prototype/float8nocompile/**'
1816

1917
concurrency:
2018
group: floatnocompile_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}

setup.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def get_extensions():
252252
print(
253253
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
254254
)
255+
255256
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
256257
print(
257258
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
@@ -291,8 +292,8 @@ def get_extensions():
291292
extra_compile_args["nvcc"].append("-g")
292293
extra_link_args.append("/DEBUG")
293294

294-
this_dir = os.path.dirname(os.path.curdir)
295-
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
295+
curdir = os.path.dirname(os.path.curdir)
296+
extensions_dir = os.path.join(curdir, "torchao", "csrc")
296297
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
297298

298299
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
@@ -340,7 +341,7 @@ def get_extensions():
340341
sources += cuda_sources
341342
else:
342343
# ROCm sources
343-
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
344+
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin", "tensor_core_tiled_layout")
344345
hip_sources = list(
345346
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
346347
)
@@ -369,6 +370,21 @@ def get_extensions():
369370
)
370371
sources = [s for s in sources if s not in cutlass_sources]
371372

373+
# TOOD: Remove this and use what CUDA has once we fix all the builds.
374+
if IS_ROCM and use_cuda:
375+
# Add ROCm GPU architecture check
376+
gpu_arch = torch.cuda.get_device_properties(0).name
377+
if gpu_arch != "gfx942":
378+
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
379+
print(
380+
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
381+
)
382+
return None
383+
sources += hip_sources
384+
385+
if len(sources) == 0:
386+
return None
387+
372388
ext_modules = []
373389
if len(sources) > 0:
374390
ext_modules.append(

torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere
1+
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
22

33
#include <ATen/ATen.h>
44
#include <ATen/core/Tensor.h>
@@ -7,13 +7,24 @@
77
#include <c10/cuda/CUDAGuard.h>
88
#include <torch/library.h>
99

10+
#if defined(USE_ROCM)
11+
#include <hip/hip_bf16.h>
12+
#include <hip/hip_fp16.h>
13+
#include <hip/hip_runtime.h>
14+
#endif
15+
1016
template <typename U, typename V>
1117
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
1218
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
1319
const uint64_t blocks = a / b + (a % b != 0);
1420
return blocks;
1521
}
22+
23+
#if defined(USE_ROCM)
24+
constexpr int32_t kWarpSize = 64;
25+
#else
1626
constexpr int32_t kWarpSize = 32;
27+
#endif
1728

1829
//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization
1930
//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180
@@ -30,38 +41,71 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
3041
uint32_t const source_i4s = source;
3142

3243
// First, we extract the i4s and construct an intermediate fp16 number.
44+
#if !defined(USE_ROCM)
3345
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
46+
#endif
3447
static constexpr uint32_t MASK = 0x000f000f;
3548
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
3649

3750
// We don't have enough mantissa to remove as much shift overhead as FP16, so
3851
// we must loop. No shift needed for first item.
3952
uint32_t i4s = source_i4s;
53+
// AMD MI300X ISA that performs two bitwise operations in a single instruction:
54+
// v_and_or_b32 performs H[0] = (i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM
55+
// - First ANDs `i4s` with `MASK` (0x000f000f) to extract 4-bit values
56+
// - Then ORs the result with `I4s_TO_BF16s_MAGIC_NUM` (0x43004300) to convert them to bfloat16
57+
#if defined(USE_ROCM)
58+
asm volatile("v_and_or_b32 %0, %1, %2, %3"
59+
: "=v"(h[0])
60+
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
61+
#else
4062
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
4163
: "=r"(h[0])
4264
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
65+
#endif
66+
4367
#pragma unroll
4468
for (int ii = 1; ii < kElements / 2; ++ii) {
4569
i4s >>= 4; // or is it 8?
4670
// (i4s & 0x000f000f) | 0x43004300
71+
#if defined(USE_ROCM)
72+
asm volatile("v_and_or_b32 %0, %1, %2, %3"
73+
: "=v"(h[ii])
74+
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
75+
#else
4776
asm volatile(
4877
"lop3.b32 %0, %1, %2, %3, %4;\n"
4978
: "=r"(h[ii])
5079
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
80+
#endif
5181
}
5282

5383
// This is the BF16 {-136, -136} represented as an integer.
54-
static constexpr uint32_t BF16_BIAS = 0xC308C308;
55-
static constexpr uint32_t BF16_ONE = 0x3F803F80;
84+
#if defined(USE_ROCM)
85+
#if ROCM_VERSION >= 60200
86+
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
87+
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
88+
#else
89+
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16{0xC308});
90+
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
91+
#endif
92+
#else
93+
static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308;
94+
static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80;
95+
#endif
5696

5797
// Finally, we construct the output numbers.
5898
#pragma unroll
5999
for (int ii = 0; ii < kElements / 2; ++ii) {
60100
// Since this section is for Ampere+, we use bf16 fma to do the bias
61101
// subtraction
102+
#if defined(USE_ROCM)
103+
result.vals[ii] = __hfma2(result.vals[ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR);
104+
#else
62105
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
63106
: "=r"(h[ii])
64-
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
107+
: "r"(h[ii]), "r"(BF16_UNIT_VALUE), "r"(BF16_SCALE_FACTOR));
108+
#endif
65109
}
66110

67111
return result;
@@ -123,11 +167,22 @@ __global__ void _dequantize_int4_kernel(
123167
// All b values within a 16x16 tile should fall within the same q group
124168
// Hence we load 1 scale and zero per loop
125169
int qgroup = ks[0] / groupSize;
170+
#if defined(USE_ROCM)
171+
__nv_bfloat162 scale2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
172+
__nv_bfloat162 zero2 = __bfloat162bfloat162(__hip_bfloat16(1.0f));
173+
174+
if (scales_and_zeros) {
175+
const auto& sz = *scales_and_zeros;
176+
const __nv_bfloat16* pSZ = reinterpret_cast<const __nv_bfloat16*>(&sz[qgroup][n0][0]);
177+
178+
scale2 = __bfloat162bfloat162(pSZ[0]);
179+
zero2 = __bfloat162bfloat162(pSZ[1]);
180+
}
181+
#else
126182
const __nv_bfloat16 *pSZ = reinterpret_cast<const __nv_bfloat16*>(&scales_and_zeros.value()[qgroup][n0][0]);
127-
128-
// Vectorize scales and zeros
129183
__nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
130184
__nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]);
185+
#endif
131186

132187
#pragma unroll
133188
for (int i = 0; i < 4; i++) {

0 commit comments

Comments
 (0)