@@ -105,24 +105,12 @@ def find_complex_op_subgraphs(
105105 return complex_op_subgraphs
106106
107107
108- def complex_graph_detection (
109- gm : GraphModule , settings : CompilationSettings
110- ) -> List [ComplexSubGraphInfo ]:
111- complex_op_detector = ComplexOpDetector ()
112- complex_subgraphs = complex_op_detector .find_complex_op_subgraphs (
113- gm , anchor_target = torch .ops .aten .view_as_real .default
114- )
115- complex_graph_rewriter = ComplexGraphRewriter (gm , settings .truncate_double )
116- complex_graph_rewriter .rewrite_subgraph_nodes (complex_subgraphs )
117- return gm
118-
119-
120108class ComplexGraphRewriter :
121109 def __init__ (self , gm : GraphModule , truncate_double : bool = False ):
122110 self .gm = gm
123111 self .truncate_double = truncate_double
124112
125- def extract_shape_and_dtype (self , input_node ):
113+ def extract_shape_dtype_device (self , input_node ):
126114 if input_node .op == "placeholder" :
127115 tensor_val = input_node .meta ["val" ]
128116
@@ -135,6 +123,7 @@ def extract_shape_and_dtype(self, input_node):
135123 node_shape = tensor_val .size ()
136124 dtype = tensor_val .dtype
137125 new_node_shape = node_shape + (2 ,)
126+ device = tensor_val .device
138127
139128 if dtype == torch .complex64 :
140129 new_node_dtype = torch .float32
@@ -143,7 +132,7 @@ def extract_shape_and_dtype(self, input_node):
143132 else :
144133 new_node_dtype = torch .float64
145134
146- return new_node_shape , new_node_dtype
135+ return new_node_shape , new_node_dtype , device
147136
148137 def get_attr_tensor (self , target ):
149138 # Check if target is param or buffer
@@ -157,14 +146,16 @@ def get_attr_tensor(self, target):
157146 )
158147
159148 def replace_input_node (self , input_node ):
149+ modified = False
160150 logger .debug (f"Replacing input node: { input_node .name } " )
161- new_shape , new_dtype = self .extract_shape_and_dtype (input_node )
162- real_tensor = torch .empty (new_shape , dtype = new_dtype )
151+ new_shape , new_dtype , device = self .extract_shape_dtype_device (input_node )
152+ real_tensor = torch .empty (new_shape , dtype = new_dtype , device = device )
163153
164154 if input_node .op == "placeholder" :
165155 with FakeTensorMode () as fake_mode :
166156 fake_tensor = fake_mode .from_tensor (real_tensor )
167- new_node = self .gm .graph .placeholder (input_node .target + "_reshaped" )
157+ with self .gm .graph .inserting_before (input_node ):
158+ new_node = self .gm .graph .placeholder (input_node .target + "_reshaped" )
168159 new_node .meta ["val" ] = fake_tensor
169160
170161 elif input_node .op == "get_attr" :
@@ -179,17 +170,16 @@ def replace_input_node(self, input_node):
179170 self .gm .register_buffer (new_attr_name , stacked_tensor )
180171 with self .gm .graph .inserting_after (input_node ):
181172 new_node = self .gm .graph .get_attr (new_attr_name )
182-
183173 else :
184174 logger .debug (
185175 f"Unsupported node type in replacement of input node: { input_node .op } "
186176 )
187177 logger .debug (
188178 "This complex subgraph inputnode type does not need to replaced"
189179 )
190-
191180 input_node .replace_all_uses_with (new_node )
192181 self .gm .graph .erase_node (input_node )
182+ clean_up_graph_after_modifications (self .gm )
193183
194184 def rewrite_subgraph_nodes (self , subgraphs ):
195185 modified = False
@@ -198,7 +188,6 @@ def rewrite_subgraph_nodes(self, subgraphs):
198188 logger .debug (f"Input node rewrite: { input_node .name } " )
199189 if input_node .op not in ("call_function" ):
200190 self .replace_input_node (input_node )
201- modified = True
202191 for node in subgraph .subgraph_nodes :
203192 logger .debug (f"Subgraph Node rewrite: { node .name } " )
204193 if node .target == torch .ops .aten .view_as_complex .default :
@@ -229,6 +218,7 @@ def match_complex_mul(
229218 ignore_literals = True ,
230219 )
231220 replaced_nodes += nodes
221+ modified = True
232222 elif node .target == torch .ops .aten .view_as_real .default :
233223 node .replace_all_uses_with (node .args [0 ])
234224 self .gm .graph .erase_node (node )
@@ -239,9 +229,36 @@ def match_complex_mul(
239229 )
240230
241231 if modified :
232+ self .propagate_metadata ()
242233 self .gm .graph .lint ()
243234 self .gm .recompile ()
244235
236+ def propagate_metadata (self ):
237+ fake_inputs = []
238+ from torch ._subclasses .fake_tensor import FakeTensorMode
239+ from torch .fx .passes .fake_tensor_prop import FakeTensorProp
240+
241+ for node in self .gm .graph .nodes :
242+ if node .op == "placeholder" :
243+ if "val" in node .meta :
244+ with FakeTensorMode (allow_non_fake_inputs = True ):
245+ fake_val = node .meta ["val" ]
246+ fake_inputs .append (
247+ fake_val .to ("cuda" )
248+ if fake_val .device .type == "cuda"
249+ else fake_val
250+ )
251+ else :
252+ fake_tensor = torch .empty (
253+ [s if s != 0 else 1 for s in node .meta ["tensor_meta" ].shape ],
254+ dtype = node .meta ["tensor_meta" ].dtype ,
255+ device = node .meta ["tensor_meta" ].device ,
256+ )
257+ fake_inputs .append (fake_tensor )
258+ FakeTensorProp (
259+ self .gm , mode = FakeTensorMode (allow_non_fake_inputs = True )
260+ ).propagate (* fake_inputs )
261+
245262
246263def complex_mul_replacement () -> Tuple [
247264 Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ],
@@ -280,7 +297,23 @@ def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
280297 torch .ops .aten .unsqueeze .default (real , - 1 ),
281298 torch .ops .aten .unsqueeze .default (imag , - 1 ),
282299 ],
283- dim = - 1 ,
300+ - 1 ,
284301 )
285302
286303 return (original_mul , replacement )
304+
305+
306+ # This lowering pass is used to detect and rewrite complex subgraphs in the graph
307+ # This lowering pass works for complex tensor in mul which are parameter or buffers in the graph
308+ def complex_graph_detection (
309+ gm : GraphModule , settings : CompilationSettings
310+ ) -> List [ComplexSubGraphInfo ]:
311+ complex_op_detector = ComplexOpDetector ()
312+ complex_subgraphs = complex_op_detector .find_complex_op_subgraphs (
313+ gm , anchor_target = torch .ops .aten .view_as_real .default
314+ )
315+ for subgraph in complex_subgraphs :
316+ logger .debug (f"Complex subgraph info: { subgraph } " )
317+ complex_graph_rewriter = ComplexGraphRewriter (gm , settings .truncate_double )
318+ complex_graph_rewriter .rewrite_subgraph_nodes (complex_subgraphs )
319+ return gm
0 commit comments