Skip to content

Commit a78fc11

Browse files
committed
Fix Float8Tensor quantize op kernrel preference dispatch
Summary: Previously we didn't handle kernel_preference == "fbgemm" properly for the quantize op, this PR makes sure we dispatch to fbgemm kernels when kernel_preference is fbgemm This doesn't have much impact on BC, the serialized checkpoints will use AUTO which is going to be dispatched to triton op for quantize, only thing is fixing the kernel choice for fbgemm kernel preference, which is supposed to be a developer facing API (we expect most users to just use AUTO without worrying about details) Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2883, branch: jerryzh168/stack/59
1 parent 3bf21d0 commit a78fc11

File tree

5 files changed

+59
-15
lines changed

5 files changed

+59
-15
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def setUp(self):
6363
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
6464
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
6565
@common_utils.parametrize("compile", [True, False])
66-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
66+
@common_utils.parametrize("granularity", [PerTensor()])
6767
@common_utils.parametrize(
6868
"kernel_preference",
6969
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],

test/quantization/test_qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1859,7 +1859,7 @@ def test_float8_fake_quantize(self, granularity: Granularity):
18591859
torch.manual_seed(self.SEED)
18601860
x = torch.randn(32, 64)
18611861
out = fake_quantizer(x)
1862-
out_expected = Float8Tensor.to_float8(x, dtype, granularity).dequantize()
1862+
out_expected = Float8Tensor.from_hp(x, dtype, granularity).dequantize()
18631863
sqnr = compute_error(out, out_expected)
18641864
self.assertGreater(sqnr, 16)
18651865

torchao/quantization/quant_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,7 @@ def _float8_weight_only_quant_tensor(weight, config):
15681568
else:
15691569
assert config.version == 2, f"Unexpected version: {config.version}"
15701570
weight_dtype = config.weight_dtype
1571-
new_weight = Float8Tensor.to_float8(
1571+
new_weight = Float8Tensor.from_hp(
15721572
weight, float8_dtype=weight_dtype, granularity=PerRow()
15731573
)
15741574
return new_weight
@@ -1766,7 +1766,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
17661766
kernel_preference=kernel_preference,
17671767
)
17681768

1769-
quantized_weight = Float8Tensor.to_float8(
1769+
quantized_weight = Float8Tensor.from_hp(
17701770
weight,
17711771
float8_dtype=weight_dtype,
17721772
granularity=weight_granularity,

torchao/quantization/quantize_/common/quantize_tensor_kwargs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class QuantizeTensorKwargs(abc.ABC):
2222
2323
class Float8Tensor(...)
2424
@classmethod
25-
def to_float8(cls, tensor, quant_kwargs: QuantizeTensorKwargs)
25+
def from_hp(cls, tensor, quant_kwargs: QuantizeTensorKwargs)
2626
...
2727
"""
2828

@@ -43,7 +43,7 @@ def _choose_quant_func_and_quantize_tensor(
4343
)
4444

4545
if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs):
46-
return Float8Tensor.to_float8(
46+
return Float8Tensor.from_hp(
4747
tensor,
4848
quant_kwargs.float8_dtype,
4949
quant_kwargs.granularity,

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
preprocess_data,
2323
preprocess_scale,
2424
)
25-
from torchao.quantization.granularity import PerRow
25+
from torchao.quantization.granularity import PerRow, PerTensor
2626
from torchao.quantization.observer import get_block_size
2727
from torchao.quantization.quant_primitives import (
2828
_choose_scale_float8,
@@ -163,7 +163,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
163163
return _dequantize_affine_float8(qdata, scale, output_dtype)
164164

165165
@classmethod
166-
def to_float8(
166+
def from_hp(
167167
cls,
168168
hp_tensor: torch.Tensor,
169169
float8_dtype: torch.dtype = torch.float8_e4m3fn,
@@ -177,18 +177,30 @@ def to_float8(
177177
block_size = get_block_size(hp_tensor.shape, granularity)
178178
block_size = list(block_size)
179179

180-
# for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance
180+
kernel_choice = None
181181
if (
182182
kernel_preference == KernelPreference.AUTO
183183
and _is_fbgemm_genai_gpu_available()
184-
and (
185-
tuple(block_size)
186-
== (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],)
187-
)
184+
and is_sm_at_least_90()
185+
and isinstance(granularity, PerRow)
186+
and float8_dtype == torch.float8_e4m3fn
187+
and hp_value_lb is None
188188
):
189-
assert float8_dtype == torch.float8_e4m3fn, (
190-
f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}"
189+
# for per row quantization and kernel_preference auto setting
190+
# we'll use triton quantize kernel for best performance
191+
kernel_choice = "triton"
192+
elif kernel_preference == KernelPreference.FBGEMM and hp_value_lb is None:
193+
# we'll use fbgemm quantize kernel if it's explicitly chosen by user
194+
assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), (
195+
"Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (> H100)"
191196
)
197+
kernel_choice = "fbgemm"
198+
else:
199+
# fallback quantize kernel for everything else will be torch
200+
kernel_choice = "torch"
201+
202+
if kernel_choice == "triton":
203+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
192204
if hp_value_ub is not None:
193205
maybe_hp_value_ub_tensor = torch.tensor(
194206
hp_value_ub, dtype=torch.float, device=hp_tensor.device
@@ -202,7 +214,39 @@ def to_float8(
202214
for i in range(hp_tensor.ndim):
203215
scale_shape.append(hp_tensor.shape[i] // block_size[i])
204216
scale = scale.reshape(*scale_shape)
217+
elif kernel_choice == "fbgemm":
218+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
219+
if hp_value_ub is not None:
220+
maybe_hp_value_ub_tensor = torch.tensor(
221+
hp_value_ub, dtype=torch.float, device=hp_tensor.device
222+
)
223+
else:
224+
maybe_hp_value_ub_tensor = None
225+
# not used
226+
num_tokens = torch.empty([hp_tensor.size(0)], device=hp_tensor.device)
227+
if isinstance(granularity, PerRow):
228+
data, scale = torch.ops.fbgemm.quantize_fp8_per_row(
229+
hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
230+
)
231+
else:
232+
assert isinstance(granularity, PerTensor), (
233+
f"Expected per tensor, got {granularity}"
234+
)
235+
# TODO: use fbgemm kernel when it works
236+
# current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
237+
# data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor(
238+
# hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
239+
# )
240+
scale = _choose_scale_float8(
241+
hp_tensor,
242+
float8_dtype=float8_dtype,
243+
block_size=block_size,
244+
hp_value_lb=hp_value_lb,
245+
hp_value_ub=hp_value_ub,
246+
)
247+
data = _quantize_affine_float8(hp_tensor, scale, float8_dtype)
205248
else:
249+
assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}"
206250
scale = _choose_scale_float8(
207251
hp_tensor,
208252
float8_dtype=float8_dtype,

0 commit comments

Comments
 (0)