11import logging
2+ import operator
23from typing import Callable , List , Optional , Set , Tuple
34
45import torch
@@ -33,8 +34,7 @@ def __repr__(self):
3334
3435
3536class ComplexOpDetector :
36- def __init__ (self , logger ):
37- self .logger = logger
37+ def __init__ (self ):
3838 pass
3939
4040 def is_complex_dtype (self , node : Node ) -> bool :
@@ -45,15 +45,13 @@ def is_complex_dtype(self, node: Node) -> bool:
4545 if hasattr (val , "dtype" ):
4646 dtype = val .dtype
4747
48- self . logger .debug (f"dtype of node: { dtype } " )
48+ logger .debug (f"dtype of node: { dtype } " )
4949 return dtype in {torch .complex64 , torch .complex128 }
5050
5151 def node_include_in_subgraph (self , node : Node ) -> bool :
5252 # Include only call_function ops on complex tensors
53- self .logger .debug (f"node.op: { node .op } , node name: { node .name } " )
54- self .logger .debug (f"is_complex_dtype: { self .is_complex_dtype (node )} " )
5553 if node .op == "call_function" and self .is_complex_dtype (node ):
56- self . logger .debug (
54+ logger .debug (
5755 f"node.op is added to subgraph: { node .op } , node name: { node .name } is complex"
5856 )
5957 return node .op == "call_function" and self .is_complex_dtype (node )
@@ -67,7 +65,7 @@ def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo:
6765 if n in subgraph_nodes :
6866 continue
6967 subgraph_nodes .add (n )
70- self . logger .debug (f"node { n .name } is added to subgraph" )
68+ logger .debug (f"node { n .name } is added to subgraph" )
7169 for inp in n .all_input_nodes :
7270 if self .node_include_in_subgraph (inp ):
7371 print ("node inp is added to stack:" , inp .name )
@@ -85,13 +83,12 @@ def find_complex_op_subgraphs(
8583 complex_op_subgraphs : List [ComplexSubGraphInfo ] = []
8684 for node in gm .graph .nodes :
8785 if node .target == anchor_target :
88- self .logger .debug (f"node.target { node .target } node.name: { node .name } " )
8986 new_sub = self .subgraph_from_anchor (node )
9087 # if any intersecting nodes between seen and sub.subgraph_nodes they should be merged
9188 merged = False
9289 for existing_sub in complex_op_subgraphs :
9390 if set (existing_sub .subgraph_nodes ) & set (new_sub .subgraph_nodes ):
94- self . logger .debug (f"merging subgraphs { existing_sub } { new_sub } " )
91+ logger .debug (f"merging subgraphs { existing_sub } { new_sub } " )
9592 # merge the two subgraphs
9693 existing_sub .subgraph_nodes = list (
9794 set (existing_sub .subgraph_nodes )
@@ -113,7 +110,7 @@ def find_complex_op_subgraphs(
113110def complex_graph_detection (
114111 gm : GraphModule , settings : CompilationSettings
115112) -> List [ComplexSubGraphInfo ]:
116- complex_op_detector = ComplexOpDetector (logger )
113+ complex_op_detector = ComplexOpDetector ()
117114 complex_subgraphs = complex_op_detector .find_complex_op_subgraphs (
118115 gm , anchor_target = torch .ops .aten .view_as_real .default
119116 )
@@ -174,17 +171,24 @@ def replace_input_node(self, input_node):
174171
175172 elif input_node .op == "get_attr" :
176173 new_attr_name = input_node .target + "_reshaped"
177- original_tensor = self .get_attr_tensor (input_node .target )
178- stacked_tensor = torch .stack (
179- [original_tensor .real , original_tensor .imag ], dim = - 1
180- )
181- self .gm .register_buffer (new_attr_name , stacked_tensor )
174+ from torch ._subclasses .fake_tensor import unset_fake_temporarily
175+
176+ with unset_fake_temporarily ():
177+ original_tensor = self .get_attr_tensor (input_node .target )
178+ stacked_tensor = torch .stack (
179+ [original_tensor .real , original_tensor .imag ], dim = - 1
180+ )
181+ self .gm .register_buffer (new_attr_name , stacked_tensor )
182182 with self .gm .graph .inserting_after (input_node ):
183183 new_node = self .gm .graph .get_attr (new_attr_name )
184184
185185 else :
186- logger .debug (f"Unsupported node type: { input_node .op } " )
187- logger .debug ("This node type does not need to replaced" )
186+ logger .debug (
187+ f"Unsupported node type in replacement of input node: { input_node .op } "
188+ )
189+ logger .debug (
190+ "This complex subgraph inputnode type does not need to replaced"
191+ )
188192
189193 input_node .replace_all_uses_with (new_node )
190194 self .gm .graph .erase_node (input_node )
@@ -211,6 +215,8 @@ def rewrite_subgraph_nodes(self, subgraphs):
211215
212216 def match_complex_mul (
213217 match : torch .fx .subgraph_rewriter .Match ,
218+ original_graph ,
219+ pattern_graph ,
214220 ) -> bool :
215221 for original_node in match .nodes_map .values ():
216222 if original_node .name == node .name :
@@ -230,10 +236,9 @@ def match_complex_mul(
230236 self .gm .graph .erase_node (node )
231237 else :
232238 logger .debug (f"Unsupported node target: { node .target } " )
233- logger .debug (f"This node type does not need to replaced" )
234- if modified :
235- self .gm .graph .lint ()
236- self .gm .recompile ()
239+ logger .debug (
240+ "This complex subgraphnode type does not need to replaced"
241+ )
237242
238243 if modified :
239244 self .gm .graph .lint ()
@@ -256,16 +261,28 @@ def complex_mul_replacement() -> Tuple[
256261
257262 # Original pattern: torch.mul for complex tensors
258263 def original_mul (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
259- return torch .mul (x , y )
264+ return torch .ops . aten . mul . Tensor (x , y )
260265
261266 # Replacement function: manual complex multiplication on real/imag stacked tensors
262267 def replacement (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
263- x_real , x_imag = x [..., 0 ], x [..., 1 ]
264- y_real , y_imag = y [..., 0 ], y [..., 1 ]
265-
266- real = x_real * y_real - x_imag * y_imag
267- imag = x_real * y_imag + x_imag * y_real
268-
269- return torch .stack ((real , imag ), dim = - 1 )
268+ x_real = torch .ops .aten .select .int (x , - 1 , 0 )
269+ x_imag = torch .ops .aten .select .int (x , - 1 , 1 ) # x is reshape tensor
270+ y_real , y_imag = y [..., 0 ], y [..., 1 ] # y is frozen param
271+
272+ real_part1 = torch .ops .aten .mul .Tensor (x_real , y_real )
273+ real_part2 = torch .ops .aten .mul .Tensor (x_imag , y_imag )
274+ real = torch .ops .aten .sub .Tensor (real_part1 , real_part2 )
275+
276+ imag_part1 = torch .ops .aten .mul .Tensor (x_real , y_imag )
277+ imag_part2 = torch .ops .aten .mul .Tensor (x_imag , y_real )
278+ imag = torch .ops .aten .add .Tensor (imag_part1 , imag_part2 )
279+
280+ return torch .ops .aten .cat .default (
281+ [
282+ torch .ops .aten .unsqueeze .default (real , - 1 ),
283+ torch .ops .aten .unsqueeze .default (imag , - 1 ),
284+ ],
285+ dim = - 1 ,
286+ )
270287
271288 return (original_mul , replacement )
0 commit comments