Skip to content

Commit 0c70dee

Browse files
committed
Float8Tensor per row quantization pass bias to fbgemm kernel
Summary: Previously bias is not passed to fbgemm kernel for float8 per row quant, this PR adds it. Difference is we should have a faster float8 per row quantized kernel, without changing numerics or other things. Test Plan: ``` python test/dtypes/test_affine_quantized_float.py -k test_expected_kernels_on_gpu python test/quantization/quantize_/workflows/float8/test_float8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2884, branch: jerryzh168/stack/60
1 parent a73fa51 commit 0c70dee

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def test_moe_weight_reshape_ops(self):
415415

416416
def test_expected_gpu_kernel_fbgemm(self):
417417
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels
418+
and the bias add happens in the gemm kernel for per row quantization
418419
"""
419420
torch.compiler.reset()
420421

@@ -432,9 +433,13 @@ def test_expected_gpu_kernel_fbgemm(self):
432433
out, code = run_and_get_code(m, x)
433434

434435
# check at least one occurrence of the quantize op and rowwise gemm op
436+
# check that there is no `triton_poi_fused_add_0` since the bias add should
437+
# happen in the `f8f8bf16_rowwise.default` op instead of separately
435438
FileCheck().check_count(
436439
"torch.ops.triton.quantize_fp8_row.default", 1
437-
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0])
440+
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).check_not(
441+
"triton_poi_fused_add_0"
442+
).run(code[0])
438443

439444

440445
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def _(func, types, args, kwargs):
296296
"Expected fbgemm_gpu_genai package to be installed"
297297
)
298298
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
299+
mm_config = weight_tensor.mm_config
300+
assert mm_config is not None
299301

300302
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
301303
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
@@ -311,6 +313,8 @@ def _(func, types, args, kwargs):
311313
wq,
312314
x_scale,
313315
w_scale,
316+
bias=bias,
317+
use_fast_accum=mm_config.use_fast_accum,
314318
).reshape(out_shape)
315319
else:
316320
assert _is_tensorwise_scaled(weight_tensor)
@@ -319,9 +323,10 @@ def _(func, types, args, kwargs):
319323
xq,
320324
wq,
321325
x_scale * w_scale,
326+
use_fast_accum=mm_config.use_fast_accum,
322327
).reshape(out_shape)
323-
if bias is not None:
324-
res = res + bias
328+
if bias is not None:
329+
res = res + bias
325330
return res
326331
else:
327332
assert kernel_choice == "torch"

0 commit comments

Comments
 (0)