Skip to content

Commit 8c07d22

Browse files
BF16 support for Quant-LLM kernel (#1147)
* Add FP6 benchmark option to use BF16 * Change dequant bit-shifting logic for BF16 * Modify dequant + tensor core ops for bf16 * Template progress * Modify fpx quant logic to include bf16 * Add tests for FP6 BF16 * Use type punning for large exponent multiplication * Fix some TODOs * Remove option to add exponent bias directly to the exponent bits This approach is (much) slower than multiplying by 2^bias after the fact, so that's why it's not usable * Reformat * Cleanup * Fix alignment * Remove templated input type whenever possible * Remove templated input type whenever possible 2 * Remove templated input type whenever possible 3 * Less hacky way to construct a float with a large exponent * rtol=1e-2 instead of 1e-3 for bfloat16 test * Guards for SM75 * Remove redundant `__CUDA_ARCH` guards in host code Any check for `__CUDA_ARCH__` in `fp6_linear.cu` will always fail because `__CUDA_ARCH__` is undefined since all of the functions in `fp6_linear.cu` are host functions * Fix consistency in checking for `CUDA_ARCH` versions * Update docs * Make float bias a constexpr * Update docs more * Fix SM75 support * Compile guard for sm<75 * Check for CUDA synchronous errors after kernel launch If this is not done, the kernel may still run but fail silently, leading to unexpected behavior * Updated compile guard * Fix problematic usage of `__CUDA_ARCH__` There are currently several ways of using `__CUDA_ARCH__` that lead to undefined behavior. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch for details of how `__CUDA_ARCH__` should not be used * Fix incorrect CUDA error handling * Make the kernel fail for sm75 + bfloat16 inputs
1 parent f99b667 commit 8c07d22

File tree

15 files changed

+258
-153
lines changed

15 files changed

+258
-153
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,42 @@
88

99

1010
def benchmark(m: int, k: int, n: int):
11-
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
12-
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2))
13-
fp16_weight = fp6_weight.dequantize(torch.half)
14-
15-
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
16-
fp6_output = F.linear(fp16_act, fp6_weight)
11+
float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda")
12+
float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
13+
fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2))
14+
fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2))
15+
fp16_weight = fp6_weight_fp16.dequantize(torch.float16)
16+
bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16)
17+
18+
fp16_act = torch.randn(m, k, dtype=torch.float16, device="cuda")
19+
bf16_act = fp16_act.to(torch.bfloat16)
20+
fp6_output_fp16 = F.linear(fp16_act, fp6_weight_fp16)
21+
fp6_output_bf16 = F.linear(bf16_act, fp6_weight_bf16)
1722
fp16_output = F.linear(fp16_act, fp16_weight)
23+
bf16_output = F.linear(bf16_act, bf16_weight)
1824

19-
fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
2025
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
26+
bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight)
27+
fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16)
28+
fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16)
2129

2230
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
2331
# doesn't seem to be the right way to check for correctness
24-
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
32+
correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
33+
correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2
2534

2635
return {
2736
"m": m,
2837
"k": k,
2938
"n": n,
30-
"fp6_latency (ms)": fp6_time,
31-
"fp16_latency (ms)": fp16_time,
32-
"speedup (d/s)": fp16_time / fp6_time,
33-
"correct": correct,
39+
"fp6-fp16 latency (ms)": fp6_time_fp16,
40+
"fp16 latency (ms)": fp16_time,
41+
"speedup fp16": fp16_time / fp6_time_fp16,
42+
"correct fp16": correct_fp16,
43+
"fp6-bf16 latency (ms)": fp6_time_bf16,
44+
"bf16 latency (ms)": bf16_time,
45+
"speedup bf16": bf16_time / fp6_time_bf16,
46+
"correct bf16": correct_bf16,
3447
}
3548

3649

test/dtypes/test_floatx.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,17 @@ def test_to_copy_device(self, ebits, mbits):
9191
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
9292
@parametrize("ebits,mbits", _Floatx_DTYPES)
9393
@parametrize("bias", [False, True])
94+
@parametrize("dtype", [torch.half, torch.bfloat16])
9495
@pytest.mark.skipif(is_fbcode(), reason="broken in fbcode")
95-
def test_fpx_weight_only(self, ebits, mbits, bias):
96+
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
9697
N, OC, IC = 4, 256, 64
9798
device = "cuda"
9899

99-
linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half)
100+
linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype)
100101
fpx_linear = copy.deepcopy(linear)
101102
quantize_(fpx_linear, fpx_weight_only(ebits, mbits))
102103

103-
x = torch.randn(N, IC, device=device, dtype=torch.half)
104+
x = torch.randn(N, IC, device=device, dtype=dtype)
104105
expected = fpx_linear(x)
105106
actual = torch.compile(fpx_linear, fullgraph=True)(x)
106107
# somehow compile now changes the result a bit

test/test_ops.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,23 @@
3333

3434

3535
class TestOps(TestCase):
36-
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
36+
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype):
3737
# Randomly initialize each byte
3838
nbits = 1 + ebits + mbits
3939
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
40-
scale = torch.rand(OC).half() + 0.5
41-
fp16_act = torch.rand(BS, IC).half() + 0.5
40+
scale = torch.rand(OC).to(dtype) + 0.5
41+
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
4242
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
4343

4444
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4545
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
46-
def test_quant_llm_linear(self, ebits, mbits):
46+
@parametrize("dtype", [torch.half, torch.bfloat16])
47+
def test_quant_llm_linear(self, ebits, mbits, dtype):
4748
BS = 2
4849
OC = 256
4950
IC = 256
5051
splitK = 1
51-
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
52+
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype)
5253

5354
# smoke test
5455
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
@@ -60,19 +61,21 @@ def test_quant_llm_linear(self, ebits, mbits):
6061
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6162
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
6263
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
63-
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
64+
@parametrize("dtype", [torch.half, torch.bfloat16])
65+
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype):
6466
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
65-
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
67+
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype)
6668

6769
results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
6870

69-
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half()
71+
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype)
7072
results_fp16 = fp16_act @ fp16_weight.T
7173

7274
error = (results_floatx - results_fp16).abs().mean()
7375
gt = results_fp16.abs().mean()
7476
relative_error = error / gt
75-
assert relative_error < 1e-3
77+
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
78+
assert relative_error < rtol
7679

7780
instantiate_parametrized_tests(TestOps)
7881

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# FP6-LLM kernel
22

3-
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN).
3+
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN).
44

55
On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion.
66

7-
See https://github.com/pytorch/ao/pull/223 for some benchmark results.
7+
See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results.

0 commit comments

Comments
 (0)