Skip to content

Commit bf31998

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Fix the math in the FuseMulTensorIntoQuantPass. (#11946)
Summary: The new scale value was calculated incorrectly, fixing that with this diff. See the details of the new scale calculation in the comments in the pass. Reviewed By: zonglinpeng, abeakkas Differential Revision: D77267667
1 parent 083663b commit bf31998

File tree

2 files changed

+50
-38
lines changed

2 files changed

+50
-38
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -856,19 +856,32 @@ class FuseMulTensorIntoQuantPass(ExportPass):
856856
def attempt_fusion(
857857
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
858858
) -> None:
859-
full_nodes = [
860-
arg
861-
for arg in mul_node.args
862-
if isinstance(arg, torch.fx.Node)
863-
and arg.target == exir_ops.edge.aten.full.default
864-
]
859+
if len(mul_node.args) != 2 or len(mul_node.users) != 1:
860+
return
861+
862+
first_arg = cast(torch.fx.Node, mul_node.args[0])
863+
second_arg = cast(torch.fx.Node, mul_node.args[1])
864+
865+
input_node = first_arg
866+
full_node = second_arg
867+
if second_arg.target == exir_ops.edge.aten.full.default:
868+
# Most common case, nothing to change.
869+
pass
870+
elif first_arg.target == exir_ops.edge.aten.full.default:
871+
# Input and full nodes are swapped.
872+
full_node = first_arg
873+
input_node = second_arg
874+
else:
875+
# Full node is not found, skip.
876+
return
865877

866-
if len(full_nodes) != 1 or len(mul_node.users) != 1:
878+
# Ensure that the mul op does not do any broadcasting.
879+
if input_node.meta["val"].shape != mul_node.meta["val"].shape:
867880
return
868881

869-
full_node = full_nodes[0]
870882
mul_user = list(mul_node.users.keys())[0]
871883

884+
# Ensure only the expected quant ops are using the current mul op.
872885
if mul_user.target not in {
873886
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
874887
exir_ops.edge.cadence.quantize_per_tensor.default,
@@ -878,33 +891,28 @@ def attempt_fusion(
878891
quant_node = mul_user
879892

880893
# Calculate the new scale value.
881-
prev_scale = quant_node.args[1]
882-
assert isinstance(prev_scale, (int, float))
894+
old_scale = quant_node.args[1]
895+
assert isinstance(old_scale, (int, float))
883896
mul_scalar = full_node.args[1]
884897
assert isinstance(mul_scalar, (int, float))
885-
new_scale = float(prev_scale) * float(mul_scalar)
898+
""" The reason why we divide old scale by the mul value to get a new scale:
899+
y = x * mul_scalar
900+
q = zp + y / old_scale
901+
q = zp + x * mul_scalar / old_scale
902+
new_scale = old_scale / mul_scalar
903+
q = zp + x / new_scale
904+
"""
905+
new_scale = float(old_scale) / float(mul_scalar)
886906

887907
logging.debug(
888908
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
889909
)
890910

891-
# Replace the input first
892-
quant_node.replace_input_with(
893-
cast(torch.fx.Node, quant_node.args[0]),
894-
cast(torch.fx.Node, mul_node.args[0]),
895-
)
896-
897-
# Now update the scale in the args
898-
new_quant_args = list(quant_node.args)
899-
new_quant_args[1] = new_scale
900-
quant_node.args = tuple(new_quant_args)
901-
902-
# Clean up the mul_node
903-
mul_node.args = ()
904-
mul_node.users = {}
905-
906-
graph_module.graph.erase_node(mul_node)
907-
graph_module.graph.erase_node(full_node)
911+
# Update quant node input and scale.
912+
old_quant_input = cast(torch.fx.Node, quant_node.args[0])
913+
new_quant_input = cast(torch.fx.Node, mul_node.args[0])
914+
quant_node.replace_input_with(old_quant_input, new_quant_input)
915+
quant_node.update_arg(1, new_scale)
908916

909917
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
910918
for node in graph_module.graph.find_nodes(

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
598598
self.assertEqual(deq_scale, dequant_scale * mul_value)
599599

600600
def test_fuse_mul_into_quant(self) -> None:
601-
quant_scale = 1.5
601+
quant_scale = 5
602602
mul_value = 10
603603

604604
builder = GraphBuilder()
@@ -613,7 +613,7 @@ def test_fuse_mul_into_quant(self) -> None:
613613
)
614614
quant = builder.call_operator(
615615
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
616-
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
616+
args=(mul, quant_scale, 7, 0, 255, torch.uint8),
617617
)
618618
builder.output([quant])
619619
original_graph = builder.get_graph_module()
@@ -631,14 +631,18 @@ def test_fuse_mul_into_quant(self) -> None:
631631
)
632632

633633
# verify that the quant scale value was updated correctly
634-
deq_scale = -1
635-
for node in converted_graph.graph.nodes:
636-
if (
637-
node.target
638-
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
639-
):
640-
deq_scale = node.args[1]
641-
self.assertEqual(deq_scale, quant_scale * mul_value)
634+
for node in converted_graph.graph.find_nodes(
635+
op="call_function",
636+
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
637+
):
638+
new_quant_scale = node.args[1]
639+
self.assertEqual(new_quant_scale, quant_scale / mul_value)
640+
641+
# verify the math is correct
642+
inp = torch.randn(4, 32, dtype=torch.float32)
643+
original_out = original_graph(inp)[0]
644+
new_out = converted_graph(inp)[0]
645+
assert torch.equal(original_out, new_out)
642646

643647
def test_fuse_then_transpose_pass(self) -> None:
644648
# Create a graph with full -> transpose.

0 commit comments

Comments
 (0)