11from __future__ import annotations
22
3+ import hashlib
34import os
45from collections import defaultdict
56from itertools import chain
6- from typing import TYPE_CHECKING
7+ from typing import TYPE_CHECKING , cast
78
89import libcst as cst
910
3132def get_code_optimization_context (
3233 function_to_optimize : FunctionToOptimize ,
3334 project_root_path : Path ,
34- optim_token_limit : int = 8000 ,
35- testgen_token_limit : int = 8000 ,
35+ optim_token_limit : int = 16000 ,
36+ testgen_token_limit : int = 16000 ,
3637) -> CodeOptimizationContext :
3738 # Get FunctionSource representation of helpers of FTO
3839 helpers_of_fto_dict , helpers_of_fto_list = get_function_sources_from_jedi (
@@ -73,6 +74,13 @@ def get_code_optimization_context(
7374 remove_docstrings = False ,
7475 code_context_type = CodeContextType .READ_ONLY ,
7576 )
77+ hashing_code_context = extract_code_markdown_context_from_files (
78+ helpers_of_fto_dict ,
79+ helpers_of_helpers_dict ,
80+ project_root_path ,
81+ remove_docstrings = True ,
82+ code_context_type = CodeContextType .HASHING ,
83+ )
7684
7785 # Handle token limits
7886 final_read_writable_tokens = encoded_tokens_len (final_read_writable_code )
@@ -125,11 +133,15 @@ def get_code_optimization_context(
125133 testgen_context_code_tokens = encoded_tokens_len (testgen_context_code )
126134 if testgen_context_code_tokens > testgen_token_limit :
127135 raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
136+ code_hash_context = hashing_code_context .markdown
137+ code_hash = hashlib .sha256 (code_hash_context .encode ("utf-8" )).hexdigest ()
128138
129139 return CodeOptimizationContext (
130140 testgen_context_code = testgen_context_code ,
131141 read_writable_code = final_read_writable_code ,
132142 read_only_context_code = read_only_context_code ,
143+ hashing_code_context = code_hash_context ,
144+ hashing_code_context_hash = code_hash ,
133145 helper_functions = helpers_of_fto_list ,
134146 preexisting_objects = preexisting_objects ,
135147 )
@@ -309,8 +321,8 @@ def extract_code_markdown_context_from_files(
309321 logger .debug (f"Error while getting read-only code: { e } " )
310322 continue
311323 if code_context .strip ():
312- code_context_with_imports = CodeString (
313- code = add_needed_imports_from_module (
324+ if code_context_type != CodeContextType . HASHING :
325+ code_context = add_needed_imports_from_module (
314326 src_module_code = original_code ,
315327 dst_module_code = code_context ,
316328 src_path = file_path ,
@@ -319,10 +331,9 @@ def extract_code_markdown_context_from_files(
319331 helper_functions = list (
320332 helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ())
321333 ),
322- ),
323- file_path = file_path .relative_to (project_root_path ),
324- )
325- code_context_markdown .code_strings .append (code_context_with_imports )
334+ )
335+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
336+ code_context_markdown .code_strings .append (code_string_context )
326337 # Extract code from file paths containing helpers of helpers
327338 for file_path , helper_function_sources in helpers_of_helpers_no_overlap .items ():
328339 try :
@@ -343,18 +354,17 @@ def extract_code_markdown_context_from_files(
343354 continue
344355
345356 if code_context .strip ():
346- code_context_with_imports = CodeString (
347- code = add_needed_imports_from_module (
357+ if code_context_type != CodeContextType . HASHING :
358+ code_context = add_needed_imports_from_module (
348359 src_module_code = original_code ,
349360 dst_module_code = code_context ,
350361 src_path = file_path ,
351362 dst_path = file_path ,
352363 project_root = project_root_path ,
353364 helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
354- ),
355- file_path = file_path .relative_to (project_root_path ),
356- )
357- code_context_markdown .code_strings .append (code_context_with_imports )
365+ )
366+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
367+ code_context_markdown .code_strings .append (code_string_context )
358368 return code_context_markdown
359369
360370
@@ -492,6 +502,8 @@ def parse_code_and_prune_cst(
492502 filtered_node , found_target = prune_cst_for_testgen_code (
493503 module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
494504 )
505+ elif code_context_type == CodeContextType .HASHING :
506+ filtered_node , found_target = prune_cst_for_code_hashing (module , target_functions )
495507 else :
496508 raise ValueError (f"Unknown code_context_type: { code_context_type } " ) # noqa: EM102
497509
@@ -583,6 +595,90 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583595 return (node .with_changes (** updates ) if updates else node ), True
584596
585597
598+ def prune_cst_for_code_hashing ( # noqa: PLR0911
599+ node : cst .CSTNode , target_functions : set [str ], prefix : str = ""
600+ ) -> tuple [cst .CSTNode | None , bool ]:
601+ """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
602+
603+ Returns
604+ -------
605+ (filtered_node, found_target):
606+ filtered_node: The modified CST node or None if it should be removed.
607+ found_target: True if a target function was found in this node's subtree.
608+
609+ """
610+ if isinstance (node , (cst .Import , cst .ImportFrom )):
611+ return None , False
612+
613+ if isinstance (node , cst .FunctionDef ):
614+ qualified_name = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
615+ if qualified_name in target_functions :
616+ new_body = remove_docstring_from_body (node .body ) if isinstance (node .body , cst .IndentedBlock ) else node .body
617+ return node .with_changes (body = new_body ), True
618+ return None , False
619+
620+ if isinstance (node , cst .ClassDef ):
621+ # Do not recurse into nested classes
622+ if prefix :
623+ return None , False
624+ # Assuming always an IndentedBlock
625+ if not isinstance (node .body , cst .IndentedBlock ):
626+ raise ValueError ("ClassDef body is not an IndentedBlock" ) # noqa: TRY004
627+ class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
628+ new_class_body : list [cst .CSTNode ] = []
629+ found_target = False
630+
631+ for stmt in node .body .body :
632+ if isinstance (stmt , cst .FunctionDef ):
633+ qualified_name = f"{ class_prefix } .{ stmt .name .value } "
634+ if qualified_name in target_functions :
635+ stmt_with_changes = stmt .with_changes (
636+ body = remove_docstring_from_body (cast ("cst.IndentedBlock" , stmt .body ))
637+ )
638+ new_class_body .append (stmt_with_changes )
639+ found_target = True
640+ # If no target functions found, remove the class entirely
641+ if not new_class_body or not found_target :
642+ return None , False
643+ return node .with_changes (
644+ body = cst .IndentedBlock (cast ("list[cst.BaseStatement]" , new_class_body ))
645+ ) if new_class_body else None , found_target
646+
647+ # For other nodes, we preserve them only if they contain target functions in their children.
648+ section_names = get_section_names (node )
649+ if not section_names :
650+ return node , False
651+
652+ updates : dict [str , list [cst .CSTNode ] | cst .CSTNode ] = {}
653+ found_any_target = False
654+
655+ for section in section_names :
656+ original_content = getattr (node , section , None )
657+ if isinstance (original_content , (list , tuple )):
658+ new_children = []
659+ section_found_target = False
660+ for child in original_content :
661+ filtered , found_target = prune_cst_for_code_hashing (child , target_functions , prefix )
662+ if filtered :
663+ new_children .append (filtered )
664+ section_found_target |= found_target
665+
666+ if section_found_target :
667+ found_any_target = True
668+ updates [section ] = new_children
669+ elif original_content is not None :
670+ filtered , found_target = prune_cst_for_code_hashing (original_content , target_functions , prefix )
671+ if found_target :
672+ found_any_target = True
673+ if filtered :
674+ updates [section ] = filtered
675+
676+ if not found_any_target :
677+ return None , False
678+
679+ return (node .with_changes (** updates ) if updates else node ), True
680+
681+
586682def prune_cst_for_read_only_code ( # noqa: PLR0911
587683 node : cst .CSTNode ,
588684 target_functions : set [str ],
0 commit comments