Skip to content

Fix the math in the FuseMulTensorIntoQuantPass. #11946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,19 +856,32 @@ class FuseMulTensorIntoQuantPass(ExportPass):
def attempt_fusion(
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
) -> None:
full_nodes = [
arg
for arg in mul_node.args
if isinstance(arg, torch.fx.Node)
and arg.target == exir_ops.edge.aten.full.default
]
if len(mul_node.args) != 2 or len(mul_node.users) != 1:
return

first_arg = cast(torch.fx.Node, mul_node.args[0])
second_arg = cast(torch.fx.Node, mul_node.args[1])

input_node = first_arg
full_node = second_arg
if second_arg.target == exir_ops.edge.aten.full.default:
# Most common case, nothing to change.
pass
elif first_arg.target == exir_ops.edge.aten.full.default:
# Input and full nodes are swapped.
full_node = first_arg
input_node = second_arg
else:
# Full node is not found, skip.
return

if len(full_nodes) != 1 or len(mul_node.users) != 1:
# Ensure that the mul op does not do any broadcasting.
if input_node.meta["val"].shape != mul_node.meta["val"].shape:
return

full_node = full_nodes[0]
mul_user = list(mul_node.users.keys())[0]

# Ensure only the expected quant ops are using the current mul op.
if mul_user.target not in {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
Expand All @@ -878,33 +891,28 @@ def attempt_fusion(
quant_node = mul_user

# Calculate the new scale value.
prev_scale = quant_node.args[1]
assert isinstance(prev_scale, (int, float))
old_scale = quant_node.args[1]
assert isinstance(old_scale, (int, float))
mul_scalar = full_node.args[1]
assert isinstance(mul_scalar, (int, float))
new_scale = float(prev_scale) * float(mul_scalar)
""" The reason why we divide old scale by the mul value to get a new scale:
y = x * mul_scalar
q = zp + y / old_scale
q = zp + x * mul_scalar / old_scale
new_scale = old_scale / mul_scalar
q = zp + x / new_scale
"""
new_scale = float(old_scale) / float(mul_scalar)

logging.debug(
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
)

# Replace the input first
quant_node.replace_input_with(
cast(torch.fx.Node, quant_node.args[0]),
cast(torch.fx.Node, mul_node.args[0]),
)

# Now update the scale in the args
new_quant_args = list(quant_node.args)
new_quant_args[1] = new_scale
quant_node.args = tuple(new_quant_args)

# Clean up the mul_node
mul_node.args = ()
mul_node.users = {}

graph_module.graph.erase_node(mul_node)
graph_module.graph.erase_node(full_node)
# Update quant node input and scale.
old_quant_input = cast(torch.fx.Node, quant_node.args[0])
new_quant_input = cast(torch.fx.Node, mul_node.args[0])
quant_node.replace_input_with(old_quant_input, new_quant_input)
quant_node.update_arg(1, new_scale)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.find_nodes(
Expand Down
24 changes: 14 additions & 10 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
self.assertEqual(deq_scale, dequant_scale * mul_value)

def test_fuse_mul_into_quant(self) -> None:
quant_scale = 1.5
quant_scale = 5
mul_value = 10

builder = GraphBuilder()
Expand All @@ -613,7 +613,7 @@ def test_fuse_mul_into_quant(self) -> None:
)
quant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
args=(mul, quant_scale, 7, 0, 255, torch.uint8),
)
builder.output([quant])
original_graph = builder.get_graph_module()
Expand All @@ -631,14 +631,18 @@ def test_fuse_mul_into_quant(self) -> None:
)

# verify that the quant scale value was updated correctly
deq_scale = -1
for node in converted_graph.graph.nodes:
if (
node.target
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
):
deq_scale = node.args[1]
self.assertEqual(deq_scale, quant_scale * mul_value)
for node in converted_graph.graph.find_nodes(
op="call_function",
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
):
new_quant_scale = node.args[1]
self.assertEqual(new_quant_scale, quant_scale / mul_value)

# verify the math is correct
inp = torch.randn(4, 32, dtype=torch.float32)
original_out = original_graph(inp)[0]
new_out = converted_graph(inp)[0]
assert torch.equal(original_out, new_out)

def test_fuse_then_transpose_pass(self) -> None:
# Create a graph with full -> transpose.
Expand Down
Loading