@@ -856,19 +856,32 @@ class FuseMulTensorIntoQuantPass(ExportPass):
856
856
def attempt_fusion (
857
857
self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
858
858
) -> None :
859
- full_nodes = [
860
- arg
861
- for arg in mul_node .args
862
- if isinstance (arg , torch .fx .Node )
863
- and arg .target == exir_ops .edge .aten .full .default
864
- ]
859
+ if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
860
+ return
861
+
862
+ first_arg = cast (torch .fx .Node , mul_node .args [0 ])
863
+ second_arg = cast (torch .fx .Node , mul_node .args [1 ])
864
+
865
+ input_node = first_arg
866
+ full_node = second_arg
867
+ if second_arg .target == exir_ops .edge .aten .full .default :
868
+ # Most common case, nothing to change.
869
+ pass
870
+ elif first_arg .target == exir_ops .edge .aten .full .default :
871
+ # Input and full nodes are swapped.
872
+ full_node = first_arg
873
+ input_node = second_arg
874
+ else :
875
+ # Full node is not found, skip.
876
+ return
865
877
866
- if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
878
+ # Ensure that the mul op does not do any broadcasting.
879
+ if input_node .meta ["val" ].shape != mul_node .meta ["val" ].shape :
867
880
return
868
881
869
- full_node = full_nodes [0 ]
870
882
mul_user = list (mul_node .users .keys ())[0 ]
871
883
884
+ # Ensure only the expected quant ops are using the current mul op.
872
885
if mul_user .target not in {
873
886
exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
874
887
exir_ops .edge .cadence .quantize_per_tensor .default ,
@@ -878,33 +891,28 @@ def attempt_fusion(
878
891
quant_node = mul_user
879
892
880
893
# Calculate the new scale value.
881
- prev_scale = quant_node .args [1 ]
882
- assert isinstance (prev_scale , (int , float ))
894
+ old_scale = quant_node .args [1 ]
895
+ assert isinstance (old_scale , (int , float ))
883
896
mul_scalar = full_node .args [1 ]
884
897
assert isinstance (mul_scalar , (int , float ))
885
- new_scale = float (prev_scale ) * float (mul_scalar )
898
+ """ The reason why we divide old scale by the mul value to get a new scale:
899
+ y = x * mul_scalar
900
+ q = zp + y / old_scale
901
+ q = zp + x * mul_scalar / old_scale
902
+ new_scale = old_scale / mul_scalar
903
+ q = zp + x / new_scale
904
+ """
905
+ new_scale = float (old_scale ) / float (mul_scalar )
886
906
887
907
logging .debug (
888
908
f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
889
909
)
890
910
891
- # Replace the input first
892
- quant_node .replace_input_with (
893
- cast (torch .fx .Node , quant_node .args [0 ]),
894
- cast (torch .fx .Node , mul_node .args [0 ]),
895
- )
896
-
897
- # Now update the scale in the args
898
- new_quant_args = list (quant_node .args )
899
- new_quant_args [1 ] = new_scale
900
- quant_node .args = tuple (new_quant_args )
901
-
902
- # Clean up the mul_node
903
- mul_node .args = ()
904
- mul_node .users = {}
905
-
906
- graph_module .graph .erase_node (mul_node )
907
- graph_module .graph .erase_node (full_node )
911
+ # Update quant node input and scale.
912
+ old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
913
+ new_quant_input = cast (torch .fx .Node , mul_node .args [0 ])
914
+ quant_node .replace_input_with (old_quant_input , new_quant_input )
915
+ quant_node .update_arg (1 , new_scale )
908
916
909
917
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
910
918
for node in graph_module .graph .find_nodes (
0 commit comments