Skip to content

Commit 4b99590

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Fix the math in the FuseMulTensorIntoQuantPass.
Summary: The new scale value was calculated incorrectly, fixing that with this diff. See the details of the new scale calculation in the comments in the pass. Differential Revision: D77267667
1 parent 124758e commit 4b99590

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,13 @@ def attempt_fusion(
882882
assert isinstance(prev_scale, (int, float))
883883
mul_scalar = full_node.args[1]
884884
assert isinstance(mul_scalar, (int, float))
885-
new_scale = float(prev_scale) * float(mul_scalar)
885+
# The reason why we divide previous scale by the mul value to get a new scale:
886+
# y = x * mul_scalar
887+
# q = zp + y / prev_scale
888+
# q = zp + x * mul_scalar / prev_scale
889+
# new_scale = prev_scale / mul_scalar
890+
# q = zp + x / new_scale
891+
new_scale = float(prev_scale) / float(mul_scalar)
886892

887893
logging.debug(
888894
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
598598
self.assertEqual(deq_scale, dequant_scale * mul_value)
599599

600600
def test_fuse_mul_into_quant(self) -> None:
601-
quant_scale = 1.5
601+
quant_scale = 5
602602
mul_value = 10
603603

604604
builder = GraphBuilder()
@@ -613,7 +613,7 @@ def test_fuse_mul_into_quant(self) -> None:
613613
)
614614
quant = builder.call_operator(
615615
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
616-
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
616+
args=(mul, quant_scale, 7, 0, 255, torch.uint8),
617617
)
618618
builder.output([quant])
619619
original_graph = builder.get_graph_module()
@@ -631,14 +631,20 @@ def test_fuse_mul_into_quant(self) -> None:
631631
)
632632

633633
# verify that the quant scale value was updated correctly
634-
deq_scale = -1
634+
new_quant_scale = -1
635635
for node in converted_graph.graph.nodes:
636636
if (
637637
node.target
638638
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
639639
):
640-
deq_scale = node.args[1]
641-
self.assertEqual(deq_scale, quant_scale * mul_value)
640+
new_quant_scale = node.args[1]
641+
self.assertEqual(new_quant_scale, quant_scale / mul_value)
642+
643+
# verify the math is correct
644+
inp = torch.randn(4, 32, dtype=torch.float32)
645+
original_out = original_graph(inp)[0]
646+
new_out = converted_graph(inp)[0]
647+
assert torch.equal(original_out, new_out)
642648

643649
def test_fuse_then_transpose_pass(self) -> None:
644650
# Create a graph with full -> transpose.

0 commit comments

Comments
 (0)