@@ -856,19 +856,23 @@ class FuseMulTensorIntoQuantPass(ExportPass):
856
856
def attempt_fusion (
857
857
self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
858
858
) -> None :
859
- full_nodes = [
860
- arg
861
- for arg in mul_node .args
862
- if isinstance (arg , torch .fx .Node )
863
- and arg .target == exir_ops .edge .aten .full .default
864
- ]
859
+ if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
860
+ return
861
+
862
+ second_arg = cast (torch .fx .Node , mul_node .args [1 ])
863
+ input_index = 0 if second_arg .target == exir_ops .edge .aten .full .default else 1
865
864
866
- if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
865
+ input_node = cast (torch .fx .Node , mul_node .args [input_index ])
866
+ full_node = cast (torch .fx .Node , mul_node .args [1 - input_index ])
867
+ output_node = list (mul_node .users .keys ())[0 ]
868
+
869
+ # Ensure that the mul op does not do any broadcasting.
870
+ if input_node .meta ["val" ].shape != output_node .meta ["val" ].shape :
867
871
return
868
872
869
- full_node = full_nodes [0 ]
870
873
mul_user = list (mul_node .users .keys ())[0 ]
871
874
875
+ # Ensure only the expected quant ops are using the current mul op.
872
876
if mul_user .target not in {
873
877
exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
874
878
exir_ops .edge .cadence .quantize_per_tensor .default ,
@@ -878,33 +882,27 @@ def attempt_fusion(
878
882
quant_node = mul_user
879
883
880
884
# Calculate the new scale value.
881
- prev_scale = quant_node .args [1 ]
882
- assert isinstance (prev_scale , (int , float ))
885
+ old_scale = quant_node .args [1 ]
886
+ assert isinstance (old_scale , (int , float ))
883
887
mul_scalar = full_node .args [1 ]
884
888
assert isinstance (mul_scalar , (int , float ))
885
- new_scale = float (prev_scale ) * float (mul_scalar )
889
+ # The reason why we divide old scale by the mul value to get a new scale:
890
+ # y = x * mul_scalar
891
+ # q = zp + y / old_scale
892
+ # q = zp + x * mul_scalar / old_scale
893
+ # new_scale = old_scale / mul_scalar
894
+ # q = zp + x / new_scale
895
+ new_scale = float (old_scale ) / float (mul_scalar )
886
896
887
897
logging .debug (
888
898
f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
889
899
)
890
900
891
- # Replace the input first
892
- quant_node .replace_input_with (
893
- cast (torch .fx .Node , quant_node .args [0 ]),
894
- cast (torch .fx .Node , mul_node .args [0 ]),
895
- )
896
-
897
- # Now update the scale in the args
898
- new_quant_args = list (quant_node .args )
899
- new_quant_args [1 ] = new_scale
900
- quant_node .args = tuple (new_quant_args )
901
-
902
- # Clean up the mul_node
903
- mul_node .args = ()
904
- mul_node .users = {}
905
-
906
- graph_module .graph .erase_node (mul_node )
907
- graph_module .graph .erase_node (full_node )
901
+ # Update quant node input and scale.
902
+ old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
903
+ new_quant_input = cast (torch .fx .Node , mul_node .args [0 ])
904
+ quant_node .replace_input_with (old_quant_input , new_quant_input )
905
+ quant_node .update_arg (1 , new_scale )
908
906
909
907
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
910
908
for node in graph_module .graph .find_nodes (
0 commit comments