From 4211b85abe3996439684c1af583b4a073d99c3c3 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:18:43 +0200 Subject: [PATCH 1/2] Fix failing FP6 benchmark --- benchmarks/benchmark_fp6.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index e9f9d21398..da6be403e0 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,23 +1,27 @@ import torch import pandas as pd -import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_floatx -from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType +import torchao +from torchao.dtypes.floatx import from_scaled_tc_floatx from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm def benchmark(m: int, k: int, n: int): - float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2)) - fp16_weight = fp6_weight.dequantize(torch.half) + ebits = 3 + mbits = 2 + nbits = 1 + ebits + mbits - fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") - fp6_output = F.linear(fp16_act, fp6_weight) - fp16_output = F.linear(fp16_act, fp16_weight) + fp6_weight = torch.randint(256, (n, k // 8 * nbits), dtype=torch.uint8, device="cuda") + scale = torch.rand(n, device="cuda").half() + 0.5 + fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + 0.5 - fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight) - fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) + fp6_output = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scale, splitK=1) + + fp16_weight = from_scaled_tc_floatx(fp6_weight, ebits, mbits, scale).half() + fp16_output = torch.matmul(fp16_act, fp16_weight.T) + + fp6_time = benchmark_torch_function_in_microseconds(torchao.ops.quant_llm_linear, ebits, mbits, fp16_act, fp6_weight, scale, splitK=1) + fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, fp16_act, fp16_weight.T) # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness From 7fbbcca9ec07a9d80f143dbea4d03c3627999d6a Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:23:00 +0200 Subject: [PATCH 2/2] More elegant weight initialization for FP6 benchmark --- benchmarks/benchmark_fp6.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index da6be403e0..0709035efa 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,7 +1,7 @@ import torch import pandas as pd import torchao -from torchao.dtypes.floatx import from_scaled_tc_floatx +from torchao.dtypes.floatx import from_scaled_tc_floatx, to_scaled_tc_floatx from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm @@ -9,10 +9,9 @@ def benchmark(m: int, k: int, n: int): ebits = 3 mbits = 2 - nbits = 1 + ebits + mbits - fp6_weight = torch.randint(256, (n, k // 8 * nbits), dtype=torch.uint8, device="cuda") - scale = torch.rand(n, device="cuda").half() + 0.5 + fp32_weight = torch.randn(n, k, device="cuda") + fp6_weight, scale = to_scaled_tc_floatx(fp32_weight, ebits, mbits) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + 0.5 fp6_output = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scale, splitK=1)