Skip to content

Commit a207d64

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Fix the math in the FuseMulTensorIntoQuantPass. (#11946)
Summary: The new scale value was calculated incorrectly, fixing that with this diff. See the details of the new scale calculation in the comments in the pass. Differential Revision: D77267667
1 parent 85cf6ce commit a207d64

File tree

2 files changed

+38
-38
lines changed

2 files changed

+38
-38
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -856,19 +856,23 @@ class FuseMulTensorIntoQuantPass(ExportPass):
856856
def attempt_fusion(
857857
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
858858
) -> None:
859-
full_nodes = [
860-
arg
861-
for arg in mul_node.args
862-
if isinstance(arg, torch.fx.Node)
863-
and arg.target == exir_ops.edge.aten.full.default
864-
]
859+
if len(mul_node.args) != 2 or len(mul_node.users) != 1:
860+
return
861+
862+
second_arg = cast(torch.fx.Node, mul_node.args[1])
863+
input_index = 0 if second_arg.target == exir_ops.edge.aten.full.default else 1
865864

866-
if len(full_nodes) != 1 or len(mul_node.users) != 1:
865+
input_node = cast(torch.fx.Node, mul_node.args[input_index])
866+
full_node = cast(torch.fx.Node, mul_node.args[1 - input_index])
867+
output_node = list(mul_node.users.keys())[0]
868+
869+
# Ensure that the mul op does not do any broadcasting.
870+
if input_node.meta["val"].shape != output_node.meta["val"].shape:
867871
return
868872

869-
full_node = full_nodes[0]
870873
mul_user = list(mul_node.users.keys())[0]
871874

875+
# Ensure only the expected quant ops are using the current mul op.
872876
if mul_user.target not in {
873877
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
874878
exir_ops.edge.cadence.quantize_per_tensor.default,
@@ -878,33 +882,27 @@ def attempt_fusion(
878882
quant_node = mul_user
879883

880884
# Calculate the new scale value.
881-
prev_scale = quant_node.args[1]
882-
assert isinstance(prev_scale, (int, float))
885+
old_scale = quant_node.args[1]
886+
assert isinstance(old_scale, (int, float))
883887
mul_scalar = full_node.args[1]
884888
assert isinstance(mul_scalar, (int, float))
885-
new_scale = float(prev_scale) * float(mul_scalar)
889+
# The reason why we divide old scale by the mul value to get a new scale:
890+
# y = x * mul_scalar
891+
# q = zp + y / old_scale
892+
# q = zp + x * mul_scalar / old_scale
893+
# new_scale = old_scale / mul_scalar
894+
# q = zp + x / new_scale
895+
new_scale = float(old_scale) / float(mul_scalar)
886896

887897
logging.debug(
888898
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
889899
)
890900

891-
# Replace the input first
892-
quant_node.replace_input_with(
893-
cast(torch.fx.Node, quant_node.args[0]),
894-
cast(torch.fx.Node, mul_node.args[0]),
895-
)
896-
897-
# Now update the scale in the args
898-
new_quant_args = list(quant_node.args)
899-
new_quant_args[1] = new_scale
900-
quant_node.args = tuple(new_quant_args)
901-
902-
# Clean up the mul_node
903-
mul_node.args = ()
904-
mul_node.users = {}
905-
906-
graph_module.graph.erase_node(mul_node)
907-
graph_module.graph.erase_node(full_node)
901+
# Update quant node input and scale.
902+
old_quant_input = cast(torch.fx.Node, quant_node.args[0])
903+
new_quant_input = cast(torch.fx.Node, mul_node.args[0])
904+
quant_node.replace_input_with(old_quant_input, new_quant_input)
905+
quant_node.update_arg(1, new_scale)
908906

909907
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
910908
for node in graph_module.graph.find_nodes(

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
598598
self.assertEqual(deq_scale, dequant_scale * mul_value)
599599

600600
def test_fuse_mul_into_quant(self) -> None:
601-
quant_scale = 1.5
601+
quant_scale = 5
602602
mul_value = 10
603603

604604
builder = GraphBuilder()
@@ -613,7 +613,7 @@ def test_fuse_mul_into_quant(self) -> None:
613613
)
614614
quant = builder.call_operator(
615615
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
616-
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
616+
args=(mul, quant_scale, 7, 0, 255, torch.uint8),
617617
)
618618
builder.output([quant])
619619
original_graph = builder.get_graph_module()
@@ -631,14 +631,16 @@ def test_fuse_mul_into_quant(self) -> None:
631631
)
632632

633633
# verify that the quant scale value was updated correctly
634-
deq_scale = -1
635-
for node in converted_graph.graph.nodes:
636-
if (
637-
node.target
638-
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
639-
):
640-
deq_scale = node.args[1]
641-
self.assertEqual(deq_scale, quant_scale * mul_value)
634+
for node in converted_graph.graph.find_nodes(op="call_function",
635+
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default):
636+
new_quant_scale = node.args[1]
637+
self.assertEqual(new_quant_scale, quant_scale / mul_value)
638+
639+
# verify the math is correct
640+
inp = torch.randn(4, 32, dtype=torch.float32)
641+
original_out = original_graph(inp)[0]
642+
new_out = converted_graph(inp)[0]
643+
assert torch.equal(original_out, new_out)
642644

643645
def test_fuse_then_transpose_pass(self) -> None:
644646
# Create a graph with full -> transpose.

0 commit comments

Comments
 (0)