2121
2222
2323def read_configs ( # noqa: C901: TODO: move modules into config
24- benchmarks : list [str ] = None ,
25- implementations : list [str ] = None ,
24+ benchmarks : set [str ] = None ,
25+ implementations : set [str ] = None ,
26+ no_dpbench : bool = False ,
27+ with_npbench : bool = False ,
28+ with_polybench : bool = False ,
29+ load_implementations : bool = True ,
2630) -> Config :
2731 """Read all configuration files and populate those settings into Config.
2832
@@ -53,11 +57,9 @@ def read_configs( # noqa: C901: TODO: move modules into config
5357 ),
5458 ]
5559
56- no_dpbench = os .getenv ("NO_DPBENCH" )
5760 if no_dpbench :
5861 modules [0 ].benchmark_configs_path = ""
5962
60- with_npbench = os .getenv ("WITH_NPBENCH" )
6163 if with_npbench :
6264 modules .append (
6365 Module (
@@ -69,7 +71,6 @@ def read_configs( # noqa: C901: TODO: move modules into config
6971 )
7072 )
7173
72- with_polybench = os .getenv ("WITH_POLYBENCH" )
7374 if with_polybench :
7475 modules .append (
7576 Module (
@@ -92,7 +93,7 @@ def read_configs( # noqa: C901: TODO: move modules into config
9293 benchmarks = benchmarks ,
9394 )
9495 if mod .framework_configs_path != "" :
95- read_frameworks (config , mod .framework_configs_path )
96+ read_frameworks (config , mod .framework_configs_path , implementations )
9697 if mod .precision_dtypes_path != "" :
9798 read_precision_dtypes (config , mod .precision_dtypes_path )
9899 if mod .path != "" :
@@ -102,13 +103,20 @@ def read_configs( # noqa: C901: TODO: move modules into config
102103 config .implementations += framework .postfixes
103104
104105 if implementations is None :
105- implementations = [ impl .postfix for impl in config .implementations ]
106+ implementations = { impl .postfix for impl in config .implementations }
106107
107- for benchmark in config .benchmarks :
108- read_benchmark_implementations (
109- benchmark ,
110- implementations ,
111- )
108+ if load_implementations :
109+ for benchmark in config .benchmarks :
110+ read_benchmark_implementations (
111+ benchmark ,
112+ implementations ,
113+ )
114+
115+ config .benchmarks = [
116+ benchmark
117+ for benchmark in config .benchmarks
118+ if len (benchmark .implementations ) > 0
119+ ]
112120
113121 return config
114122
@@ -118,7 +126,7 @@ def read_benchmarks(
118126 bench_info_dir : str ,
119127 recursive : bool = False ,
120128 parent_package : str = "dpbench.benchmarks" ,
121- benchmarks : list [str ] = None ,
129+ benchmarks : set [str ] = None ,
122130):
123131 """Read and populate benchmark configuration files.
124132
@@ -163,12 +171,20 @@ def read_benchmarks(
163171 config .benchmarks .append (benchmark )
164172
165173
166- def read_frameworks (config : Config , framework_info_dir : str ) -> None :
174+ def read_frameworks (
175+ config : Config ,
176+ framework_info_dir : str ,
177+ implementations : set [str ] = None ,
178+ ) -> None :
167179 """Read and populate framework configuration files.
168180
169181 Args:
170182 config: Configuration object where settings should be populated.
171183 framework_info_dir: Path to the directory with configuration files.
184+ implementations: Set of the implementations to load. If framework
185+ does not have any implementation from this list - it won't be
186+ loaded. If set None or empty - all frameworks/implementations get
187+ loaded.
172188 """
173189 for framework_info_file in os .listdir (framework_info_dir ):
174190 if not framework_info_file .endswith (".toml" ):
@@ -183,9 +199,20 @@ def read_frameworks(config: Config, framework_info_dir: str) -> None:
183199
184200 framework_info = tomli .loads (file_contents )
185201 framework_dict = framework_info .get ("framework" )
186- if framework_dict :
187- framework = Framework .from_dict (framework_dict )
188- config .frameworks .append (framework )
202+ if not framework_dict :
203+ continue
204+ framework = Framework .from_dict (framework_dict )
205+ if implementations :
206+ framework .postfixes = [
207+ postfix
208+ for postfix in framework .postfixes
209+ if postfix .postfix in implementations
210+ ]
211+
212+ if len (framework .postfixes ) == 0 :
213+ continue
214+
215+ config .frameworks .append (framework )
189216
190217
191218def read_implementation_postfixes (
@@ -249,8 +276,8 @@ def setup_init(config: Benchmark, modules: list[str]) -> None:
249276 impl_mod = importlib .import_module (config .init .package_path )
250277
251278 if not hasattr (impl_mod , config .init .func_name ):
252- print (
253- f"WARNING: could not find init function for { config .module_name } "
279+ logging . warn (
280+ f"could not find init function for { config .module_name } "
254281 )
255282
256283
@@ -284,7 +311,7 @@ def discover_module_name_and_postfix(module: str, config: Config):
284311
285312def read_benchmark_implementations (
286313 config : Benchmark ,
287- implementations : list [str ] = None ,
314+ implementations : set [str ] = None ,
288315) -> None :
289316 """Read and discover implementation modules and functions.
290317
@@ -318,8 +345,10 @@ def read_benchmark_implementations(
318345 for module in modules :
319346 module_name , postfix = discover_module_name_and_postfix (module , config )
320347
321- if (postfix not in implementations ) or (
322- config .init and config .init .module_name .endswith (module_name )
348+ if (
349+ implementations
350+ and (postfix not in implementations )
351+ or (config .init and config .init .module_name .endswith (module_name ))
323352 ):
324353 continue
325354
0 commit comments