@@ -35,6 +35,17 @@ def test_function_eligible_for_optimization() -> None:
3535 assert len (functions_found [Path (f .name )]) == 0
3636
3737
38+ # we want to trigger an error in the function discovery
39+ function = """def test_invalid_code():"""
40+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
41+ f .write (function )
42+ f .flush ()
43+ functions_found = find_all_functions_in_file (Path (f .name ))
44+ assert functions_found == {}
45+
46+
47+
48+
3849def test_find_top_level_function_or_method ():
3950 with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
4051 f .write (
@@ -82,6 +93,15 @@ def non_classmethod_function(cls, name):
8293 ).is_top_level
8394 # needed because this will be traced with a class_name being passed
8495
96+ # we want to write invalid code to ensure that the function discovery does not crash
97+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
98+ f .write (
99+ """def functionA():
100+ """
101+ )
102+ f .flush ()
103+ path_obj_name = Path (f .name )
104+ assert not inspect_top_level_functions_or_methods (path_obj_name , "functionA" )
85105
86106def test_class_method_discovery ():
87107 with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
@@ -152,6 +172,133 @@ def functionA():
152172 assert functions [file ][0 ].function_name == "functionA"
153173
154174
175+ def test_nested_function ():
176+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
177+ f .write (
178+ """
179+ import copy
180+
181+ def propagate_attributes(
182+ nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
183+ ) -> dict[str, dict]:
184+ modified_nodes = copy.deepcopy(nodes)
185+
186+ # Build an adjacency list for faster traversal
187+ adjacency = {}
188+ for edge in edges:
189+ src = edge["source"]
190+ tgt = edge["target"]
191+ if src not in adjacency:
192+ adjacency[src] = []
193+ adjacency[src].append(tgt)
194+
195+ # Track visited nodes to avoid cycles
196+ visited = set()
197+
198+ def traverse(node_id):
199+ if node_id in visited:
200+ return
201+ visited.add(node_id)
202+
203+ # Propagate attribute from source node
204+ if (
205+ node_id != source_node_id
206+ and source_node_id in modified_nodes
207+ and attribute in modified_nodes[source_node_id]
208+ ):
209+ if node_id in modified_nodes:
210+ modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
211+ attribute
212+ ]
213+
214+ # Continue propagation to neighbors
215+ for neighbor in adjacency.get(node_id, []):
216+ traverse(neighbor)
217+
218+ traverse(source_node_id)
219+ return modified_nodes
220+ """
221+ )
222+ f .flush ()
223+ test_config = TestConfig (
224+ tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
225+ )
226+ path_obj_name = Path (f .name )
227+ functions , functions_count = get_functions_to_optimize (
228+ optimize_all = None ,
229+ replay_test = None ,
230+ file = path_obj_name ,
231+ test_cfg = test_config ,
232+ only_get_this_function = None ,
233+ ignore_paths = [Path ("/bruh/" )],
234+ project_root = path_obj_name .parent ,
235+ module_root = path_obj_name .parent ,
236+ )
237+
238+ assert len (functions ) == 1
239+ assert functions_count == 1
240+
241+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
242+ f .write (
243+ """
244+ def outer_function():
245+ def inner_function():
246+ pass
247+
248+ return inner_function
249+ """
250+ )
251+ f .flush ()
252+ test_config = TestConfig (
253+ tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
254+ )
255+ path_obj_name = Path (f .name )
256+ functions , functions_count = get_functions_to_optimize (
257+ optimize_all = None ,
258+ replay_test = None ,
259+ file = path_obj_name ,
260+ test_cfg = test_config ,
261+ only_get_this_function = None ,
262+ ignore_paths = [Path ("/bruh/" )],
263+ project_root = path_obj_name .parent ,
264+ module_root = path_obj_name .parent ,
265+ )
266+
267+ assert len (functions ) == 1
268+ assert functions_count == 1
269+
270+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
271+ f .write (
272+ """
273+ def outer_function():
274+ def inner_function():
275+ pass
276+
277+ def another_inner_function():
278+ pass
279+ return inner_function, another_inner_function
280+ """
281+ )
282+ f .flush ()
283+ test_config = TestConfig (
284+ tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
285+ )
286+ path_obj_name = Path (f .name )
287+ functions , functions_count = get_functions_to_optimize (
288+ optimize_all = None ,
289+ replay_test = None ,
290+ file = path_obj_name ,
291+ test_cfg = test_config ,
292+ only_get_this_function = None ,
293+ ignore_paths = [Path ("/bruh/" )],
294+ project_root = path_obj_name .parent ,
295+ module_root = path_obj_name .parent ,
296+ )
297+
298+ assert len (functions ) == 1
299+ assert functions_count == 1
300+
301+
155302def test_filter_files_optimized ():
156303 tests_root = Path ("tests" ).resolve ()
157304 module_root = Path ().resolve ()
0 commit comments