Skip to content

Commit 2b5046b

Browse files
committed
Float8Tensor per row quantization pass bias to fbgemm kernel
Summary: Previously bias is not passed to fbgemm kernel for float8 per row quant, this PR adds it. Difference is we should have a faster float8 per row quantized kernel, without changing numerics or other things. Test Plan: ``` python test/dtypes/test_affine_quantized_float.py -k test_expected_kernels_on_gpu python test/quantization/quantize_/workflows/float8/test_float8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2884, branch: jerryzh168/stack/60
1 parent f685f8b commit 2b5046b

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ def test_moe_weight_reshape_ops(self):
415415

416416
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
417417
def test_expected_gpu_kernel_fbgemm(self):
418-
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels"""
418+
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels
419+
and the bias add happens in the gemm kernel for per row quantization
420+
"""
419421
torch.compiler.reset()
420422

421423
M, K, N = 128, 256, 512
@@ -432,9 +434,13 @@ def test_expected_gpu_kernel_fbgemm(self):
432434
out, code = run_and_get_code(m, x)
433435

434436
# check at least one occurrence of the quantize op and rowwise gemm op
437+
# check that there is no `triton_poi_fused_add_0` since the bias add should
438+
# happen in the `f8f8bf16_rowwise.default` op instead of separately
435439
FileCheck().check_count(
436440
"torch.ops.triton.quantize_fp8_row.default", 1
437-
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0])
441+
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).check_not(
442+
"triton_poi_fused_add_0"
443+
).run(code[0])
438444

439445

440446
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def _(func, types, args, kwargs):
296296
"Expected fbgemm_gpu_genai package to be installed"
297297
)
298298
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
299+
mm_config = weight_tensor.mm_config
300+
assert mm_config is not None
299301

300302
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
301303
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
@@ -311,6 +313,8 @@ def _(func, types, args, kwargs):
311313
wq,
312314
x_scale,
313315
w_scale,
316+
bias=bias,
317+
use_fast_accum=mm_config.use_fast_accum,
314318
).reshape(out_shape)
315319
else:
316320
assert _is_tensorwise_scaled(weight_tensor)
@@ -319,9 +323,10 @@ def _(func, types, args, kwargs):
319323
xq,
320324
wq,
321325
x_scale * w_scale,
326+
use_fast_accum=mm_config.use_fast_accum,
322327
).reshape(out_shape)
323-
if bias is not None:
324-
res = res + bias
328+
if bias is not None:
329+
res = res + bias
325330
return res
326331
else:
327332
assert kernel_choice == "torch"

0 commit comments

Comments
 (0)