1- #!/usr/bin/env fbpython
1+ #!/usr/bin/env python3
22# Copyright (c) Meta Platforms, Inc. and affiliates.
33# All rights reserved.
44#
55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
88import argparse
9+ import re
910import os
1011import sys
11- from typing import Any , List
12+ from functools import reduce
13+ from typing import Any , List , Tuple
1214
13- from tools_copy .code_analyzer import gen_oplist_copy_from_core
15+ import yaml
16+ from torchgen .selective_build .selector import (
17+ combine_selective_builders ,
18+ SelectiveBuilder ,
19+ )
20+ from pathlib import Path
1421
1522
23+ def throw_if_any_op_includes_overloads (selective_builder : SelectiveBuilder ) -> None :
24+ ops = []
25+ for op_name , op in selective_builder .operators .items ():
26+ if op .include_all_overloads :
27+ ops .append (op_name )
28+ if ops :
29+ raise Exception ( # noqa: TRY002
30+ (
31+ "Operators that include all overloads are "
32+ + "not allowed since --allow-include-all-overloads "
33+ + "was not specified: {}"
34+ ).format (", " .join (ops ))
35+ )
36+
37+ def resolve_model_file_path_to_buck_target (model_file_path : str ) -> str :
38+ real_path = str (Path (model_file_path ).resolve (strict = True ))
39+ # try my best to convert to buck target
40+ prog = re .compile (r"/.*/buck-out/.*/(fbsource|fbcode)/[0-9a-f]*/(.*)/__(.*)_et_oplist__/out/selected_operators.yaml" )
41+ match = prog .match (real_path )
42+ if match :
43+ return f"{ match .group (1 )} //{ match .group (2 )} :{ match .group (3 )} "
44+ else :
45+ return real_path
46+
1647def main (argv : List [Any ]) -> None :
17- """This binary is a wrapper for //executorch/codegen/tools/gen_oplist_copy_from_core.py.
18- This is needed because we intend to error out for the case where `model_file_list_path`
19- is empty or invalid, so that the ExecuTorch build will fail when no selective build target
20- is provided as a dependency to ExecuTorch build.
48+ """This binary generates 3 files:
49+
50+ 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
51+ dtypes captured by tracing
52+ 2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
2153 """
2254 parser = argparse .ArgumentParser (description = "Generate operator lists" )
2355 parser .add_argument (
56+ "--output-dir" ,
2457 "--output_dir" ,
2558 help = ("The directory to store the output yaml file (selected_operators.yaml)" ),
2659 required = True ,
2760 )
2861 parser .add_argument (
62+ "--model-file-list-path" ,
2963 "--model_file_list_path" ,
3064 help = (
3165 "Path to a file that contains the locations of individual "
@@ -36,6 +70,7 @@ def main(argv: List[Any]) -> None:
3670 required = True ,
3771 )
3872 parser .add_argument (
73+ "--allow-include-all-overloads" ,
3974 "--allow_include_all_overloads" ,
4075 help = (
4176 "Flag to allow operators that include all overloads. "
@@ -46,26 +81,99 @@ def main(argv: List[Any]) -> None:
4681 default = False ,
4782 required = False ,
4883 )
84+ parser .add_argument (
85+ "--check-ops-not-overlapping" ,
86+ "--check_ops_not_overlapping" ,
87+ help = (
88+ "Flag to check if the operators in the model file list are overlapping. "
89+ + "If not set, the script will not error out for overlapping operators."
90+ ),
91+ action = "store_true" ,
92+ default = False ,
93+ required = False ,
94+ )
95+ options = parser .parse_args (argv )
96+
4997
50- # check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
98+ # Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
5199 # 1. a yaml file containing selected ops (could be empty), or
52- # 2. a non-empty list of yaml files in the `model_file_list_path`.
53- # If none of the two things happened, the build target has no dependency on any selective build and we should error out .
54- options = parser . parse_args ( argv )
100+ # 2. a non-empty list of yaml files in the `model_file_list_path` or
101+ # 3. a non-empty list of directories in the `model_file_list_path`, with each directory containing a `selected_operators.yaml` file .
102+ # If none of the 3 things happened, the build target has no dependency on any selective build and we should error out.
55103 if os .path .isfile (options .model_file_list_path ):
56- pass
104+ print ("Processing model file: " , options .model_file_list_path )
105+ model_dicts = []
106+ model_dict = yaml .safe_load (open (options .model_file_list_path ))
107+ model_dicts .append (model_dict )
57108 else :
58- assert (
59- options .model_file_list_path [0 ] == "@"
60- ), "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
109+ print ( "Processing model directory: " , options . model_file_list_path )
110+ assert options .model_file_list_path [0 ] == "@" , "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue. "
111+
61112 model_file_list_path = options .model_file_list_path [1 :]
113+
114+ model_dicts = []
62115 with open (model_file_list_path ) as model_list_file :
63116 model_file_names = model_list_file .read ().split ()
64117 assert (
65118 len (model_file_names ) > 0
66119 ), "BUCK was not able to find any `et_operator_library` in the dependency graph of the current ExecuTorch "
67120 "build. Please refer to Selective Build wiki page to add at least one."
68- gen_oplist_copy_from_core .main (argv )
121+ for model_file_name in model_file_names :
122+ if not os .path .isfile (model_file_name ):
123+ model_file_name = os .path .join (model_file_name , "selected_operators.yaml" )
124+ print ("Processing model file: " , model_file_name )
125+ assert os .path .isfile (model_file_name ), f"{ model_file_name } is not a valid file path. This is likely a BUCK issue."
126+ with open (model_file_name , "rb" ) as model_file :
127+ model_dict = yaml .safe_load (model_file )
128+ resolved = resolve_model_file_path_to_buck_target (model_file_name )
129+ for op in model_dict ["operators" ]:
130+ model_dict ["operators" ][op ]["debug_info" ] = [resolved ]
131+ model_dicts .append (model_dict )
132+
133+ selective_builders = [SelectiveBuilder .from_yaml_dict (m ) for m in model_dicts ]
134+
135+ # Optionally check if the operators in the model file list are overlapping.
136+ if options .check_ops_not_overlapping :
137+ ops = {}
138+ for model_dict in model_dicts :
139+ for op_name in model_dict ["operators" ]:
140+ if op_name in ops :
141+ debug_info_1 = ',' .join (ops [op_name ]["debug_info" ])
142+ debug_info_2 = ',' .join (model_dict ["operators" ][op_name ]["debug_info" ])
143+ error = f"Operator { op_name } is used in 2 models: { debug_info_1 } and { debug_info_2 } "
144+ if "//" not in debug_info_1 and "//" not in debug_info_2 :
145+ error += "\n We can't determine what BUCK targets these model files belong to."
146+ tail = "."
147+ else :
148+ error += "\n Please run the following commands to find out where is the BUCK target being added as a dependency to your target:\n "
149+ error += f"\n buck2 cquery <mode> \" allpaths(<target>, { debug_info_1 } )\" "
150+ error += f"\n buck2 cquery <mode> \" allpaths(<target>, { debug_info_2 } )\" "
151+ tail = "as well as results from BUCK commands listed above."
152+
153+ error += "\n \n If issue is not resolved, please post in PyTorch Edge Q&A with this error message" + tail
154+ raise Exception (error ) # noqa: TRY002
155+ ops [op_name ] = model_dict ["operators" ][op_name ]
156+ # We may have 0 selective builders since there may not be any viable
157+ # pt_operator_library rule marked as a dep for the pt_operator_registry rule.
158+ # This is potentially an error, and we should probably raise an assertion
159+ # failure here. However, this needs to be investigated further.
160+ selective_builder = SelectiveBuilder .from_yaml_dict ({})
161+ if len (selective_builders ) > 0 :
162+ selective_builder = reduce (
163+ combine_selective_builders ,
164+ selective_builders ,
165+ )
166+
167+ if not options .allow_include_all_overloads :
168+ throw_if_any_op_includes_overloads (selective_builder )
169+ with open (
170+ os .path .join (options .output_dir , "selected_operators.yaml" ), "wb"
171+ ) as out_file :
172+ out_file .write (
173+ yaml .safe_dump (
174+ selective_builder .to_dict (), default_flow_style = False
175+ ).encode ("utf-8" ),
176+ )
69177
70178
71179if __name__ == "__main__" :
0 commit comments