Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Tuple

import torch
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
run_tests,
Expand Down Expand Up @@ -85,6 +87,14 @@ def test_fp8_linear_variants(
kernel_preference: KernelPreference,
sizes: Tuple,
):
if (
isinstance(granularity, PerTensor)
and kernel_preference == KernelPreference.FBGEMM
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: lets Xfail this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using unittest, seems like we can't do return unittest.expectedFailure("...")?

.../ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py", line 92, in test_fp8_linear_variants
    return unittest.expectedFailure(
  File ".../python3.10/unittest/case.py", line 148, in expectedFailure
    test_item.__unittest_expecting_failure__ = True
AttributeError: 'str' object has no attribute '__unittest_expecting_failure__'

but let me know if there is an example to do expectedFailure conditionally instead of skipping entire test

return unittest.skip(
"per tensor with fbgemm kernel preferece does not work yet"
)

error_message = None
if isinstance(granularity, PerRow):
if mode == "dynamic" and dtype != torch.bfloat16:
Expand Down Expand Up @@ -237,7 +247,11 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
other_kernel_preferences = [
KernelPreference.AUTO,
]
if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90():
if (
_is_fbgemm_genai_gpu_available()
and is_sm_at_least_90()
and not isinstance(granularity, PerTensor)
):
other_kernel_preferences.append(KernelPreference.FBGEMM)

quantized_outputs = {}
Expand Down Expand Up @@ -399,6 +413,32 @@ def test_moe_weight_reshape_ops(self):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
self._test_moe_weight_reshape_ops(config)

# TODO: we have some other tests living in https://github.com/pytorch/ao/blob/4ecc89edd7b5cfc12e6f80854c85d04c472a0eb0/test/dtypes/test_affine_quantized_float.py#L743
# that should be moved here after v1 config is deprecated:
# https://github.com/pytorch/ao/issues/2649
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
def test_expected_gpu_kernel_fbgemm(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test should be together with the other tests we have which check the same thing for other settings of this config, currently in test_affine_quantized_float.py. Can we add a TODO to unify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think we can put everything here after we deprecate the AQT path in 9 months

"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels"""
torch.compiler.reset()

M, K, N = 128, 256, 512
m = torch.nn.Sequential(
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
)
config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
kernel_preference=KernelPreference.FBGEMM,
)
quantize_(m, config)
m = torch.compile(m)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
out, code = run_and_get_code(m, x)

# check at least one occurrence of the quantize op and rowwise gemm op
FileCheck().check_count(
"torch.ops.triton.quantize_fp8_row.default", 1
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0])


common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quantize_/common/kernel_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class KernelPreference(str, Enum):
"""
TORCH = "torch"

"""Use fbgemm quantize and quantized mm kernels, requires fbgemm_gpu_genai library
"""Use quantize and quantized mm kernels from fbgemm_gpu_genai library, requires fbgemm_gpu_genai library
"""
FBGEMM = "fbgemm"

Expand Down
59 changes: 44 additions & 15 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
preprocess_data,
preprocess_scale,
)
from torchao.quantization.granularity import PerRow
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.observer import get_block_size
from torchao.quantization.quant_primitives import (
_choose_scale_float8,
Expand Down Expand Up @@ -177,32 +177,61 @@ def to_float8(
block_size = get_block_size(hp_tensor.shape, granularity)
block_size = list(block_size)

# for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance
kernel_choice = None
if (
kernel_preference == KernelPreference.AUTO
and _is_fbgemm_genai_gpu_available()
and (
tuple(block_size)
== (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],)
)
and is_sm_at_least_90()
and isinstance(granularity, PerRow)
and float8_dtype == torch.float8_e4m3fn
and hp_value_lb is None
):
assert float8_dtype == torch.float8_e4m3fn, (
f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}"
# if kernel_preference is AUTO and per row quantization
# we'll use fbgemm quantize kernel for best performance
kernel_choice = "fbgemm"
elif kernel_preference == KernelPreference.FBGEMM:
# if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel
assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), (
"Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (>= H100)"
)
assert hp_value_lb is None, (
"hp_value_lb should not be specified if with KerenelPreference.FBGEMM"
)
kernel_choice = "fbgemm"
else:
# fallback quantize kernel for everything else will be torch
kernel_choice = "torch"

if kernel_choice == "fbgemm":
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
if hp_value_ub is not None:
maybe_hp_value_ub_tensor = torch.tensor(
hp_value_ub, dtype=torch.float, device=hp_tensor.device
)
else:
maybe_hp_value_ub_tensor = None
data, scale = torch.ops.triton.quantize_fp8_row(
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
)
scale_shape = []
for i in range(hp_tensor.ndim):
scale_shape.append(hp_tensor.shape[i] // block_size[i])
scale = scale.reshape(*scale_shape)
if isinstance(granularity, PerRow):
data, scale = torch.ops.triton.quantize_fp8_row(
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
)
scale_shape = []
for i in range(hp_tensor.ndim):
scale_shape.append(hp_tensor.shape[i] // block_size[i])
scale = scale.reshape(*scale_shape)
else:
assert isinstance(granularity, PerTensor), (
f"Expected per tensor, got {granularity}"
)
# current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
# TODO: enable after this is working
# data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor(
# hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
# )
raise NotImplementedError(
"Currently KernelPreference.FBGEMM does not work for per tensor float8 quant"
)
else:
assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}"
scale = _choose_scale_float8(
hp_tensor,
float8_dtype=float8_dtype,
Expand Down
Loading