Skip to content

Commit 5f8d5e2

Browse files
committed
Float8Tensor per row quantization pass bias to fbgemm kernel
Summary: Previously bias is not passed to fbgemm kernel for float8 per row quant, this PR adds it. Difference is we should have a faster float8 per row quantized kernel, without changing numerics or other things. Test Plan: ``` python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2884, branch: jerryzh168/stack/60
1 parent fbe3df9 commit 5f8d5e2

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,9 @@ def test_moe_weight_reshape_ops(self):
418418
# https://github.com/pytorch/ao/issues/2649
419419
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
420420
def test_expected_gpu_kernel_fbgemm(self):
421-
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels"""
421+
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels
422+
and the bias add happens in the gemm kernel for per row quantization
423+
"""
422424
torch.compiler.reset()
423425

424426
M, K, N = 128, 256, 512
@@ -435,9 +437,13 @@ def test_expected_gpu_kernel_fbgemm(self):
435437
out, code = run_and_get_code(m, x)
436438

437439
# check at least one occurrence of the quantize op and rowwise gemm op
440+
# check that there is no `triton_poi_fused_add_0` since the bias add should
441+
# happen in the `f8f8bf16_rowwise.default` op instead of separately
438442
FileCheck().check_count(
439443
"torch.ops.triton.quantize_fp8_row.default", 1
440-
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0])
444+
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).check_not(
445+
"triton_poi_fused_add_0"
446+
).run(code[0])
441447

442448

443449
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ def _(func, types, args, kwargs):
297297
"Expected fbgemm_gpu_genai package to be installed"
298298
)
299299
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
300+
mm_config = weight_tensor.mm_config
301+
assert mm_config is not None
300302

301303
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
302304
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
@@ -312,6 +314,8 @@ def _(func, types, args, kwargs):
312314
wq,
313315
x_scale,
314316
w_scale,
317+
bias=bias,
318+
use_fast_accum=mm_config.use_fast_accum,
315319
).reshape(out_shape)
316320
else:
317321
assert _is_tensorwise_scaled(weight_tensor)
@@ -320,9 +324,10 @@ def _(func, types, args, kwargs):
320324
xq,
321325
wq,
322326
x_scale * w_scale,
327+
use_fast_accum=mm_config.use_fast_accum,
323328
).reshape(out_shape)
324-
if bias is not None:
325-
res = res + bias
329+
if bias is not None:
330+
res = res + bias
326331
return res
327332
else:
328333
assert kernel_choice == "torch"

0 commit comments

Comments
 (0)