Skip to content

Commit fbe3df9

Browse files
authored
Fix Float8Tensor quantize op kernrel preference dispatch (#2883)
Summary: Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like `_choose_scale_float8` and `_quantize_affine_float8` to quantize the high precision Tensor into a float8 Tensor this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning: `torch.ops.triton.quantize_fp8_row` for per row, and `torch.ops.fbgemm.quantize_fp8_per_tensor` for per tensor (while `torch.ops.fbgemm.quantize_fp8_per_tensor` has some issues right now and we'll enable later when it's fixed) This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm Reviewers: Subscribers: Tasks: Tags:
1 parent 7ea5410 commit fbe3df9

File tree

3 files changed

+86
-17
lines changed

3 files changed

+86
-17
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from typing import Tuple
1111

1212
import torch
13+
from torch._inductor.utils import run_and_get_code
14+
from torch.testing import FileCheck
1315
from torch.testing._internal import common_utils
1416
from torch.testing._internal.common_utils import (
1517
run_tests,
@@ -85,6 +87,14 @@ def test_fp8_linear_variants(
8587
kernel_preference: KernelPreference,
8688
sizes: Tuple,
8789
):
90+
if (
91+
isinstance(granularity, PerTensor)
92+
and kernel_preference == KernelPreference.FBGEMM
93+
):
94+
return unittest.skip(
95+
"per tensor with fbgemm kernel preferece does not work yet"
96+
)
97+
8898
error_message = None
8999
if isinstance(granularity, PerRow):
90100
if mode == "dynamic" and dtype != torch.bfloat16:
@@ -237,7 +247,11 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
237247
other_kernel_preferences = [
238248
KernelPreference.AUTO,
239249
]
240-
if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90():
250+
if (
251+
_is_fbgemm_genai_gpu_available()
252+
and is_sm_at_least_90()
253+
and not isinstance(granularity, PerTensor)
254+
):
241255
other_kernel_preferences.append(KernelPreference.FBGEMM)
242256

243257
quantized_outputs = {}
@@ -399,6 +413,32 @@ def test_moe_weight_reshape_ops(self):
399413
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
400414
self._test_moe_weight_reshape_ops(config)
401415

416+
# TODO: we have some other tests living in https://github.com/pytorch/ao/blob/4ecc89edd7b5cfc12e6f80854c85d04c472a0eb0/test/dtypes/test_affine_quantized_float.py#L743
417+
# that should be moved here after v1 config is deprecated:
418+
# https://github.com/pytorch/ao/issues/2649
419+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
420+
def test_expected_gpu_kernel_fbgemm(self):
421+
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels"""
422+
torch.compiler.reset()
423+
424+
M, K, N = 128, 256, 512
425+
m = torch.nn.Sequential(
426+
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
427+
)
428+
config = Float8DynamicActivationFloat8WeightConfig(
429+
granularity=PerRow(),
430+
kernel_preference=KernelPreference.FBGEMM,
431+
)
432+
quantize_(m, config)
433+
m = torch.compile(m)
434+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
435+
out, code = run_and_get_code(m, x)
436+
437+
# check at least one occurrence of the quantize op and rowwise gemm op
438+
FileCheck().check_count(
439+
"torch.ops.triton.quantize_fp8_row.default", 1
440+
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0])
441+
402442

403443
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
404444

torchao/quantization/quantize_/common/kernel_preference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class KernelPreference(str, Enum):
2626
"""
2727
TORCH = "torch"
2828

29-
"""Use fbgemm quantize and quantized mm kernels, requires fbgemm_gpu_genai library
29+
"""Use quantize and quantized mm kernels from fbgemm_gpu_genai library, requires fbgemm_gpu_genai library
3030
"""
3131
FBGEMM = "fbgemm"
3232

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
preprocess_data,
2323
preprocess_scale,
2424
)
25-
from torchao.quantization.granularity import PerRow
25+
from torchao.quantization.granularity import PerRow, PerTensor
2626
from torchao.quantization.observer import get_block_size
2727
from torchao.quantization.quant_primitives import (
2828
_choose_scale_float8,
@@ -178,32 +178,61 @@ def from_hp(
178178
block_size = get_block_size(hp_tensor.shape, granularity)
179179
block_size = list(block_size)
180180

181-
# for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance
181+
kernel_choice = None
182182
if (
183183
kernel_preference == KernelPreference.AUTO
184184
and _is_fbgemm_genai_gpu_available()
185-
and (
186-
tuple(block_size)
187-
== (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],)
188-
)
185+
and is_sm_at_least_90()
186+
and isinstance(granularity, PerRow)
187+
and float8_dtype == torch.float8_e4m3fn
188+
and hp_value_lb is None
189189
):
190-
assert float8_dtype == torch.float8_e4m3fn, (
191-
f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}"
190+
# if kernel_preference is AUTO and per row quantization
191+
# we'll use fbgemm quantize kernel for best performance
192+
kernel_choice = "fbgemm"
193+
elif kernel_preference == KernelPreference.FBGEMM:
194+
# if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel
195+
assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), (
196+
"Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (>= H100)"
197+
)
198+
assert hp_value_lb is None, (
199+
"hp_value_lb should not be specified if with KerenelPreference.FBGEMM"
192200
)
201+
kernel_choice = "fbgemm"
202+
else:
203+
# fallback quantize kernel for everything else will be torch
204+
kernel_choice = "torch"
205+
206+
if kernel_choice == "fbgemm":
207+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
193208
if hp_value_ub is not None:
194209
maybe_hp_value_ub_tensor = torch.tensor(
195210
hp_value_ub, dtype=torch.float, device=hp_tensor.device
196211
)
197212
else:
198213
maybe_hp_value_ub_tensor = None
199-
data, scale = torch.ops.triton.quantize_fp8_row(
200-
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
201-
)
202-
scale_shape = []
203-
for i in range(hp_tensor.ndim):
204-
scale_shape.append(hp_tensor.shape[i] // block_size[i])
205-
scale = scale.reshape(*scale_shape)
214+
if isinstance(granularity, PerRow):
215+
data, scale = torch.ops.triton.quantize_fp8_row(
216+
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
217+
)
218+
scale_shape = []
219+
for i in range(hp_tensor.ndim):
220+
scale_shape.append(hp_tensor.shape[i] // block_size[i])
221+
scale = scale.reshape(*scale_shape)
222+
else:
223+
assert isinstance(granularity, PerTensor), (
224+
f"Expected per tensor, got {granularity}"
225+
)
226+
# current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
227+
# TODO: enable after this is working
228+
# data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor(
229+
# hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
230+
# )
231+
raise NotImplementedError(
232+
"Currently KernelPreference.FBGEMM does not work for per tensor float8 quant"
233+
)
206234
else:
235+
assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}"
207236
scale = _choose_scale_float8(
208237
hp_tensor,
209238
float8_dtype=float8_dtype,

0 commit comments

Comments
 (0)