2121)
2222from executorch .exir .dialects ._ops import ops as exir_ops
2323from executorch .exir .graph_module import get_control_flow_submodules
24+ from torch ._export .utils import is_buffer , is_lifted_tensor_constant , is_param
2425from torch .export .exported_program import ExportedProgram
2526from torch .fx .passes .operator_support import any_chain , OperatorSupportBase
2627
2728
29+ def is_param_node (exp_prog : ExportedProgram , node : torch .fx .Node ) -> bool :
30+ return (
31+ is_param (exp_prog , node )
32+ or is_buffer (exp_prog , node )
33+ or is_lifted_tensor_constant (exp_prog , node )
34+ )
35+
36+
37+ def get_total_num_ops_in_ep (edge_programs , supported_ops ):
38+ total_number_of_ops = 0
39+ for edge_program in edge_programs .values ():
40+ for partitioned_program in edge_program :
41+ for node in partitioned_program .graph .nodes :
42+ if node .op == "call_function" :
43+ if node .target in supported_ops :
44+ total_number_of_ops += 1
45+ return total_number_of_ops
46+
47+
2848def _preprocess_multimethod (
2949 edge_programs : Dict [str , List [ExportedProgram ]],
3050 compile_specs : Dict [str , List [List [CompileSpec ]]],
@@ -37,13 +57,7 @@ def _preprocess_multimethod(
3757 in testing for a partitioner which tags different partitions for different backends
3858 to be lowered to
3959 """
40- total_number_of_ops = 0
41- for edge_program in edge_programs .values ():
42- for partitioned_program in edge_program :
43- for node in partitioned_program .graph .nodes :
44- if node .op == "call_function" :
45- if node .target in supported_ops :
46- total_number_of_ops += 1
60+ total_number_of_ops = get_total_num_ops_in_ep (edge_programs , supported_ops )
4761 all_processed_results = {key : [] for key in edge_programs .keys ()}
4862
4963 for method_name , partitioned_programs in edge_programs .items ():
@@ -67,6 +81,8 @@ def _preprocess_multimethod(
6781 raise RuntimeError (
6882 f"{ node .op } { node .target .__name__ } is not supported in backend { backend_name } "
6983 )
84+ if is_param_node (partitioned_program , node ):
85+ processed_bytes += f"CONST{ node .name } :"
7086
7187 processed_bytes += "#"
7288 for cs in compile_spec_for_partition :
@@ -171,14 +187,30 @@ def preprocess_multimethod(
171187
172188
173189class AddSinOperatorSupport (OperatorSupportBase ):
190+ def __init__ (self , original_program ):
191+ self .original_program = original_program
192+ super ().__init__ ()
193+
174194 def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
175- return node . op == "call_function" and node . target in [
195+ supported_targets = [
176196 exir_ops .edge .aten .add .Tensor ,
177197 exir_ops .edge .aten .sin .default ,
178198 ]
199+ if node .op == "call_function" and node .target in supported_targets :
200+ return True
201+
202+ if node .op == "placeholder" and is_param_node (self .original_program , node ):
203+ for user in node .users .keys ():
204+ if user .target in supported_targets :
205+ return True
206+ return False
179207
180208
181209class SubCosOperatorSupport (OperatorSupportBase ):
210+ def __init__ (self , original_program ):
211+ self .original_program = original_program
212+ super ().__init__ ()
213+
182214 def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
183215 return node .op == "call_function" and node .target in [
184216 exir_ops .edge .aten .sub .Tensor ,
@@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner):
199231 """
200232
201233 def __init__ (self ) -> None :
202- self .add_sin_support = any_chain (AddSinOperatorSupport ())
203- self .add_sin_backend_id = FirstBackendWithPreprocessAll .__name__
204-
205- self .sub_cos_support = any_chain (SubCosOperatorSupport ())
206234 self .sub_cos_backend_id = SecondBackendWithPreprocessAll .__name__
235+ self .add_sin_backend_id = FirstBackendWithPreprocessAll .__name__
207236
208237 def _partition_graph_module (
209238 self ,
@@ -260,6 +289,8 @@ def _partition_graph_module(
260289 return partition_tags , start_idx_for_submodules
261290
262291 def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
292+ self .add_sin_support = any_chain (AddSinOperatorSupport (exported_program ))
293+ self .sub_cos_support = any_chain (SubCosOperatorSupport (exported_program ))
263294 partition_tags , _ = self ._partition_graph_module (exported_program .graph_module )
264295 return PartitionResult (
265296 tagged_exported_program = exported_program , partition_tags = partition_tags
0 commit comments