From 1f7124adcde1579028d3cee11963ad5a094c514f Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 27 Apr 2025 18:04:14 -0700 Subject: [PATCH 01/32] WIP --- codeflash/api/cfapi.py | 19 ++++++++++ codeflash/optimization/function_optimizer.py | 37 ++++++++++++++++---- 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 00d324db9..b1e2dbb03 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import json import os import sys @@ -14,6 +15,7 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number from codeflash.code_utils.git_utils import get_repo_owner_and_name +from codeflash.models.models import CodeOptimizationContext from codeflash.version import __version__ if TYPE_CHECKING: @@ -200,3 +202,20 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: return {} return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} + + +def is_function_being_optimized_again(code_context: CodeOptimizationContext) -> bool: + """Check if the function being optimized is being optimized again.""" + pr_number = get_pr_number() + if pr_number is None: + # Only want to do this check during GH Actions + return False + owner, repo = get_repo_owner_and_name() + + rw_context_hash = hashlib.sha256(str(code_context).encode()).hexdigest() + + payload = {"owner": owner, "repo": repo, "pullNumber": pr_number, "code_hash": rw_context_hash} + response = make_cfapi_request(endpoint="/is-function-being-optimized-again", method="POST", payload=payload) + if not response.ok or response.text != "true": + logger.error(f"Error: {response.text}") + return False diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 6ad25bc0e..1b71a4205 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -19,6 +19,7 @@ from rich.tree import Tree from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient +from codeflash.api.cfapi import is_function_being_optimized_again from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils @@ -144,6 +145,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") + if is_function_being_optimized_again(code_context=code_context): + return Failure("This code has already been optimized earlier") + code_print(code_context.read_writable_code) generated_test_paths = [ get_test_file_path( @@ -242,7 +246,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # request for new optimizations but don't block execution, check for completion later # adding to control and experiment set but with same traceid best_optimization = None - for _u, (candidates, exp_type) in enumerate(zip([optimizations_set.control, optimizations_set.experiment],["EXP0","EXP1"])): + for _u, (candidates, exp_type) in enumerate( + zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"]) + ): if candidates is None: continue @@ -254,7 +260,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: file_path_to_helper_classes=file_path_to_helper_classes, exp_type=exp_type, ) - ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id}) + ph( + "cli-optimize-function-finished", + { + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id + }, + ) generated_tests = remove_functions_from_generated_tests( generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove @@ -324,7 +337,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: explanation=explanation, existing_tests_source=existing_tests, generated_original_test_source=generated_tests_str, - function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + function_trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, coverage_message=coverage_message, git_remote=self.args.git_remote, ) @@ -379,7 +394,7 @@ def determine_best_candidate( # Start a new thread for AI service request, start loop in main thread # check if aiservice request is complete, when it is complete, append result to the candidates list with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - ai_service_client = self.aiservice_client if exp_type=="EXP0" else self.local_aiservice_client + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client future_line_profile_results = executor.submit( ai_service_client.optimize_python_code_line_profiler, source_code=code_context.read_writable_code, @@ -387,7 +402,11 @@ def determine_best_candidate( trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, line_profiler_results=original_code_baseline.line_profile_results["str_out"], num_candidates=10, - experiment_metadata=ExperimentMetadata(id=self.experiment_id, group= "control" if exp_type == "EXP0" else "experiment") if self.experiment_id else None, + experiment_metadata=ExperimentMetadata( + id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" + ) + if self.experiment_id + else None, ) try: candidate_index = 0 @@ -528,7 +547,9 @@ def determine_best_candidate( ) return best_optimization - def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str) -> None: + def log_successful_optimization( + self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str + ) -> None: explanation_panel = Panel( f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n" f"📈 {explanation.perf_improvement_line}\n" @@ -555,7 +576,9 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests: ph( "cli-optimize-success", { - "function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, "speedup_x": explanation.speedup_x, "speedup_pct": explanation.speedup_pct, "best_runtime": explanation.best_runtime_ns, From 41378a0b24da5292dd36308e315e744ab1399873 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 27 Apr 2025 19:09:56 -0700 Subject: [PATCH 02/32] WIP --- codeflash/api/cfapi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index b1e2dbb03..15500dd00 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -211,8 +211,8 @@ def is_function_being_optimized_again(code_context: CodeOptimizationContext) -> # Only want to do this check during GH Actions return False owner, repo = get_repo_owner_and_name() - - rw_context_hash = hashlib.sha256(str(code_context).encode()).hexdigest() + # TODO: Add file paths + rw_context_hash = hashlib.sha256(str(code_context.read_writable_code).encode()).hexdigest() payload = {"owner": owner, "repo": repo, "pullNumber": pr_number, "code_hash": rw_context_hash} response = make_cfapi_request(endpoint="/is-function-being-optimized-again", method="POST", payload=payload) From e9746c920a3940c99d8005d9fb0477a611001e5f Mon Sep 17 00:00:00 2001 From: dasarchan Date: Mon, 2 Jun 2025 18:56:12 -0400 Subject: [PATCH 03/32] batch code hash check --- codeflash/discovery/functions_to_optimize.py | 217 +++++++++++++++---- 1 file changed, 177 insertions(+), 40 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 8adfb4e00..add615f9a 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import hashlib import os import random import warnings @@ -145,15 +146,56 @@ def qualified_name(self) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" + def get_code_context_hash(self) -> str: + """Generate a SHA-256 hash representing the code context of this function. + + This hash includes the function's code content, file path, and qualified name + to uniquely identify the function for optimization tracking. + """ + try: + with open(self.file_path, 'r', encoding='utf-8') as f: + file_content = f.read() + + # Extract the function's code content + lines = file_content.splitlines() + if self.starting_line is not None and self.ending_line is not None: + # Use line numbers if available (1-indexed to 0-indexed) + function_content = '\n'.join(lines[self.starting_line - 1:self.ending_line]) + else: + # Fallback: use the entire file content if line numbers aren't available + function_content = file_content + + # Create a context string that includes: + # - File path (relative to make it portable) + # - Qualified function name + # - Function code content + context_parts = [ + str(self.file_path.name), # Just filename for portability + self.qualified_name, + function_content.strip() + ] + + context_string = '\n---\n'.join(context_parts) + + # Generate SHA-256 hash + return hashlib.sha256(context_string.encode('utf-8')).hexdigest() + + except (OSError, IOError) as e: + logger.warning(f"Could not read file {self.file_path} for hashing: {e}") + # Fallback hash using available metadata + fallback_string = f"{self.file_path.name}:{self.qualified_name}" + return hashlib.sha256(fallback_string.encode('utf-8')).hexdigest() + + def get_functions_to_optimize( - optimize_all: str | None, - replay_test: str | None, - file: Path | None, - only_get_this_function: str | None, - test_cfg: TestConfig, - ignore_paths: list[Path], - project_root: Path, - module_root: Path, + optimize_all: str | None, + replay_test: str | None, + file: Path | None, + only_get_this_function: str | None, + test_cfg: TestConfig, + ignore_paths: list[Path], + project_root: Path, + module_root: Path, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( "Only one of optimize_all, replay_test, or file should be provided" @@ -186,7 +228,7 @@ def get_functions_to_optimize( found_function = None for fn in functions.get(file, []): if only_function_name == fn.function_name and ( - class_name is None or class_name == fn.top_level_parent_name + class_name is None or class_name == fn.top_level_parent_name ): found_function = fn if found_function is None: @@ -224,8 +266,8 @@ def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]: function_to_optimize for function_to_optimize in function_lines.functions if (start_line := function_to_optimize.starting_line) is not None - and (end_line := function_to_optimize.ending_line) is not None - and any(start_line <= line <= end_line for line in modified_lines[path_str]) + and (end_line := function_to_optimize.ending_line) is not None + and any(start_line <= line <= end_line for line in modified_lines[path_str]) ] return modified_functions @@ -258,7 +300,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( - replay_test: Path, test_cfg: TestConfig, project_root_path: Path + replay_test: Path, test_cfg: TestConfig, project_root_path: Path ) -> dict[Path, list[FunctionToOptimize]]: function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) # Get the absolute file paths for each function, excluding class name if present @@ -273,7 +315,7 @@ def get_all_replay_test_functions( class_name = ( module_path_parts[-1] if module_path_parts - and is_class_defined_in_file( + and is_class_defined_in_file( module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py") ) else None @@ -323,7 +365,8 @@ def ignored_submodule_paths(module_root: str) -> list[str]: class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): def __init__( - self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None + self, file_name: Path, function_or_method_name: str, class_name: str | None = None, + line_no: int | None = None ) -> None: self.file_name = file_name self.class_name = class_name @@ -354,13 +397,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name: self.is_top_level = True if any( - isinstance(decorator, ast.Name) and decorator.id == "classmethod" - for decorator in body_node.decorator_list + isinstance(decorator, ast.Name) and decorator.id == "classmethod" + for decorator in body_node.decorator_list ): self.is_classmethod = True elif any( - isinstance(decorator, ast.Name) and decorator.id == "staticmethod" - for decorator in body_node.decorator_list + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list ): self.is_staticmethod = True return @@ -369,13 +412,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: # This way, if we don't have the class name, we can still find the static method for body_node in node.body: if ( - isinstance(body_node, ast.FunctionDef) - and body_node.name == self.function_name - and body_node.lineno in {self.line_no, self.line_no + 1} - and any( - isinstance(decorator, ast.Name) and decorator.id == "staticmethod" - for decorator in body_node.decorator_list - ) + isinstance(body_node, ast.FunctionDef) + and body_node.name == self.function_name + and body_node.lineno in {self.line_no, self.line_no + 1} + and any( + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list + ) ): self.is_staticmethod = True self.is_top_level = True @@ -386,7 +429,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: def inspect_top_level_functions_or_methods( - file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None + file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None ) -> FunctionProperties: with open(file_name, encoding="utf8") as file: try: @@ -408,13 +451,93 @@ def inspect_top_level_functions_or_methods( ) +def check_optimization_status( + functions_by_file: dict[Path, list[FunctionToOptimize]], + owner: str, + repo: str, + pr_number: int +) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + """Check which functions have already been optimized and filter them out. + + This function calls the optimization API to: + 1. Check which functions are already optimized + 2. Log new function hashes to the database + 3. Return only functions that need optimization + + Args: + functions_by_file: Dictionary mapping file paths to lists of functions + owner: Repository owner + repo: Repository name + pr_number: Pull request number + + Returns: + Tuple of (filtered_functions_dict, remaining_count) + """ + import requests + + # Build the code_contexts dictionary for the API call + code_contexts = {} + path_to_function_map = {} + + for file_path, functions in functions_by_file.items(): + for func in functions: + func_hash = func.get_code_context_hash() + # Use a unique path identifier that includes function info + path_key = f"{file_path}:{func.qualified_name}" + code_contexts[path_key] = func_hash + path_to_function_map[path_key] = (file_path, func) + + if not code_contexts: + return {}, 0 + + try: + # Call the optimization check API + response = requests.post( + "http://your-api-endpoint/is_code_being_optimized_again", # Replace with actual endpoint + json={ + "owner": owner, + "repo": repo, + "pr_number": str(pr_number), + "code_contexts": code_contexts + }, + timeout=30 + ) + response.raise_for_status() + + result = response.json() + already_optimized_paths = set(result.get("already_optimized_paths", [])) + + logger.info(f"Found {len(already_optimized_paths)} already optimized functions") + + # Filter out already optimized functions + filtered_functions = defaultdict(list) + remaining_count = 0 + + for path_key, (file_path, func) in path_to_function_map.items(): + if path_key not in already_optimized_paths: + filtered_functions[file_path].append(func) + remaining_count += 1 + + return dict(filtered_functions), remaining_count + + except Exception as e: + logger.warning(f"Failed to check optimization status: {e}") + logger.info("Proceeding with all functions (optimization check failed)") + # Return all functions if API call fails + total_count = sum(len(funcs) for funcs in functions_by_file.values()) + return functions_by_file, total_count + + def filter_functions( - modified_functions: dict[Path, list[FunctionToOptimize]], - tests_root: Path, - ignore_paths: list[Path], - project_root: Path, - module_root: Path, - disable_logs: bool = False, + modified_functions: dict[Path, list[FunctionToOptimize]], + tests_root: Path, + ignore_paths: list[Path], + project_root: Path, + module_root: Path, + disable_logs: bool = False, + owner: str | None = None, + repo: str | None = None, + pr_number: int | None = None, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: blocklist_funcs = get_blocklisted_functions() # Remove any function that we don't want to optimize @@ -432,6 +555,7 @@ def filter_functions( submodule_ignored_paths_count: int = 0 tests_root_str = str(tests_root) module_root_str = str(module_root) + # We desperately need Python 3.10+ only support to make this code readable with structural pattern matching for file_path_path, functions in modified_functions.items(): file_path = str(file_path_path) @@ -439,12 +563,12 @@ def filter_functions( test_functions_removed_count += len(functions) continue if file_path in ignore_paths or any( - file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths + file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths ): ignore_paths_removed_count += 1 continue if file_path in submodule_paths or any( - file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths + file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths ): submodule_ignored_paths_count += 1 continue @@ -464,13 +588,25 @@ def filter_functions( function for function in functions if not ( - function.file_path.name in blocklist_funcs - and function.qualified_name in blocklist_funcs[function.file_path.name] + function.file_path.name in blocklist_funcs + and function.qualified_name in blocklist_funcs[function.file_path.name] ) ] filtered_modified_functions[file_path] = functions functions_count += len(functions) + # Convert to Path keys for optimization check + path_based_functions = {Path(k): v for k, v in filtered_modified_functions.items() if v} + + # Check optimization status if repository info is provided + already_optimized_count = 0 + if owner and repo and pr_number is not None: + path_based_functions, functions_count = check_optimization_status( + path_based_functions, owner, repo, pr_number + ) + initial_count = sum(len(funcs) for funcs in filtered_modified_functions.values()) + already_optimized_count = initial_count - functions_count + if not disable_logs: log_info = { f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count, @@ -479,13 +615,14 @@ def filter_functions( f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count, f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count, f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count, + f"{already_optimized_count} already optimized function{'s' if already_optimized_count != 1 else ''}": already_optimized_count, } log_string = "\n".join([k for k, v in log_info.items() if v > 0]) if log_string: logger.info(f"Ignoring: {log_string}") console.rule() - return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count + return path_based_functions, functions_count def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool: @@ -505,8 +642,8 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list if submodule_paths is None: submodule_paths = ignored_submodule_paths(module_root) return not ( - file_path in submodule_paths - or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) + file_path in submodule_paths + or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) ) @@ -515,4 +652,4 @@ def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool: - return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list) + return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list) \ No newline at end of file From 5760316fa952d8cdd107ff7d28df2445ffd9cf73 Mon Sep 17 00:00:00 2001 From: dasarchan Date: Tue, 3 Jun 2025 14:02:12 -0400 Subject: [PATCH 04/32] implemented hash check into filter_functions --- codeflash/discovery/functions_to_optimize.py | 30 +++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index add615f9a..70f470db9 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -15,14 +15,15 @@ import libcst as cst from pydantic.dataclasses import dataclass -from codeflash.api.cfapi import get_blocklisted_functions +from codeflash.api.cfapi import get_blocklisted_functions, make_cfapi_request from codeflash.cli_cmds.console import DEBUG_MODE, console, logger from codeflash.code_utils.code_utils import ( is_class_defined_in_file, module_name_from_file_path, path_belongs_to_site_packages, ) -from codeflash.code_utils.git_utils import get_git_diff +from codeflash.code_utils.env_utils import get_pr_number +from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.models.models import FunctionParent from codeflash.telemetry.posthog_cf import ph @@ -473,8 +474,7 @@ def check_optimization_status( Returns: Tuple of (filtered_functions_dict, remaining_count) """ - import requests - + logger.info("entering function") # Build the code_contexts dictionary for the API call code_contexts = {} path_to_function_map = {} @@ -492,18 +492,18 @@ def check_optimization_status( try: # Call the optimization check API - response = requests.post( - "http://your-api-endpoint/is_code_being_optimized_again", # Replace with actual endpoint - json={ + logger.info("Checking status") + response = make_cfapi_request( + "/is-already-optimized", + "POST", + { "owner": owner, "repo": repo, - "pr_number": str(pr_number), + "pr_number": pr_number, "code_contexts": code_contexts - }, - timeout=30 + } ) response.raise_for_status() - result = response.json() already_optimized_paths = set(result.get("already_optimized_paths", [])) @@ -535,10 +535,8 @@ def filter_functions( project_root: Path, module_root: Path, disable_logs: bool = False, - owner: str | None = None, - repo: str | None = None, - pr_number: int | None = None, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + logger.info("filtering functions boogaloo") blocklist_funcs = get_blocklisted_functions() # Remove any function that we don't want to optimize @@ -600,6 +598,10 @@ def filter_functions( # Check optimization status if repository info is provided already_optimized_count = 0 + repository = git.Repo(Path.cwd(), search_parent_directories=True) + owner, repo = get_repo_owner_and_name(repository) + pr_number = get_pr_number() + print(owner, repo, pr_number) if owner and repo and pr_number is not None: path_based_functions, functions_count = check_optimization_status( path_based_functions, owner, repo, pr_number From 2367160d3970d9c20874d09f303669cc6e965677 Mon Sep 17 00:00:00 2001 From: dasarchan Date: Thu, 5 Jun 2025 16:31:48 -0400 Subject: [PATCH 05/32] removed prints, added cfapi.py func --- codeflash/api/cfapi.py | 29 ++++++++++---------- codeflash/discovery/functions_to_optimize.py | 22 ++------------- codeflash/optimization/function_optimizer.py | 3 -- 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 15500dd00..69b128d83 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -6,7 +6,7 @@ import sys from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Dict import requests import sentry_sdk @@ -204,18 +204,17 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} -def is_function_being_optimized_again(code_context: CodeOptimizationContext) -> bool: +def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, code_contexts: dict[str, str]) -> Dict: """Check if the function being optimized is being optimized again.""" - pr_number = get_pr_number() - if pr_number is None: - # Only want to do this check during GH Actions - return False - owner, repo = get_repo_owner_and_name() - # TODO: Add file paths - rw_context_hash = hashlib.sha256(str(code_context.read_writable_code).encode()).hexdigest() - - payload = {"owner": owner, "repo": repo, "pullNumber": pr_number, "code_hash": rw_context_hash} - response = make_cfapi_request(endpoint="/is-function-being-optimized-again", method="POST", payload=payload) - if not response.ok or response.text != "true": - logger.error(f"Error: {response.text}") - return False + response = make_cfapi_request( + "/is-already-optimized", + "POST", + { + "owner": owner, + "repo": repo, + "pr_number": pr_number, + "code_contexts": code_contexts + } + ) + response.raise_for_status() + return response.json() diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 70f470db9..417ff1ece 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -15,7 +15,7 @@ import libcst as cst from pydantic.dataclasses import dataclass -from codeflash.api.cfapi import get_blocklisted_functions, make_cfapi_request +from codeflash.api.cfapi import get_blocklisted_functions, make_cfapi_request, is_function_being_optimized_again from codeflash.cli_cmds.console import DEBUG_MODE, console, logger from codeflash.code_utils.code_utils import ( is_class_defined_in_file, @@ -474,7 +474,6 @@ def check_optimization_status( Returns: Tuple of (filtered_functions_dict, remaining_count) """ - logger.info("entering function") # Build the code_contexts dictionary for the API call code_contexts = {} path_to_function_map = {} @@ -491,23 +490,9 @@ def check_optimization_status( return {}, 0 try: - # Call the optimization check API - logger.info("Checking status") - response = make_cfapi_request( - "/is-already-optimized", - "POST", - { - "owner": owner, - "repo": repo, - "pr_number": pr_number, - "code_contexts": code_contexts - } - ) - response.raise_for_status() - result = response.json() + result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) already_optimized_paths = set(result.get("already_optimized_paths", [])) - logger.info(f"Found {len(already_optimized_paths)} already optimized functions") # Filter out already optimized functions filtered_functions = defaultdict(list) @@ -522,7 +507,6 @@ def check_optimization_status( except Exception as e: logger.warning(f"Failed to check optimization status: {e}") - logger.info("Proceeding with all functions (optimization check failed)") # Return all functions if API call fails total_count = sum(len(funcs) for funcs in functions_by_file.values()) return functions_by_file, total_count @@ -536,7 +520,6 @@ def filter_functions( module_root: Path, disable_logs: bool = False, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: - logger.info("filtering functions boogaloo") blocklist_funcs = get_blocklisted_functions() # Remove any function that we don't want to optimize @@ -601,7 +584,6 @@ def filter_functions( repository = git.Repo(Path.cwd(), search_parent_directories=True) owner, repo = get_repo_owner_and_name(repository) pr_number = get_pr_number() - print(owner, repo, pr_number) if owner and repo and pr_number is not None: path_based_functions, functions_count = check_optimization_status( path_based_functions, owner, repo, pr_number diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 1b71a4205..b6ce38ee6 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -145,9 +145,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") - if is_function_being_optimized_again(code_context=code_context): - return Failure("This code has already been optimized earlier") - code_print(code_context.read_writable_code) generated_test_paths = [ get_test_file_path( From c1fb089af0c737e3b93519f8fead6d8b718afdcc Mon Sep 17 00:00:00 2001 From: dasarchan Date: Thu, 5 Jun 2025 16:35:48 -0400 Subject: [PATCH 06/32] removed unused import --- codeflash/optimization/function_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 720d3a081..12aeff3fa 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -18,7 +18,6 @@ from rich.tree import Tree from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient -from codeflash.api.cfapi import is_function_being_optimized_again from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils From eb3d30592ed37547df20d7bdf6762c619ace0130 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Fri, 6 Jun 2025 15:01:31 -0700 Subject: [PATCH 07/32] fix no git error --- codeflash/discovery/functions_to_optimize.py | 77 ++++++++++---------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 51cd9f52e..a4d04b651 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -15,7 +15,7 @@ import libcst as cst from pydantic.dataclasses import dataclass -from codeflash.api.cfapi import get_blocklisted_functions, make_cfapi_request, is_function_being_optimized_again +from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again from codeflash.cli_cmds.console import DEBUG_MODE, console, logger from codeflash.code_utils.code_utils import ( is_class_defined_in_file, @@ -153,14 +153,14 @@ def get_code_context_hash(self) -> str: to uniquely identify the function for optimization tracking. """ try: - with open(self.file_path, 'r', encoding='utf-8') as f: + with open(self.file_path, encoding="utf-8") as f: file_content = f.read() # Extract the function's code content lines = file_content.splitlines() if self.starting_line is not None and self.ending_line is not None: # Use line numbers if available (1-indexed to 0-indexed) - function_content = '\n'.join(lines[self.starting_line - 1:self.ending_line]) + function_content = "\n".join(lines[self.starting_line - 1 : self.ending_line]) else: # Fallback: use the entire file content if line numbers aren't available function_content = file_content @@ -172,19 +172,20 @@ def get_code_context_hash(self) -> str: context_parts = [ str(self.file_path.name), # Just filename for portability self.qualified_name, - function_content.strip() + function_content.strip(), ] - context_string = '\n---\n'.join(context_parts) + context_string = "\n---\n".join(context_parts) # Generate SHA-256 hash - return hashlib.sha256(context_string.encode('utf-8')).hexdigest() + return hashlib.sha256(context_string.encode("utf-8")).hexdigest() - except (OSError, IOError) as e: + except OSError as e: logger.warning(f"Could not read file {self.file_path} for hashing: {e}") # Fallback hash using available metadata fallback_string = f"{self.file_path.name}:{self.qualified_name}" - return hashlib.sha256(fallback_string.encode('utf-8')).hexdigest() + return hashlib.sha256(fallback_string.encode("utf-8")).hexdigest() + def get_functions_to_optimize( optimize_all: str | None, @@ -228,7 +229,7 @@ def get_functions_to_optimize( found_function = None for fn in functions.get(file, []): if only_function_name == fn.function_name and ( - class_name is None or class_name == fn.top_level_parent_name + class_name is None or class_name == fn.top_level_parent_name ): found_function = fn if found_function is None: @@ -307,7 +308,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( - replay_test: Path, test_cfg: TestConfig, project_root_path: Path + replay_test: Path, test_cfg: TestConfig, project_root_path: Path ) -> dict[Path, list[FunctionToOptimize]]: function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) # Get the absolute file paths for each function, excluding class name if present @@ -322,7 +323,7 @@ def get_all_replay_test_functions( class_name = ( module_path_parts[-1] if module_path_parts - and is_class_defined_in_file( + and is_class_defined_in_file( module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py") ) else None @@ -374,8 +375,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]: class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): def __init__( - self, file_name: Path, function_or_method_name: str, class_name: str | None = None, - line_no: int | None = None + self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None ) -> None: self.file_name = file_name self.class_name = class_name @@ -406,13 +406,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name: self.is_top_level = True if any( - isinstance(decorator, ast.Name) and decorator.id == "classmethod" - for decorator in body_node.decorator_list + isinstance(decorator, ast.Name) and decorator.id == "classmethod" + for decorator in body_node.decorator_list ): self.is_classmethod = True elif any( - isinstance(decorator, ast.Name) and decorator.id == "staticmethod" - for decorator in body_node.decorator_list + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list ): self.is_staticmethod = True return @@ -421,13 +421,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: # This way, if we don't have the class name, we can still find the static method for body_node in node.body: if ( - isinstance(body_node, ast.FunctionDef) - and body_node.name == self.function_name - and body_node.lineno in {self.line_no, self.line_no + 1} - and any( - isinstance(decorator, ast.Name) and decorator.id == "staticmethod" - for decorator in body_node.decorator_list - ) + isinstance(body_node, ast.FunctionDef) + and body_node.name == self.function_name + and body_node.lineno in {self.line_no, self.line_no + 1} + and any( + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list + ) ): self.is_staticmethod = True self.is_top_level = True @@ -460,10 +460,7 @@ def inspect_top_level_functions_or_methods( def check_optimization_status( - functions_by_file: dict[Path, list[FunctionToOptimize]], - owner: str, - repo: str, - pr_number: int + functions_by_file: dict[Path, list[FunctionToOptimize]], owner: str, repo: str, pr_number: int ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: """Check which functions have already been optimized and filter them out. @@ -480,6 +477,7 @@ def check_optimization_status( Returns: Tuple of (filtered_functions_dict, remaining_count) + """ # Build the code_contexts dictionary for the API call code_contexts = {} @@ -500,7 +498,6 @@ def check_optimization_status( result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) already_optimized_paths = set(result.get("already_optimized_paths", [])) - # Filter out already optimized functions filtered_functions = defaultdict(list) remaining_count = 0 @@ -556,12 +553,12 @@ def filter_functions( test_functions_removed_count += len(_functions) continue if file_path in ignore_paths or any( - file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths + file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths ): ignore_paths_removed_count += 1 continue if file_path in submodule_paths or any( - file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths + file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths ): submodule_ignored_paths_count += 1 continue @@ -607,13 +604,15 @@ def filter_functions( # Check optimization status if repository info is provided already_optimized_count = 0 - repository = git.Repo(Path.cwd(), search_parent_directories=True) - owner, repo = get_repo_owner_and_name(repository) + try: + repository = git.Repo(Path.cwd(), search_parent_directories=True) + owner, repo = get_repo_owner_and_name(repository) + except git.exc.InvalidGitRepositoryError: + logger.warning("No git repository found") + owner, repo = None, None pr_number = get_pr_number() if owner and repo and pr_number is not None: - path_based_functions, functions_count = check_optimization_status( - path_based_functions, owner, repo, pr_number - ) + path_based_functions, functions_count = check_optimization_status(path_based_functions, owner, repo, pr_number) initial_count = sum(len(funcs) for funcs in filtered_modified_functions.values()) already_optimized_count = initial_count - functions_count @@ -654,8 +653,8 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list if submodule_paths is None: submodule_paths = ignored_submodule_paths(module_root) return not ( - file_path in submodule_paths - or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) + file_path in submodule_paths + or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) ) @@ -664,4 +663,4 @@ def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool: - return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list) \ No newline at end of file + return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list) From c862b4d5c93c092628849c9be3d1571d7baa2766 Mon Sep 17 00:00:00 2001 From: dasarchan Date: Fri, 6 Jun 2025 19:03:52 -0400 Subject: [PATCH 08/32] add low prob of repeating optimization --- codeflash/code_utils/config_consts.py | 1 + codeflash/discovery/functions_to_optimize.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 3e0acafcb..83ddc95f3 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -9,3 +9,4 @@ TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget COVERAGE_THRESHOLD = 60.0 MIN_TESTCASE_PASSED_THRESHOLD = 6 +REPEAT_OPTIMIZATION_PROBABILITY = 0.1 \ No newline at end of file diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index a4d04b651..37600ca44 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -28,6 +28,7 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.models.models import FunctionParent from codeflash.telemetry.posthog_cf import ph +from codeflash.code_utils.config_consts import REPEAT_OPTIMIZATION_PROBABILITY if TYPE_CHECKING: from libcst import CSTNode @@ -506,6 +507,10 @@ def check_optimization_status( if path_key not in already_optimized_paths: filtered_functions[file_path].append(func) remaining_count += 1 + else: + if random.random() < REPEAT_OPTIMIZATION_PROBABILITY: + filtered_functions[file_path].append(func) + remaining_count += 1 return dict(filtered_functions), remaining_count From 96ee580d7bafb59fb57f522850c633a3ce0a7f5a Mon Sep 17 00:00:00 2001 From: dasarchan Date: Fri, 6 Jun 2025 20:31:30 -0400 Subject: [PATCH 09/32] changes to cli for code context hash --- codeflash/discovery/functions_to_optimize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 37600ca44..0df9c4d74 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -509,6 +509,7 @@ def check_optimization_status( remaining_count += 1 else: if random.random() < REPEAT_OPTIMIZATION_PROBABILITY: + logger.info(f"Attempting more optimization on {path_key}") filtered_functions[file_path].append(func) remaining_count += 1 From 87fe0868952a6dfe50b79cd576544deb5a506336 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Fri, 6 Jun 2025 18:15:24 -0700 Subject: [PATCH 10/32] update the cli --- codeflash/api/cfapi.py | 11 +-- codeflash/discovery/functions_to_optimize.py | 78 +++++++++----------- 2 files changed, 37 insertions(+), 52 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index e07e0212c..5a5b5516a 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -1,12 +1,11 @@ from __future__ import annotations -import hashlib import json import os import sys from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional import requests import sentry_sdk @@ -15,7 +14,6 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number from codeflash.code_utils.git_utils import get_repo_owner_and_name -from codeflash.models.models import CodeOptimizationContext from codeflash.version import __version__ if TYPE_CHECKING: @@ -200,12 +198,7 @@ def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, cod response = make_cfapi_request( "/is-already-optimized", "POST", - { - "owner": owner, - "repo": repo, - "pr_number": pr_number, - "code_contexts": code_contexts - } + {"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts}, ) response.raise_for_status() return response.json() diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 0df9c4d74..c36a61c36 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -22,13 +22,13 @@ module_name_from_file_path, path_belongs_to_site_packages, ) +from codeflash.code_utils.config_consts import REPEAT_OPTIMIZATION_PROBABILITY from codeflash.code_utils.env_utils import get_pr_number from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name from codeflash.code_utils.time_utils import humanize_runtime from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.models.models import FunctionParent from codeflash.telemetry.posthog_cf import ph -from codeflash.code_utils.config_consts import REPEAT_OPTIMIZATION_PROBABILITY if TYPE_CHECKING: from libcst import CSTNode @@ -460,9 +460,7 @@ def inspect_top_level_functions_or_methods( ) -def check_optimization_status( - functions_by_file: dict[Path, list[FunctionToOptimize]], owner: str, repo: str, pr_number: int -) -> tuple[dict[Path, list[FunctionToOptimize]], int]: +def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptimize]]) -> list[tuple[str, str]]: """Check which functions have already been optimized and filter them out. This function calls the optimization API to: @@ -480,7 +478,19 @@ def check_optimization_status( Tuple of (filtered_functions_dict, remaining_count) """ - # Build the code_contexts dictionary for the API call + # Check optimization status if repository info is provided + # already_optimized_count = 0 + try: + repository = git.Repo(search_parent_directories=True) + owner, repo = get_repo_owner_and_name(repository) + except git.exc.InvalidGitRepositoryError: + logger.warning("No git repository found") + owner, repo = None, None + pr_number = get_pr_number() + + if not owner or not repo or pr_number is None: + return [] + code_contexts = {} path_to_function_map = {} @@ -497,29 +507,13 @@ def check_optimization_status( try: result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) - already_optimized_paths = set(result.get("already_optimized_paths", [])) - - # Filter out already optimized functions - filtered_functions = defaultdict(list) - remaining_count = 0 - - for path_key, (file_path, func) in path_to_function_map.items(): - if path_key not in already_optimized_paths: - filtered_functions[file_path].append(func) - remaining_count += 1 - else: - if random.random() < REPEAT_OPTIMIZATION_PROBABILITY: - logger.info(f"Attempting more optimization on {path_key}") - filtered_functions[file_path].append(func) - remaining_count += 1 - - return dict(filtered_functions), remaining_count + already_optimized_paths: list[tuple[str, str]] = result.get("already_optimized_paths", []) + return already_optimized_paths except Exception as e: logger.warning(f"Failed to check optimization status: {e}") # Return all functions if API call fails - total_count = sum(len(funcs) for funcs in functions_by_file.values()) - return functions_by_file, total_count + return [] def filter_functions( @@ -531,22 +525,25 @@ def filter_functions( previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None, disable_logs: bool = False, # noqa: FBT001, FBT002 ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {} blocklist_funcs = get_blocklisted_functions() logger.debug(f"Blocklisted functions: {blocklist_funcs}") # Remove any function that we don't want to optimize + already_optimized_paths = check_optimization_status(modified_functions) # Ignore files with submodule path, cache the submodule paths submodule_paths = ignored_submodule_paths(module_root) - filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {} functions_count: int = 0 test_functions_removed_count: int = 0 non_modules_removed_count: int = 0 site_packages_removed_count: int = 0 ignore_paths_removed_count: int = 0 malformed_paths_count: int = 0 + already_optimized_count: int = 0 submodule_ignored_paths_count: int = 0 blocklist_funcs_removed_count: int = 0 + already_optimized_paths_removed_count: int = 0 previous_checkpoint_functions_removed_count: int = 0 tests_root_str = str(tests_root) module_root_str = str(module_root) @@ -579,6 +576,7 @@ def filter_functions( except SyntaxError: malformed_paths_count += 1 continue + if blocklist_funcs: functions_tmp = [] for function in _functions: @@ -592,6 +590,17 @@ def filter_functions( # This function is NOT in blocklist. we can keep it functions_tmp.append(function) _functions = functions_tmp + functions_tmp = [] + for function in _functions: + if ( + function.file_path.name, + function.qualified_name, + ) in already_optimized_paths and random.random() > REPEAT_OPTIMIZATION_PROBABILITY: + # This function is in blocklist, we can skip it with a probability + already_optimized_paths_removed_count += 1 + continue + functions_tmp.append(function) + _functions = functions_tmp if previous_checkpoint_functions: functions_tmp = [] @@ -605,23 +614,6 @@ def filter_functions( filtered_modified_functions[file_path] = _functions functions_count += len(_functions) - # Convert to Path keys for optimization check - path_based_functions = {Path(k): v for k, v in filtered_modified_functions.items() if v} - - # Check optimization status if repository info is provided - already_optimized_count = 0 - try: - repository = git.Repo(Path.cwd(), search_parent_directories=True) - owner, repo = get_repo_owner_and_name(repository) - except git.exc.InvalidGitRepositoryError: - logger.warning("No git repository found") - owner, repo = None, None - pr_number = get_pr_number() - if owner and repo and pr_number is not None: - path_based_functions, functions_count = check_optimization_status(path_based_functions, owner, repo, pr_number) - initial_count = sum(len(funcs) for funcs in filtered_modified_functions.values()) - already_optimized_count = initial_count - functions_count - if not disable_logs: log_info = { f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count, @@ -639,7 +631,7 @@ def filter_functions( logger.info(f"Ignoring: {log_string}") console.rule() - return path_based_functions, functions_count + return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool: From 4cb823e70150bac6816efed997a12df293effe7c Mon Sep 17 00:00:00 2001 From: dasarchan Date: Fri, 6 Jun 2025 21:16:20 -0400 Subject: [PATCH 11/32] added separate write route, changed return format for api route --- codeflash/api/cfapi.py | 14 ++++++++++++ codeflash/discovery/functions_to_optimize.py | 24 +++++++++++--------- codeflash/optimization/function_optimizer.py | 17 ++++++++++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index e07e0212c..777f4bc02 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -209,3 +209,17 @@ def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, cod ) response.raise_for_status() return response.json() + +def add_code_context_hash(owner: str, repo: str, pr_number: int, code_context_hash: str) -> Response: + """Add code context to the DB cache""" + response = make_cfapi_request( + "/add-code-hash", + "POST", + { + "owner": owner, + "repo": repo, + "pr_number": pr_number, + "code_context_hash": code_context_hash + } + ) + return response diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 0df9c4d74..19026a105 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -532,6 +532,17 @@ def filter_functions( disable_logs: bool = False, # noqa: FBT001, FBT002 ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: blocklist_funcs = get_blocklisted_functions() + already_optimized_count = 0 + path_based_functions = {Path(k): v for k, v in modified_functions.items() if v} + try: + repository = git.Repo(Path.cwd(), search_parent_directories=True) + owner, repo = get_repo_owner_and_name(repository) + except git.exc.InvalidGitRepositoryError: + logger.warning("No git repository found") + owner, repo = None, None + pr_number = get_pr_number() + if owner and repo and pr_number is not None: + path_based_functions, functions_count = check_optimization_status(path_based_functions, owner, repo, pr_number) logger.debug(f"Blocklisted functions: {blocklist_funcs}") # Remove any function that we don't want to optimize @@ -606,19 +617,10 @@ def filter_functions( functions_count += len(_functions) # Convert to Path keys for optimization check - path_based_functions = {Path(k): v for k, v in filtered_modified_functions.items() if v} + # Check optimization status if repository info is provided - already_optimized_count = 0 - try: - repository = git.Repo(Path.cwd(), search_parent_directories=True) - owner, repo = get_repo_owner_and_name(repository) - except git.exc.InvalidGitRepositoryError: - logger.warning("No git repository found") - owner, repo = None, None - pr_number = get_pr_number() - if owner and repo and pr_number is not None: - path_based_functions, functions_count = check_optimization_status(path_based_functions, owner, repo, pr_number) + initial_count = sum(len(funcs) for funcs in filtered_modified_functions.values()) already_optimized_count = initial_count - functions_count diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5922d6c1c..4977daaa0 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import git import concurrent.futures import os import subprocess @@ -18,6 +19,7 @@ from rich.tree import Tree from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient +from codeflash.api.cfapi import add_code_context_hash from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils @@ -50,6 +52,8 @@ from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime +from codeflash.code_utils.env_utils import get_pr_number +from codeflash.code_utils.git_utils import get_repo_owner_and_name from codeflash.context import code_context_extractor from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions from codeflash.either import Failure, Success, is_successful @@ -370,6 +374,19 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ) self.log_successful_optimization(explanation, generated_tests, exp_type) + # Add function to code context hash if in gh actions + try: + repository = git.Repo(Path.cwd(), search_parent_directories=True) + owner, repo = get_repo_owner_and_name(repository) + except git.exc.InvalidGitRepositoryError: + logger.warning("No git repository found") + owner, repo = None, None + pr_number = get_pr_number() + + if owner and repo and pr_number is not None: + code_context_hash = self.function_to_optimize.get_code_context_hash() + add_code_context_hash(owner, repo, pr_number, code_context_hash) + if self.args.override_fixtures: restore_conftest(original_conftest_content) if not best_optimization: From dd8dceb8d1a1a93e3b4cb8e6c54f10b0146bb3ad Mon Sep 17 00:00:00 2001 From: dasarchan Date: Sat, 7 Jun 2025 00:00:08 -0400 Subject: [PATCH 12/32] removed empty test file --- codeflash/api/cfapi.py | 4 ++-- codeflash/discovery/functions_to_optimize.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index b9b47897f..50788f2b9 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -5,7 +5,7 @@ import sys from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, List import requests import sentry_sdk @@ -193,7 +193,7 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} -def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, code_contexts: dict[str, str]) -> Dict: +def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, code_contexts: List[Dict[str, str]]) -> Dict: """Check if the function being optimized is being optimized again.""" response = make_cfapi_request( "/is-already-optimized", diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index c36a61c36..4900b972d 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -491,8 +491,7 @@ def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptim if not owner or not repo or pr_number is None: return [] - code_contexts = {} - path_to_function_map = {} + code_contexts = [] for file_path, functions in functions_by_file.items(): for func in functions: @@ -529,7 +528,8 @@ def filter_functions( blocklist_funcs = get_blocklisted_functions() logger.debug(f"Blocklisted functions: {blocklist_funcs}") # Remove any function that we don't want to optimize - already_optimized_paths = check_optimization_status(modified_functions) + already_optimized_paths = check_optimization_status(modified_functions, project_root) + # Ignore files with submodule path, cache the submodule paths submodule_paths = ignored_submodule_paths(module_root) @@ -593,7 +593,7 @@ def filter_functions( functions_tmp = [] for function in _functions: if ( - function.file_path.name, + function.file_path, function.qualified_name, ) in already_optimized_paths and random.random() > REPEAT_OPTIMIZATION_PROBABILITY: # This function is in blocklist, we can skip it with a probability From 5989b26050cb0ece165a64097fb451ac05f5b626 Mon Sep 17 00:00:00 2001 From: dasarchan Date: Sat, 7 Jun 2025 00:09:17 -0400 Subject: [PATCH 13/32] updates --- codeflash/api/cfapi.py | 38 +++++++++++++------- codeflash/discovery/functions_to_optimize.py | 17 ++++----- codeflash/optimization/function_optimizer.py | 13 ++----- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 50788f2b9..beb298f9d 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -3,6 +3,7 @@ import json import os import sys +import git from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, List @@ -13,8 +14,8 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number -from codeflash.code_utils.git_utils import get_repo_owner_and_name from codeflash.version import __version__ +from codeflash.code_utils.git_utils import get_repo_owner_and_name if TYPE_CHECKING: from requests import Response @@ -203,16 +204,27 @@ def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, cod response.raise_for_status() return response.json() -def add_code_context_hash(owner: str, repo: str, pr_number: int, code_context_hash: str) -> Response: +def add_code_context_hash( code_context_hash: str): """Add code context to the DB cache""" - response = make_cfapi_request( - "/add-code-hash", - "POST", - { - "owner": owner, - "repo": repo, - "pr_number": pr_number, - "code_context_hash": code_context_hash - } - ) - return response + pr_number = get_pr_number() + if pr_number is None: + return + try: + owner, repo = get_repo_owner_and_name() + pr_number = get_pr_number() + except git.exc.InvalidGitRepositoryError: + return + + + if owner and repo and pr_number is not None: + make_cfapi_request( + "/add-code-hash", + "POST", + { + "owner": owner, + "repo": repo, + "pr_number": pr_number, + "code_context_hash": code_context_hash + } + ) + diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 4900b972d..7a3b122e5 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -159,6 +159,7 @@ def get_code_context_hash(self) -> str: # Extract the function's code content lines = file_content.splitlines() + print("starting and ending line ", self.starting_line, self.ending_line) if self.starting_line is not None and self.ending_line is not None: # Use line numbers if available (1-indexed to 0-indexed) function_content = "\n".join(lines[self.starting_line - 1 : self.ending_line]) @@ -460,7 +461,7 @@ def inspect_top_level_functions_or_methods( ) -def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptimize]]) -> list[tuple[str, str]]: +def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptimize]], project_root_path: Path) -> set[tuple[str, str]]: """Check which functions have already been optimized and filter them out. This function calls the optimization API to: @@ -497,22 +498,21 @@ def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptim for func in functions: func_hash = func.get_code_context_hash() # Use a unique path identifier that includes function info - path_key = f"{file_path}:{func.qualified_name}" - code_contexts[path_key] = func_hash - path_to_function_map[path_key] = (file_path, func) + code_contexts.append({"file_path": Path(file_path).relative_to(project_root_path), + "function_name": func.qualified_name, "code_hash": func_hash}) if not code_contexts: - return {}, 0 + return set(tuple()) try: result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) - already_optimized_paths: list[tuple[str, str]] = result.get("already_optimized_paths", []) - return already_optimized_paths + already_optimized_paths: list[tuple[str, str]] = result.get("already_optimized_tuples", []) + return set(( project_root_path / Path(path[0]), path[1]) for path in already_optimized_paths) except Exception as e: logger.warning(f"Failed to check optimization status: {e}") # Return all functions if API call fails - return [] + return set(tuple()) def filter_functions( @@ -625,6 +625,7 @@ def filter_functions( f"{already_optimized_count} already optimized function{'s' if already_optimized_count != 1 else ''}": already_optimized_count, f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count, f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count, + f"{already_optimized_paths_removed_count} function{'s' if already_optimized_paths_removed_count != 1 else ''} as previously attempted optimization": already_optimized_paths_removed_count, } log_string = "\n".join([k for k, v in log_info.items() if v > 0]) if log_string: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 4977daaa0..5ed16d92d 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -375,17 +375,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 self.log_successful_optimization(explanation, generated_tests, exp_type) # Add function to code context hash if in gh actions - try: - repository = git.Repo(Path.cwd(), search_parent_directories=True) - owner, repo = get_repo_owner_and_name(repository) - except git.exc.InvalidGitRepositoryError: - logger.warning("No git repository found") - owner, repo = None, None - pr_number = get_pr_number() - - if owner and repo and pr_number is not None: - code_context_hash = self.function_to_optimize.get_code_context_hash() - add_code_context_hash(owner, repo, pr_number, code_context_hash) + + add_code_context_hash(self.function_to_optimize.get_code_context_hash()) if self.args.override_fixtures: restore_conftest(original_conftest_content) From 5c0a028dc2676af149e43f63a96c3cc04a177116 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sat, 7 Jun 2025 17:39:24 -0700 Subject: [PATCH 14/32] Add a first version of hashing code context --- codeflash/context/code_context_extractor.py | 143 +++++++++++++++---- codeflash/models/models.py | 46 +++--- codeflash/optimization/function_optimizer.py | 8 +- 3 files changed, 145 insertions(+), 52 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 934d3053b..336f8bc8d 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -73,6 +73,13 @@ def get_code_optimization_context( remove_docstrings=False, code_context_type=CodeContextType.READ_ONLY, ) + hashing_code_context = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, + code_context_type=CodeContextType.HASHING, + ) # Handle token limits final_read_writable_tokens = encoded_tokens_len(final_read_writable_code) @@ -130,6 +137,7 @@ def get_code_optimization_context( testgen_context_code=testgen_context_code, read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, + hashing_code_context=hashing_code_context.markdown, helper_functions=helpers_of_fto_list, preexisting_objects=preexisting_objects, ) @@ -309,20 +317,21 @@ def extract_code_markdown_context_from_files( logger.debug(f"Error while getting read-only code: {e}") continue if code_context.strip(): - code_context_with_imports = CodeString( - code=add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list( - helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) + if code_context_type != CodeContextType.HASHING: + code_context = ( + add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=code_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=list( + helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) + ), ), - ), - file_path=file_path.relative_to(project_root_path), - ) - code_context_markdown.code_strings.append(code_context_with_imports) + ) + code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path)) + code_context_markdown.code_strings.append(code_string_context) # Extract code from file paths containing helpers of helpers for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): try: @@ -343,18 +352,19 @@ def extract_code_markdown_context_from_files( continue if code_context.strip(): - code_context_with_imports = CodeString( - code=add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), - ), - file_path=file_path.relative_to(project_root_path), - ) - code_context_markdown.code_strings.append(code_context_with_imports) + if code_context_type != CodeContextType.HASHING: + code_context = ( + add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=code_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), + ), + ) + code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path)) + code_context_markdown.code_strings.append(code_string_context) return code_context_markdown @@ -492,6 +502,8 @@ def parse_code_and_prune_cst( filtered_node, found_target = prune_cst_for_testgen_code( module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings ) + elif code_context_type == CodeContextType.HASHING: + filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions) else: raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102 @@ -583,6 +595,87 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 return (node.with_changes(**updates) if updates else node), True +def prune_cst_for_code_hashing( # noqa: PLR0911 + node: cst.CSTNode, target_functions: set[str], prefix: str = "" +) -> tuple[cst.CSTNode | None, bool]: + """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. + + Returns + ------- + (filtered_node, found_target): + filtered_node: The modified CST node or None if it should be removed. + found_target: True if a target function was found in this node's subtree. + + """ + if isinstance(node, (cst.Import, cst.ImportFrom)): + return None, False + + if isinstance(node, cst.FunctionDef): + qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value + if qualified_name in target_functions: + new_body = remove_docstring_from_body(node.body) + return node.with_changes(body=new_body), True + return None, False + + if isinstance(node, cst.ClassDef): + # Do not recurse into nested classes + if prefix: + return None, False + # Assuming always an IndentedBlock + if not isinstance(node.body, cst.IndentedBlock): + raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 + class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + new_body = [] + found_target = False + + for stmt in node.body.body: + if isinstance(stmt, cst.FunctionDef): + qualified_name = f"{class_prefix}.{stmt.name.value}" + if qualified_name in target_functions: + new_body.append(stmt) + found_target = True + # If no target functions found, remove the class entirely + if not new_body or not found_target: + return None, False + return node.with_changes( + body=remove_docstring_from_body(node.body.with_changes(body=new_body)) + ) if new_body else None, True + + # For other nodes, we preserve them only if they contain target functions in their children. + section_names = get_section_names(node) + if not section_names: + return node, False + + updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} + found_any_target = False + + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_target = False + for child in original_content: + filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix) + if filtered: + new_children.append(filtered) + section_found_target |= found_target + + if section_found_target: + found_any_target = True + updates[section] = new_children + elif original_content is not None: + filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix) + if found_target: + found_any_target = True + if filtered: + updates[section] = filtered + + if not found_any_target: + return None, False + + return (node.with_changes(**updates) if updates else node), True + + def prune_cst_for_read_only_code( # noqa: PLR0911 node: cst.CSTNode, target_functions: set[str], diff --git a/codeflash/models/models.py b/codeflash/models/models.py index b250d2474..4f45c41b5 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -16,7 +16,7 @@ from enum import Enum, IntEnum from pathlib import Path from re import Pattern -from typing import Annotated, Optional, cast +from typing import Annotated, cast from jedi.api.classes import Name from pydantic import AfterValidator, BaseModel, ConfigDict, Field @@ -77,10 +77,10 @@ class BestOptimization(BaseModel): candidate: OptimizedCandidate helper_functions: list[FunctionSource] runtime: int - replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None + replay_performance_gain: dict[BenchmarkKey, float] | None = None winning_behavioral_test_results: TestResults winning_benchmarking_test_results: TestResults - winning_replay_benchmarking_test_results: Optional[TestResults] = None + winning_replay_benchmarking_test_results: TestResults | None = None @dataclass(frozen=True) @@ -136,7 +136,7 @@ def to_dict(self) -> dict[str, list[dict[str, any]]]: class CodeString(BaseModel): code: Annotated[str, AfterValidator(validate_python_code)] - file_path: Optional[Path] = None + file_path: Path | None = None class CodeStringsMarkdown(BaseModel): @@ -157,6 +157,7 @@ class CodeOptimizationContext(BaseModel): testgen_context_code: str = "" read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" + hashing_code_context: str = "" helper_functions: list[FunctionSource] preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] @@ -165,6 +166,7 @@ class CodeContextType(str, Enum): READ_WRITABLE = "READ_WRITABLE" READ_ONLY = "READ_ONLY" TESTGEN = "TESTGEN" + HASHING = "HASHING" class OptimizedCandidateResult(BaseModel): @@ -172,7 +174,7 @@ class OptimizedCandidateResult(BaseModel): best_test_runtime: int behavior_test_results: TestResults benchmarking_test_results: TestResults - replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None + replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None optimization_candidate_index: int total_candidate_timing: int @@ -192,10 +194,10 @@ class GeneratedTestsList(BaseModel): class TestFile(BaseModel): instrumented_behavior_file_path: Path benchmarking_file_path: Path = None - original_file_path: Optional[Path] = None - original_source: Optional[str] = None + original_file_path: Path | None = None + original_source: str | None = None test_type: TestType - tests_in_file: Optional[list[TestsInFile]] = None + tests_in_file: list[TestsInFile] | None = None class TestFiles(BaseModel): @@ -238,13 +240,13 @@ def __len__(self) -> int: class OptimizationSet(BaseModel): control: list[OptimizedCandidate] - experiment: Optional[list[OptimizedCandidate]] + experiment: list[OptimizedCandidate] | None @dataclass(frozen=True) class TestsInFile: test_file: Path - test_class: Optional[str] + test_class: str | None test_function: str test_type: TestType @@ -277,10 +279,10 @@ class FunctionParent: class OriginalCodeBaseline(BaseModel): behavioral_test_results: TestResults benchmarking_test_results: TestResults - replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None + replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None line_profile_results: dict runtime: int - coverage_results: Optional[CoverageData] + coverage_results: CoverageData | None class CoverageStatus(Enum): @@ -299,7 +301,7 @@ class CoverageData: graph: dict[str, dict[str, Collection[object]]] code_context: CodeOptimizationContext main_func_coverage: FunctionCoverage - dependent_func_coverage: Optional[FunctionCoverage] + dependent_func_coverage: FunctionCoverage | None status: CoverageStatus blank_re: Pattern[str] = re.compile(r"\s*(#|$)") else_re: Pattern[str] = re.compile(r"\s*else\s*:\s*(#|$)") @@ -407,10 +409,10 @@ def to_name(self) -> str: @dataclass(frozen=True) class InvocationId: test_module_path: str # The fully qualified name of the test module - test_class_name: Optional[str] # The name of the class where the test is defined - test_function_name: Optional[str] # The name of the test_function. Does not include the components of the file_name + test_class_name: str | None # The name of the class where the test is defined + test_function_name: str | None # The name of the test_function. Does not include the components of the file_name function_getting_tested: str - iteration_id: Optional[str] + iteration_id: str | None # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id def id(self) -> str: @@ -421,7 +423,7 @@ def id(self) -> str: ) @staticmethod - def from_str_id(string_id: str, iteration_id: Optional[str] = None) -> InvocationId: + def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: components = string_id.split(":") assert len(components) == 4 second_components = components[1].split(".") @@ -446,13 +448,13 @@ class FunctionTestInvocation: id: InvocationId # The fully qualified name of the function invocation (id) file_name: Path # The file where the test is defined did_pass: bool # Whether the test this function invocation was part of, passed or failed - runtime: Optional[int] # Time in nanoseconds + runtime: int | None # Time in nanoseconds test_framework: str # unittest or pytest test_type: TestType - return_value: Optional[object] # The return value of the function invocation - timed_out: Optional[bool] - verification_type: Optional[str] = VerificationType.FUNCTION_CALL - stdout: Optional[str] = None + return_value: object | None # The return value of the function invocation + timed_out: bool | None + verification_type: str | None = VerificationType.FUNCTION_CALL + stdout: str | None = None @property def unique_invocation_loop_id(self) -> str: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5ed16d92d..ecd1350d6 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import git import concurrent.futures import os import subprocess @@ -52,8 +51,6 @@ from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.code_utils.env_utils import get_pr_number -from codeflash.code_utils.git_utils import get_repo_owner_and_name from codeflash.context import code_context_extractor from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions from codeflash.either import Failure, Success, is_successful @@ -265,7 +262,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 # adding to control and experiment set but with same traceid best_optimization = None for _u, (candidates, exp_type) in enumerate( - zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"]) + zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"], strict=False) ): if candidates is None: continue @@ -687,6 +684,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: testgen_context_code=new_code_ctx.testgen_context_code, read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, + hashing_code_context=new_code_ctx.hashing_code_context, helper_functions=new_code_ctx.helper_functions, # only functions that are read writable preexisting_objects=new_code_ctx.preexisting_objects, ) @@ -1283,7 +1281,7 @@ def generate_and_instrument_tests( test_perf_path, ) for test_index, (test_path, test_perf_path) in enumerate( - zip(generated_test_paths, generated_perf_test_paths) + zip(generated_test_paths, generated_perf_test_paths, strict=False) ) ] From 2686682bdbd73f05c46c5993a99fdc49c7539d77 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sat, 7 Jun 2025 18:40:55 -0700 Subject: [PATCH 15/32] Might work? --- codeflash/code_utils/git_utils.py | 2 + codeflash/discovery/functions_to_optimize.py | 98 ++++---------------- codeflash/optimization/function_optimizer.py | 4 + 3 files changed, 26 insertions(+), 78 deletions(-) diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index 8333c1099..875b261cd 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -5,6 +5,7 @@ import sys import tempfile import time +from functools import cache from io import StringIO from pathlib import Path from typing import TYPE_CHECKING @@ -79,6 +80,7 @@ def get_git_remotes(repo: Repo) -> list[str]: return [remote.name for remote in repository.remotes] +@cache def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "origin") -> tuple[str, str]: remote_url = get_remote_url(repo, git_remote) # call only once remote_url = remote_url.removesuffix(".git") if remote_url.endswith(".git") else remote_url diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 7a3b122e5..d94c089a7 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -22,7 +22,6 @@ module_name_from_file_path, path_belongs_to_site_packages, ) -from codeflash.code_utils.config_consts import REPEAT_OPTIMIZATION_PROBABILITY from codeflash.code_utils.env_utils import get_pr_number from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name from codeflash.code_utils.time_utils import humanize_runtime @@ -34,6 +33,7 @@ from libcst import CSTNode from libcst.metadata import CodeRange + from codeflash.models.models import CodeOptimizationContext from codeflash.verification.verification_utils import TestConfig @@ -127,8 +127,8 @@ class FunctionToOptimize: function_name: str file_path: Path parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] - starting_line: Optional[int] = None - ending_line: Optional[int] = None + starting_line: int | None = None + ending_line: int | None = None @property def top_level_parent_name(self) -> str: @@ -147,47 +147,6 @@ def qualified_name(self) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" - def get_code_context_hash(self) -> str: - """Generate a SHA-256 hash representing the code context of this function. - - This hash includes the function's code content, file path, and qualified name - to uniquely identify the function for optimization tracking. - """ - try: - with open(self.file_path, encoding="utf-8") as f: - file_content = f.read() - - # Extract the function's code content - lines = file_content.splitlines() - print("starting and ending line ", self.starting_line, self.ending_line) - if self.starting_line is not None and self.ending_line is not None: - # Use line numbers if available (1-indexed to 0-indexed) - function_content = "\n".join(lines[self.starting_line - 1 : self.ending_line]) - else: - # Fallback: use the entire file content if line numbers aren't available - function_content = file_content - - # Create a context string that includes: - # - File path (relative to make it portable) - # - Qualified function name - # - Function code content - context_parts = [ - str(self.file_path.name), # Just filename for portability - self.qualified_name, - function_content.strip(), - ] - - context_string = "\n---\n".join(context_parts) - - # Generate SHA-256 hash - return hashlib.sha256(context_string.encode("utf-8")).hexdigest() - - except OSError as e: - logger.warning(f"Could not read file {self.file_path} for hashing: {e}") - # Fallback hash using available metadata - fallback_string = f"{self.file_path.name}:{self.qualified_name}" - return hashlib.sha256(fallback_string.encode("utf-8")).hexdigest() - def get_functions_to_optimize( optimize_all: str | None, @@ -461,7 +420,7 @@ def inspect_top_level_functions_or_methods( ) -def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptimize]], project_root_path: Path) -> set[tuple[str, str]]: +def check_optimization_status(function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext) -> bool: """Check which functions have already been optimized and filter them out. This function calls the optimization API to: @@ -469,12 +428,6 @@ def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptim 2. Log new function hashes to the database 3. Return only functions that need optimization - Args: - functions_by_file: Dictionary mapping file paths to lists of functions - owner: Repository owner - repo: Repository name - pr_number: Pull request number - Returns: Tuple of (filtered_functions_dict, remaining_count) @@ -482,37 +435,40 @@ def check_optimization_status(functions_by_file: dict[Path, list[FunctionToOptim # Check optimization status if repository info is provided # already_optimized_count = 0 try: - repository = git.Repo(search_parent_directories=True) - owner, repo = get_repo_owner_and_name(repository) + owner, repo = get_repo_owner_and_name() except git.exc.InvalidGitRepositoryError: logger.warning("No git repository found") owner, repo = None, None pr_number = get_pr_number() if not owner or not repo or pr_number is None: - return [] + return False code_contexts = [] - for file_path, functions in functions_by_file.items(): - for func in functions: - func_hash = func.get_code_context_hash() - # Use a unique path identifier that includes function info - code_contexts.append({"file_path": Path(file_path).relative_to(project_root_path), - "function_name": func.qualified_name, "code_hash": func_hash}) + func_hash = hashlib.sha256(code_context.hashing_code_context.encode("utf-8")).hexdigest() + # Use a unique path identifier that includes function info + + code_contexts.append( + { + "file_path": function_to_optimize.file_path, + "function_name": function_to_optimize.qualified_name, + "code_hash": func_hash, + } + ) if not code_contexts: - return set(tuple()) + return False try: result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) already_optimized_paths: list[tuple[str, str]] = result.get("already_optimized_tuples", []) - return set(( project_root_path / Path(path[0]), path[1]) for path in already_optimized_paths) + return len(already_optimized_paths) > 0 except Exception as e: logger.warning(f"Failed to check optimization status: {e}") # Return all functions if API call fails - return set(tuple()) + return False def filter_functions( @@ -528,8 +484,7 @@ def filter_functions( blocklist_funcs = get_blocklisted_functions() logger.debug(f"Blocklisted functions: {blocklist_funcs}") # Remove any function that we don't want to optimize - already_optimized_paths = check_optimization_status(modified_functions, project_root) - + # already_optimized_paths = check_optimization_status(modified_functions, project_root) # Ignore files with submodule path, cache the submodule paths submodule_paths = ignored_submodule_paths(module_root) @@ -543,7 +498,6 @@ def filter_functions( already_optimized_count: int = 0 submodule_ignored_paths_count: int = 0 blocklist_funcs_removed_count: int = 0 - already_optimized_paths_removed_count: int = 0 previous_checkpoint_functions_removed_count: int = 0 tests_root_str = str(tests_root) module_root_str = str(module_root) @@ -590,17 +544,6 @@ def filter_functions( # This function is NOT in blocklist. we can keep it functions_tmp.append(function) _functions = functions_tmp - functions_tmp = [] - for function in _functions: - if ( - function.file_path, - function.qualified_name, - ) in already_optimized_paths and random.random() > REPEAT_OPTIMIZATION_PROBABILITY: - # This function is in blocklist, we can skip it with a probability - already_optimized_paths_removed_count += 1 - continue - functions_tmp.append(function) - _functions = functions_tmp if previous_checkpoint_functions: functions_tmp = [] @@ -625,7 +568,6 @@ def filter_functions( f"{already_optimized_count} already optimized function{'s' if already_optimized_count != 1 else ''}": already_optimized_count, f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count, f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count, - f"{already_optimized_paths_removed_count} function{'s' if already_optimized_paths_removed_count != 1 else ''} as previously attempted optimization": already_optimized_paths_removed_count, } log_string = "\n".join([k for k, v in log_info.items() if v > 0]) if log_string: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ecd1350d6..5e4269222 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -53,6 +53,7 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions +from codeflash.discovery.functions_to_optimize import check_optimization_status from codeflash.either import Failure, Success, is_successful from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( @@ -151,8 +152,11 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 with helper_function_path.open(encoding="utf8") as f: helper_code = f.read() original_helper_code[helper_function_path] = helper_code + if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") + if check_optimization_status(self.function_to_optimize, code_context): + return Failure("This function has already been optimized, skipping.") code_print(code_context.read_writable_code) generated_test_paths = [ From 4f39794868b74e15c1fa5c52380db2aa9c91f204 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sat, 7 Jun 2025 19:47:14 -0700 Subject: [PATCH 16/32] get it working --- codeflash/api/cfapi.py | 24 ++++++++------------ codeflash/context/code_context_extractor.py | 6 ++++- codeflash/discovery/functions_to_optimize.py | 3 +-- codeflash/models/models.py | 1 + codeflash/optimization/function_optimizer.py | 5 ++-- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index beb298f9d..16c645f92 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -3,19 +3,19 @@ import json import os import sys -import git from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, List +from typing import TYPE_CHECKING, Any, Optional +import git import requests import sentry_sdk from pydantic.json import pydantic_encoder from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number -from codeflash.version import __version__ from codeflash.code_utils.git_utils import get_repo_owner_and_name +from codeflash.version import __version__ if TYPE_CHECKING: from requests import Response @@ -194,7 +194,9 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} -def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, code_contexts: List[Dict[str, str]]) -> Dict: +def is_function_being_optimized_again( + owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]] +) -> dict: """Check if the function being optimized is being optimized again.""" response = make_cfapi_request( "/is-already-optimized", @@ -204,8 +206,9 @@ def is_function_being_optimized_again(owner: str, repo: str, pr_number: int, cod response.raise_for_status() return response.json() -def add_code_context_hash( code_context_hash: str): - """Add code context to the DB cache""" + +def add_code_context_hash(code_context_hash: str) -> None: + """Add code context to the DB cache.""" pr_number = get_pr_number() if pr_number is None: return @@ -215,16 +218,9 @@ def add_code_context_hash( code_context_hash: str): except git.exc.InvalidGitRepositoryError: return - if owner and repo and pr_number is not None: make_cfapi_request( "/add-code-hash", "POST", - { - "owner": owner, - "repo": repo, - "pr_number": pr_number, - "code_context_hash": code_context_hash - } + {"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash}, ) - diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 336f8bc8d..f355eeacb 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import os from collections import defaultdict from itertools import chain @@ -132,12 +133,15 @@ def get_code_optimization_context( testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) if testgen_context_code_tokens > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + code_hash_context = hashing_code_context.markdown + code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() return CodeOptimizationContext( testgen_context_code=testgen_context_code, read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, - hashing_code_context=hashing_code_context.markdown, + hashing_code_context=code_hash_context, + hashing_code_context_hash=code_hash, helper_functions=helpers_of_fto_list, preexisting_objects=preexisting_objects, ) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index d94c089a7..4ccc656eb 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import hashlib import os import random import warnings @@ -446,7 +445,7 @@ def check_optimization_status(function_to_optimize: FunctionToOptimize, code_con code_contexts = [] - func_hash = hashlib.sha256(code_context.hashing_code_context.encode("utf-8")).hexdigest() + func_hash = code_context.hashing_code_context_hash # Use a unique path identifier that includes function info code_contexts.append( diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 4f45c41b5..e0704215e 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -158,6 +158,7 @@ class CodeOptimizationContext(BaseModel): read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" hashing_code_context: str = "" + hashing_code_context_hash: str = "" helper_functions: list[FunctionSource] preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5e4269222..079bb3cd3 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -156,7 +156,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") if check_optimization_status(self.function_to_optimize, code_context): - return Failure("This function has already been optimized, skipping.") + return Failure("This function has previously been optimized, skipping.") code_print(code_context.read_writable_code) generated_test_paths = [ @@ -377,7 +377,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 # Add function to code context hash if in gh actions - add_code_context_hash(self.function_to_optimize.get_code_context_hash()) + add_code_context_hash(code_context.hashing_code_context_hash) if self.args.override_fixtures: restore_conftest(original_conftest_content) @@ -689,6 +689,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, hashing_code_context=new_code_ctx.hashing_code_context, + hashing_code_context_hash=new_code_ctx.hashing_code_context_hash, helper_functions=new_code_ctx.helper_functions, # only functions that are read writable preexisting_objects=new_code_ctx.preexisting_objects, ) From 50f4c333b24296a1a27d98da9fc7c192ba219819 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sat, 7 Jun 2025 21:19:43 -0700 Subject: [PATCH 17/32] 10% chance of optimizing again --- codeflash/optimization/function_optimizer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 079bb3cd3..c0b4f46f8 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -3,6 +3,7 @@ import ast import concurrent.futures import os +import random import subprocess import time import uuid @@ -40,6 +41,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT, N_CANDIDATES, N_TESTS_TO_GENERATE, + REPEAT_OPTIMIZATION_PROBABILITY, TOTAL_LOOPING_TIME, ) from codeflash.code_utils.edit_generated_tests import ( @@ -155,7 +157,13 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") - if check_optimization_status(self.function_to_optimize, code_context): + # Random here means that we still attempt optimization with a fractional chance to see if + # last time we could not find an optimization, maybe this time we do. + # Random is before as a performance optimization, swapping the two 'and' statements has the same effect + if ( + random.random() > REPEAT_OPTIMIZATION_PROBABILITY # noqa: S311 + and check_optimization_status(self.function_to_optimize, code_context) + ): return Failure("This function has previously been optimized, skipping.") code_print(code_context.read_writable_code) From c856f1e1fed6ef803a43144018e96d53427c9f37 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sat, 7 Jun 2025 21:37:05 -0700 Subject: [PATCH 18/32] fix a bug --- codeflash/context/code_context_extractor.py | 34 +++++++++------------ tests/test_git_utils.py | 5 +++ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index f355eeacb..c65787f65 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -322,16 +322,14 @@ def extract_code_markdown_context_from_files( continue if code_context.strip(): if code_context_type != CodeContextType.HASHING: - code_context = ( - add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list( - helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) - ), + code_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=code_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=list( + helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) ), ) code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path)) @@ -357,15 +355,13 @@ def extract_code_markdown_context_from_files( if code_context.strip(): if code_context_type != CodeContextType.HASHING: - code_context = ( - add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), - ), + code_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=code_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), ) code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path)) code_context_markdown.code_strings.append(code_string_context) diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index f456a0d90..293ad5c9e 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -11,30 +11,35 @@ class TestGitUtils(unittest.TestCase): def test_test_get_repo_owner_and_name(self, mock_get_remote_url): # Test with a standard GitHub HTTPS URL mock_get_remote_url.return_value = "https://github.com/owner/repo.git" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "owner" assert repo_name == "repo" # Test with a GitHub SSH URL mock_get_remote_url.return_value = "git@github.com:owner/repo.git" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "owner" assert repo_name == "repo" # Test with another GitHub SSH URL mock_get_remote_url.return_value = "git@github.com:codeflash-ai/posthog.git" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "codeflash-ai" assert repo_name == "posthog" # Test with a URL without the .git suffix mock_get_remote_url.return_value = "https://github.com/owner/repo" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "owner" assert repo_name == "repo" # Test with another GitHub SSH URL mock_get_remote_url.return_value = "git@github.com:codeflash-ai/posthog/" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "codeflash-ai" assert repo_name == "posthog" From b48ed5c89d537c7aeb9b4edc58b5d4c9688e2561 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sat, 7 Jun 2025 21:38:31 -0700 Subject: [PATCH 19/32] ruff fix --- codeflash/code_utils/config_consts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 83ddc95f3..0b8f54204 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -9,4 +9,4 @@ TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget COVERAGE_THRESHOLD = 60.0 MIN_TESTCASE_PASSED_THRESHOLD = 6 -REPEAT_OPTIMIZATION_PROBABILITY = 0.1 \ No newline at end of file +REPEAT_OPTIMIZATION_PROBABILITY = 0.1 From 9e14cfe7a098c74faa3708aeaee5145092092369 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 00:30:47 -0700 Subject: [PATCH 20/32] fix bugs with docstring removal --- codeflash/context/code_context_extractor.py | 9 +- tests/test_code_context_extractor.py | 365 +++++++++++++++++++- 2 files changed, 356 insertions(+), 18 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index c65787f65..ab8bb730c 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -613,7 +613,7 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 if isinstance(node, cst.FunctionDef): qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value if qualified_name in target_functions: - new_body = remove_docstring_from_body(node.body) + new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body return node.with_changes(body=new_body), True return None, False @@ -632,14 +632,13 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 if isinstance(stmt, cst.FunctionDef): qualified_name = f"{class_prefix}.{stmt.name.value}" if qualified_name in target_functions: - new_body.append(stmt) + stmt_with_changes = stmt.with_changes(body=remove_docstring_from_body(stmt.body)) + new_body.append(stmt_with_changes) found_target = True # If no target functions found, remove the class entirely if not new_body or not found_target: return None, False - return node.with_changes( - body=remove_docstring_from_body(node.body.with_changes(body=new_body)) - ) if new_body else None, True + return node.with_changes(body=cst.IndentedBlock(new_body)) if new_body else None, found_target # For other nodes, we preserve them only if they contain target functions in their children. section_names = get_section_names(node) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 90356ac10..8402be449 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent @@ -30,6 +29,7 @@ def __init__(self, name): def nested_method(self): return self.name + def main_method(): return "hello" @@ -81,8 +81,9 @@ def test_code_replacement10() -> None: code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) qualified_names = {func.qualified_name for func in code_ctx.helper_functions} - assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here + assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ from __future__ import annotations @@ -106,8 +107,26 @@ def main_method(self): expected_read_only_context = """ """ + expected_hashing_context = f""" +```python:{file_path.relative_to(file_path.parent)} +class HelperClass: + + def helper_method(self): + return self.name + + +class MainClass: + + def main_method(self): + self.name = HelperClass.NestedClass("test").nested_method() + return HelperClass(self.name).helper_method() +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_class_method_dependencies() -> None: file_path = Path(__file__).resolve() @@ -122,6 +141,8 @@ def test_class_method_dependencies() -> None: code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve()) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + expected_read_write_context = """ from __future__ import annotations from collections import defaultdict @@ -153,8 +174,36 @@ def topologicalSort(self): """ expected_read_only_context = "" + + expected_hashing_context = f""" +```python:{file_path.relative_to(file_path.parent.resolve())} +class Graph: + + def topologicalSortUtil(self, v, visited, stack): + visited[v] = True + + for i in self.graph[v]: + if visited[i] == False: + self.topologicalSortUtil(i, visited, stack) + + stack.insert(0, v) + + def topologicalSort(self): + visited = [False] * self.V + stack = [] + + for i in range(self.V): + if visited[i] == False: + self.topologicalSortUtil(i, visited, stack) + + # Print contents of stack + return stack +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_bubble_sort_helper() -> None: @@ -176,6 +225,7 @@ def test_bubble_sort_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math @@ -196,8 +246,24 @@ def sort_from_another_file(arr): """ expected_read_only_context = "" + expected_hashing_context = """ +```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py +def sorter(arr): + arr.sort() + x = math.sqrt(2) + print(x) + return arr +``` +```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py +def sort_from_another_file(arr): + sorted_arr = sorter(arr) + return sorted_arr +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_flavio_typed_code_helper() -> None: @@ -366,7 +432,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: lifespan=self.__duration__, ) ''' - with tempfile.NamedTemporaryFile(mode="w") as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: f.write(code) f.flush() file_path = Path(f.name).resolve() @@ -391,6 +457,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): @@ -543,8 +610,67 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]): __backend__: _CacheBackendT ``` ''' + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): + + def get_cache_or_call( + self, + *, + func: Callable[_P, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + lifespan: datetime.timedelta, + ) -> Any: # noqa: ANN401 + if os.environ.get("NO_CACHE"): + return func(*args, **kwargs) + + try: + key = self.hash_key(func=func, args=args, kwargs=kwargs) + except: # noqa: E722 + # If we can't create a cache key, we should just call the function. + logging.warning("Failed to hash cache key for function: %s", func) + return func(*args, **kwargs) + result_pair = self.get(key=key) + + if result_pair is not None: + cached_time, result = result_pair + if not os.environ.get("RE_CACHE") and ( + datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005 + ): + try: + return self.decode(data=result) + except CacheBackendDecodeError as e: + logging.warning("Failed to decode cache data: %s", e) + # If decoding fails we will treat this as a cache miss. + # This might happens if underlying class definition of the data changes. + self.delete(key=key) + result = func(*args, **kwargs) + try: + self.put(key=key, data=self.encode(data=result)) + except CacheBackendEncodeError as e: + logging.warning("Failed to encode cache data: %s", e) + # If encoding fails, we should still return the result. + return result + + +class _PersistentCache(Generic[_P, _R, _CacheBackendT]): + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + if "NO_CACHE" in os.environ: + return self.__wrapped__(*args, **kwargs) + os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True) + return self.__backend__.get_cache_or_call( + func=self.__wrapped__, + args=args, + kwargs=kwargs, + lifespan=self.__duration__, + ) +``` +""" assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class() -> None: @@ -592,6 +718,8 @@ def helper_method(self): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + expected_read_write_context = """ class MyClass: def __init__(self): @@ -618,8 +746,21 @@ def __repr__(self): return "HelperClass" + str(self.x) ``` """ + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + def helper_method(self): + return self.x +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class_token_limit_1() -> None: @@ -672,6 +813,7 @@ def helper_method(self): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. expected_read_write_context = """ class MyClass: @@ -697,9 +839,21 @@ class HelperClass: def __repr__(self): return "HelperClass" + str(self.x) ``` +""" + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + def helper_method(self): + return self.x +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class_token_limit_2() -> None: @@ -752,6 +906,7 @@ def helper_method(self): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. expected_read_write_context = """ class MyClass: @@ -769,8 +924,20 @@ def helper_method(self): return self.x """ expected_read_only_context = "" + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + def helper_method(self): + return self.x +``` +""" assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class_token_limit_3() -> None: @@ -823,6 +990,7 @@ def helper_method(self): with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + def test_example_class_token_limit_4() -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] @@ -875,6 +1043,7 @@ def helper_method(self): with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + def test_repo_helper() -> None: project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" path_to_file = project_root / "main.py" @@ -889,6 +1058,7 @@ def test_repo_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math import requests @@ -938,9 +1108,38 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def process_data(self, raw_data: str) -> str: + \"\"\"Process raw data by converting it to uppercase.\"\"\" + return raw_data.upper() + + def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str: + \"\"\"Add a prefix to the processed data.\"\"\" + return prefix + data +``` +```python:{path_to_file.relative_to(project_root)} +def fetch_and_process_data(): + # Use the global variable for the request + response = requests.get(API_URL) + response.raise_for_status() + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + processed = processor.add_prefix(processed) + + return processed +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_of_helper() -> None: @@ -958,6 +1157,7 @@ def test_repo_helper_of_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1014,10 +1214,38 @@ def transform(self, data): self.data = data return self.data ``` +""" + expected_hashing_context = f""" +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def process_data(self, raw_data: str) -> str: + \"\"\"Process raw data by converting it to uppercase.\"\"\" + return raw_data.upper() + + def transform_data(self, data: str) -> str: + \"\"\"Transform the processed data\"\"\" + return DataTransformer().transform(data) +``` +```python:{path_to_file.relative_to(project_root)} +def fetch_and_transform_data(): + # Use the global variable for the request + response = requests.get(API_URL) + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + transformed = processor.transform_data(processed) + + return transformed +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_of_helper_same_class() -> None: @@ -1034,6 +1262,7 @@ def test_repo_helper_of_helper_same_class() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1078,10 +1307,20 @@ def __repr__(self) -> str: return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def transform_data_own_method(self, data: str) -> str: + \"\"\"Transform the processed data using own method\"\"\" + return DataTransformer().transform_using_own_method(data) +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_of_helper_same_file() -> None: @@ -1098,6 +1337,7 @@ def test_repo_helper_of_helper_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1137,10 +1377,20 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def transform_data_same_file_function(self, data: str) -> str: + \"\"\"Transform the processed data using a function from the same file\"\"\" + return DataTransformer().transform_using_same_file_function(data) +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_all_same_file() -> None: @@ -1156,6 +1406,7 @@ def test_repo_helper_all_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ class DataTransformer: def __init__(self): @@ -1181,10 +1432,27 @@ def transform(self, data): return self.data ``` +""" + expected_hashing_context = f""" +```python:{path_to_transform_utils.relative_to(project_root)} +class DataTransformer: + + def transform_using_own_method(self, data): + return self.transform(data) + + def transform_data_all_same_file(self, data): + new_data = update_data(data) + return self.transform_using_own_method(new_data) + + +def update_data(data): + return data + " updated" +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_circular_dependency() -> None: @@ -1201,6 +1469,7 @@ def test_repo_helper_circular_dependency() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1240,10 +1509,26 @@ def __repr__(self) -> str: return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:utils.py +class DataProcessor: + + def circular_dependency(self, data: str) -> str: + return DataTransformer().circular_dependency(data) +``` +```python:{path_to_transform_utils.relative_to(project_root)} +class DataTransformer: + + def circular_dependency(self, data): + return DataProcessor().circular_dependency(data) +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_indirect_init_helper() -> None: code = """ @@ -1282,6 +1567,7 @@ def outside_method(): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ class MyClass: def __init__(self): @@ -1295,9 +1581,18 @@ def target_method(self): def outside_method(): return 1 ``` +""" + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + return self.x + self.y +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_direct_module_import() -> None: project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" @@ -1311,9 +1606,9 @@ def test_direct_module_import() -> None: ending_line=None, ) - code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_only_context = """ ```python:utils.py @@ -1336,6 +1631,26 @@ def transform_data(self, data: str) -> str: \"\"\"Transform the processed data\"\"\" return DataTransformer().transform(data) ```""" + expected_hashing_context = """ +```python:main.py +def fetch_and_transform_data(): + # Use the global variable for the request + response = requests.get(API_URL) + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + transformed = processor.transform_data(processed) + + return transformed +``` +```python:import_test.py +def function_to_optimize(): + return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() +``` +""" expected_read_write_context = """ import requests from globals import API_URL @@ -1362,9 +1677,11 @@ def function_to_optimize(): """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_module_import_optimization() -> None: - main_code = ''' + main_code = """ import utility_module class Calculator: @@ -1391,9 +1708,9 @@ def calculate(self, operation, x, y): return self.subtract(x, y) else: return None -''' +""" - utility_module_code = ''' + utility_module_code = """ import sys import platform import logging @@ -1466,7 +1783,7 @@ def get_system_details(): "default_precision": DEFAULT_PRECISION, "python_version": sys.version } -''' +""" # Create a temporary directory for the test with tempfile.TemporaryDirectory() as temp_dir: @@ -1515,6 +1832,7 @@ def get_system_details(): # Get the code optimization context code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context # The expected contexts expected_read_write_context = """ import utility_module @@ -1579,13 +1897,34 @@ def select_precision(precision, fallback_precision): else: return DEFAULT_PRECISION ``` +""" + expected_hashing_context = """ +```python:main_module.py +class Calculator: + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) + else: + return None +``` """ # Verify the contexts match the expected values assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_module_import_init_fto() -> None: - main_code = ''' + main_code = """ import utility_module class Calculator: @@ -1612,9 +1951,9 @@ def calculate(self, operation, x, y): return self.subtract(x, y) else: return None -''' +""" - utility_module_code = ''' + utility_module_code = """ import sys import platform import logging @@ -1687,7 +2026,7 @@ def get_system_details(): "default_precision": DEFAULT_PRECISION, "python_version": sys.version } -''' +""" # Create a temporary directory for the test with tempfile.TemporaryDirectory() as temp_dir: @@ -1791,4 +2130,4 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): ``` """ assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() \ No newline at end of file + assert read_only_context.strip() == expected_read_only_context.strip() From 5d4870f803c40432cf7639a2302817b83f9f2784 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 00:33:29 -0700 Subject: [PATCH 21/32] fix a type --- codeflash/models/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index e0704215e..393d92316 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from rich.tree import Tree @@ -136,7 +136,7 @@ def to_dict(self) -> dict[str, list[dict[str, any]]]: class CodeString(BaseModel): code: Annotated[str, AfterValidator(validate_python_code)] - file_path: Path | None = None + file_path: Optional[Path] = None class CodeStringsMarkdown(BaseModel): From 2c1314dc91ebfde6cac89bdf232505dad786c09e Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 00:41:25 -0700 Subject: [PATCH 22/32] fix more tests --- tests/test_code_context_extractor.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 8402be449..0066c0971 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -1114,11 +1114,9 @@ def __repr__(self) -> str: class DataProcessor: def process_data(self, raw_data: str) -> str: - \"\"\"Process raw data by converting it to uppercase.\"\"\" return raw_data.upper() def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str: - \"\"\"Add a prefix to the processed data.\"\"\" return prefix + data ``` ```python:{path_to_file.relative_to(project_root)} @@ -1220,11 +1218,9 @@ def transform(self, data): class DataProcessor: def process_data(self, raw_data: str) -> str: - \"\"\"Process raw data by converting it to uppercase.\"\"\" return raw_data.upper() def transform_data(self, data: str) -> str: - \"\"\"Transform the processed data\"\"\" return DataTransformer().transform(data) ``` ```python:{path_to_file.relative_to(project_root)} @@ -1309,11 +1305,16 @@ def __repr__(self) -> str: """ expected_hashing_context = f""" +```python:transform_utils.py +class DataTransformer: + + def transform_using_own_method(self, data): + return self.transform(data) +``` ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: def transform_data_own_method(self, data: str) -> str: - \"\"\"Transform the processed data using own method\"\"\" return DataTransformer().transform_using_own_method(data) ``` """ @@ -1379,11 +1380,16 @@ def __repr__(self) -> str: ``` """ expected_hashing_context = f""" +```python:transform_utils.py +class DataTransformer: + + def transform_using_same_file_function(self, data): + return update_data(data) +``` ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: def transform_data_same_file_function(self, data: str) -> str: - \"\"\"Transform the processed data using a function from the same file\"\"\" return DataTransformer().transform_using_same_file_function(data) ``` """ From 32a80013bff26210a4480bcddbdf01454b162f01 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 00:50:35 -0700 Subject: [PATCH 23/32] fix types for python 3.9 --- codeflash/discovery/functions_to_optimize.py | 4 +- codeflash/models/models.py | 42 ++++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 4e568c4eb..3fca62806 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -126,8 +126,8 @@ class FunctionToOptimize: function_name: str file_path: Path parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] - starting_line: int | None = None - ending_line: int | None = None + starting_line: Optional[int] = None + ending_line: Optional[int] = None @property def top_level_parent_name(self) -> str: diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 393d92316..02db2d0b6 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from rich.tree import Tree @@ -16,7 +16,7 @@ from enum import Enum, IntEnum from pathlib import Path from re import Pattern -from typing import Annotated, cast +from typing import Annotated, Optional, cast from jedi.api.classes import Name from pydantic import AfterValidator, BaseModel, ConfigDict, Field @@ -77,10 +77,10 @@ class BestOptimization(BaseModel): candidate: OptimizedCandidate helper_functions: list[FunctionSource] runtime: int - replay_performance_gain: dict[BenchmarkKey, float] | None = None + replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None winning_behavioral_test_results: TestResults winning_benchmarking_test_results: TestResults - winning_replay_benchmarking_test_results: TestResults | None = None + winning_replay_benchmarking_test_results: Optional[TestResults] = None @dataclass(frozen=True) @@ -175,7 +175,7 @@ class OptimizedCandidateResult(BaseModel): best_test_runtime: int behavior_test_results: TestResults benchmarking_test_results: TestResults - replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None + replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None optimization_candidate_index: int total_candidate_timing: int @@ -195,10 +195,10 @@ class GeneratedTestsList(BaseModel): class TestFile(BaseModel): instrumented_behavior_file_path: Path benchmarking_file_path: Path = None - original_file_path: Path | None = None - original_source: str | None = None + original_file_path: Optional[Path] = None + original_source: Optional[str] = None test_type: TestType - tests_in_file: list[TestsInFile] | None = None + tests_in_file: Optional[list[TestsInFile]] = None class TestFiles(BaseModel): @@ -241,13 +241,13 @@ def __len__(self) -> int: class OptimizationSet(BaseModel): control: list[OptimizedCandidate] - experiment: list[OptimizedCandidate] | None + experiment: Optional[list[OptimizedCandidate]] @dataclass(frozen=True) class TestsInFile: test_file: Path - test_class: str | None + test_class: Optional[str] test_function: str test_type: TestType @@ -280,10 +280,10 @@ class FunctionParent: class OriginalCodeBaseline(BaseModel): behavioral_test_results: TestResults benchmarking_test_results: TestResults - replay_benchmarking_test_results: dict[BenchmarkKey, TestResults] | None = None + replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None line_profile_results: dict runtime: int - coverage_results: CoverageData | None + coverage_results: Optional[CoverageData] class CoverageStatus(Enum): @@ -302,7 +302,7 @@ class CoverageData: graph: dict[str, dict[str, Collection[object]]] code_context: CodeOptimizationContext main_func_coverage: FunctionCoverage - dependent_func_coverage: FunctionCoverage | None + dependent_func_coverage: Optional[FunctionCoverage] status: CoverageStatus blank_re: Pattern[str] = re.compile(r"\s*(#|$)") else_re: Pattern[str] = re.compile(r"\s*else\s*:\s*(#|$)") @@ -410,10 +410,10 @@ def to_name(self) -> str: @dataclass(frozen=True) class InvocationId: test_module_path: str # The fully qualified name of the test module - test_class_name: str | None # The name of the class where the test is defined - test_function_name: str | None # The name of the test_function. Does not include the components of the file_name + test_class_name: Optional[str] # The name of the class where the test is defined + test_function_name: Optional[str] # The name of the test_function. Does not include the components of the file_name function_getting_tested: str - iteration_id: str | None + iteration_id: Optional[str] # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id def id(self) -> str: @@ -449,13 +449,13 @@ class FunctionTestInvocation: id: InvocationId # The fully qualified name of the function invocation (id) file_name: Path # The file where the test is defined did_pass: bool # Whether the test this function invocation was part of, passed or failed - runtime: int | None # Time in nanoseconds + runtime: Optional[int] # Time in nanoseconds test_framework: str # unittest or pytest test_type: TestType - return_value: object | None # The return value of the function invocation - timed_out: bool | None - verification_type: str | None = VerificationType.FUNCTION_CALL - stdout: str | None = None + return_value: Optional[object] # The return value of the function invocation + timed_out: Optional[bool] + verification_type: Optional[str] = VerificationType.FUNCTION_CALL + stdout: Optional[str] = None @property def unique_invocation_loop_id(self) -> str: From e2f1ba0790ce94240d08cc6162c566af7aaac2ee Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 00:54:06 -0700 Subject: [PATCH 24/32] clearer message --- codeflash/optimization/function_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0929e2ce6..56dd61b27 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -169,7 +169,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 random.random() > REPEAT_OPTIMIZATION_PROBABILITY # noqa: S311 and check_optimization_status(self.function_to_optimize, code_context) ): - return Failure("This function has previously been optimized, skipping.") + return Failure("Function optimization previously attempted, skipping.") code_print(code_context.read_writable_code) generated_test_paths = [ From f6b32750f94123a265ad428923880e8af70ad09b Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 01:03:55 -0700 Subject: [PATCH 25/32] fix mypy types --- codeflash/api/cfapi.py | 2 +- codeflash/context/code_context_extractor.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 16c645f92..87c54b148 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -196,7 +196,7 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: def is_function_being_optimized_again( owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]] -) -> dict: +) -> Any: # noqa: ANN401 """Check if the function being optimized is being optimized again.""" response = make_cfapi_request( "/is-already-optimized", diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index ab8bb730c..e0eed8321 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -4,7 +4,7 @@ import os from collections import defaultdict from itertools import chain -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import libcst as cst @@ -625,20 +625,24 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 if not isinstance(node.body, cst.IndentedBlock): raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value - new_body = [] + new_class_body: list[cst.CSTNode] = [] found_target = False for stmt in node.body.body: if isinstance(stmt, cst.FunctionDef): qualified_name = f"{class_prefix}.{stmt.name.value}" if qualified_name in target_functions: - stmt_with_changes = stmt.with_changes(body=remove_docstring_from_body(stmt.body)) - new_body.append(stmt_with_changes) + stmt_with_changes = stmt.with_changes( + body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body)) + ) + new_class_body.append(stmt_with_changes) found_target = True # If no target functions found, remove the class entirely - if not new_body or not found_target: + if not new_class_body or not found_target: return None, False - return node.with_changes(body=cst.IndentedBlock(new_body)) if new_body else None, found_target + return node.with_changes( + body=cst.IndentedBlock(cast("list[cst.BaseStatement]", new_class_body)) + ) if new_class_body else None, found_target # For other nodes, we preserve them only if they contain target functions in their children. section_names = get_section_names(node) From 6ed93871b48e3e7cb05729752ce8e8c5fef2f5e3 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 01:13:17 -0700 Subject: [PATCH 26/32] add more tests --- tests/test_code_context_extractor.py | 335 +++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 0066c0971..a5b7590bd 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -2137,3 +2137,338 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + + +from __future__ import annotations + + +def test_hashing_code_context_removes_imports_docstrings_and_init() -> None: + """Test that hashing context removes imports, docstrings, and __init__ methods properly.""" + code = ''' +import os +import sys +from pathlib import Path + +class MyClass: + """A class with a docstring.""" + def __init__(self, value): + """Initialize with a value.""" + self.value = value + + def target_method(self): + """Target method with docstring.""" + result = self.helper_method() + helper_cls = HelperClass() + data = helper_cls.process_data() + return self.value * 2 + + def helper_method(self): + """Helper method with docstring.""" + return self.value + 1 + +class HelperClass: + """Helper class docstring.""" + def __init__(self): + """Helper init method.""" + self.data = "test" + + def process_data(self): + """Process data method.""" + return self.data.upper() + +def standalone_function(): + """Standalone function.""" + return "standalone" +''' + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context + + # Expected behavior based on current implementation: + # - Should not contain imports + # - Should remove docstrings from target functions (but currently doesn't - this is a bug) + # - Should not contain __init__ methods + # - Should contain target function and helper methods that are actually called + # - Should be formatted as markdown + + # Test that it's formatted as markdown + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + + # Test basic structure requirements + assert "import" not in hashing_context # Should not contain imports + assert "__init__" not in hashing_context # Should not contain __init__ methods + assert "target_method" in hashing_context # Should contain target function + assert "standalone_function" not in hashing_context # Should not contain unused functions + + # Test that helper functions are included when they're called + assert "helper_method" in hashing_context # Should contain called helper method + assert "process_data" in hashing_context # Should contain called helper method + + # Test for docstring removal (this should pass when implementation is fixed) + # Currently this will fail because docstrings are not being removed properly + assert '"""Target method with docstring."""' not in hashing_context, ( + "Docstrings should be removed from target functions" + ) + assert '"""Helper method with docstring."""' not in hashing_context, ( + "Docstrings should be removed from helper functions" + ) + assert '"""Process data method."""' not in hashing_context, ( + "Docstrings should be removed from helper class methods" + ) + + +def test_hashing_code_context_with_nested_classes() -> None: + """Test that hashing context handles nested classes properly (should exclude them).""" + code = ''' +class OuterClass: + """Outer class docstring.""" + def __init__(self): + """Outer init.""" + self.value = 1 + + def target_method(self): + """Target method.""" + return self.NestedClass().nested_method() + + class NestedClass: + """Nested class - should be excluded.""" + def __init__(self): + self.nested_value = 2 + + def nested_method(self): + return self.nested_value +''' + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="OuterClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context + + # Test basic requirements + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + assert "target_method" in hashing_context + assert "__init__" not in hashing_context # Should not contain __init__ methods + + # Verify nested classes are excluded from the hashing context + # The prune_cst_for_code_hashing function should not recurse into nested classes + assert "class NestedClass:" not in hashing_context # Nested class definition should not be present + + # The target method will reference NestedClass, but the actual nested class definition should not be included + # The call to self.NestedClass().nested_method() should be in the target method but the nested class itself excluded + target_method_call_present = "self.NestedClass().nested_method()" in hashing_context + assert target_method_call_present, "The target method should contain the call to nested class" + + # But the actual nested method definition should not be present + nested_method_definition_present = "def nested_method(self):" in hashing_context + assert not nested_method_definition_present, "Nested method definition should not be present in hashing context" + + +def test_hashing_code_context_hash_consistency() -> None: + """Test that the same code produces the same hash.""" + code = """ +class TestClass: + def target_method(self): + return "test" +""" + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + # Generate context twice + code_ctx1 = get_code_optimization_context(function_to_optimize, opt.args.project_root) + code_ctx2 = get_code_optimization_context(function_to_optimize, opt.args.project_root) + + # Hash should be consistent + assert code_ctx1.hashing_code_context_hash == code_ctx2.hashing_code_context_hash + assert code_ctx1.hashing_code_context == code_ctx2.hashing_code_context + + # Hash should be valid SHA256 + import hashlib + + expected_hash = hashlib.sha256(code_ctx1.hashing_code_context.encode("utf-8")).hexdigest() + assert code_ctx1.hashing_code_context_hash == expected_hash + + +def test_hashing_code_context_different_code_different_hash() -> None: + """Test that different code produces different hashes.""" + code1 = """ +class TestClass: + def target_method(self): + return "test1" +""" + code2 = """ +class TestClass: + def target_method(self): + return "test2" +""" + + with tempfile.NamedTemporaryFile(mode="w") as f1, tempfile.NamedTemporaryFile(mode="w") as f2: + f1.write(code1) + f1.flush() + f2.write(code2) + f2.flush() + + file_path1 = Path(f1.name).resolve() + file_path2 = Path(f2.name).resolve() + + opt1 = Optimizer( + Namespace( + project_root=file_path1.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + opt2 = Optimizer( + Namespace( + project_root=file_path2.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + + function_to_optimize1 = FunctionToOptimize( + function_name="target_method", + file_path=file_path1, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + function_to_optimize2 = FunctionToOptimize( + function_name="target_method", + file_path=file_path2, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx1 = get_code_optimization_context(function_to_optimize1, opt1.args.project_root) + code_ctx2 = get_code_optimization_context(function_to_optimize2, opt2.args.project_root) + + # Different code should produce different hashes + assert code_ctx1.hashing_code_context_hash != code_ctx2.hashing_code_context_hash + assert code_ctx1.hashing_code_context != code_ctx2.hashing_code_context + + +def test_hashing_code_context_format_is_markdown() -> None: + """Test that hashing context is formatted as markdown.""" + code = """ +class SimpleClass: + def simple_method(self): + return 42 +""" + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="simple_method", + file_path=file_path, + parents=[FunctionParent(name="SimpleClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context + + # Should be formatted as markdown code block + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + + # Should contain the relative file path in the markdown header + relative_path = file_path.relative_to(opt.args.project_root) + assert str(relative_path) in hashing_context + + # Should contain the actual code between the markdown markers + lines = hashing_context.strip().split("\n") + assert lines[0].startswith("```python:") + assert lines[-1] == "```" + + # Code should be between the markers + code_lines = lines[1:-1] + code_content = "\n".join(code_lines) + assert "class SimpleClass:" in code_content + assert "def simple_method(self):" in code_content + assert "return 42" in code_content From be1ef9b7d913585f768fc1dc33a28b624653c47f Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 01:14:53 -0700 Subject: [PATCH 27/32] fix for test --- tests/test_code_context_extractor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index a5b7590bd..2d4dd56cb 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -2139,9 +2139,6 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): assert read_only_context.strip() == expected_read_only_context.strip() -from __future__ import annotations - - def test_hashing_code_context_removes_imports_docstrings_and_init() -> None: """Test that hashing context removes imports, docstrings, and __init__ methods properly.""" code = ''' From 91379218d2bc93b5904dc07a7ee4ab22769d1419 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 01:24:24 -0700 Subject: [PATCH 28/32] double the context length --- codeflash/context/code_context_extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index e0eed8321..2971b4e7f 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -32,8 +32,8 @@ def get_code_optimization_context( function_to_optimize: FunctionToOptimize, project_root_path: Path, - optim_token_limit: int = 8000, - testgen_token_limit: int = 8000, + optim_token_limit: int = 16000, + testgen_token_limit: int = 16000, ) -> CodeOptimizationContext: # Get FunctionSource representation of helpers of FTO helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi( From 797cba3dbf2c92150d3a4154fa48a4bc84e38636 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 14:19:03 -0700 Subject: [PATCH 29/32] ruff revert --- codeflash/optimization/function_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 56dd61b27..8688a99a8 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -279,7 +279,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 # adding to control and experiment set but with same traceid best_optimization = None for _u, (candidates, exp_type) in enumerate( - zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"], strict=False) + zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"]) ): if candidates is None: continue @@ -1299,7 +1299,7 @@ def generate_and_instrument_tests( test_perf_path, ) for test_index, (test_path, test_perf_path) in enumerate( - zip(generated_test_paths, generated_perf_test_paths, strict=False) + zip(generated_test_paths, generated_perf_test_paths) ) ] From d0f84f6b3bf43a1f11597c9a835940cd1557eea1 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 14:47:51 -0700 Subject: [PATCH 30/32] improve some github actions logging --- codeflash/api/aiservice.py | 4 +- codeflash/cli_cmds/console.py | 40 ++++++++++++++------ codeflash/optimization/function_optimizer.py | 2 + codeflash/optimization/optimizer.py | 7 +++- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index ed61e8c58..f7c5a425f 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -118,7 +118,7 @@ def optimize_python_code( # noqa: D417 if response.status_code == 200: optimizations_json = response.json()["optimizations"] - logger.info(f"Generated {len(optimizations_json)} candidates.") + logger.info(f"Generated {len(optimizations_json)} candidate optimizations.") console.rule() end_time = time.perf_counter() logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.") @@ -189,7 +189,7 @@ def optimize_python_code_line_profiler( # noqa: D417 if response.status_code == 200: optimizations_json = response.json()["optimizations"] - logger.info(f"Generated {len(optimizations_json)} candidates.") + logger.info(f"Generated {len(optimizations_json)} candidate optimizations.") console.rule() return [ OptimizedCandidate( diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index fe2fdcdd1..34d50f268 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -66,18 +66,34 @@ def code_print(code_str: str) -> None: @contextmanager -def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]: - """Display a progress bar with a spinner and elapsed time.""" - progress = Progress( - SpinnerColumn(next(spinners)), - *Progress.get_default_columns(), - TimeElapsedColumn(), - console=console, - transient=transient, - ) - task = progress.add_task(message, total=None) - with progress: - yield task +def progress_bar( + message: str, *, transient: bool = False, revert_to_print: bool = False +) -> Generator[TaskID, None, None]: + """Display a progress bar with a spinner and elapsed time. + + If revert_to_print is True, falls back to printing a single logger.info message + instead of showing a progress bar. + """ + if revert_to_print: + logger.info(message) + + # Create a fake task ID since we still need to yield something + class DummyTask: + def __init__(self) -> None: + self.id = 0 + + yield DummyTask().id + else: + progress = Progress( + SpinnerColumn(next(spinners)), + *Progress.get_default_columns(), + TimeElapsedColumn(), + console=console, + transient=transient, + ) + task = progress.add_task(message, total=None) + with progress: + yield task @contextmanager diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 8688a99a8..8538e4ccb 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -48,6 +48,7 @@ add_runtime_comments_to_generated_tests, remove_functions_from_generated_tests, ) +from codeflash.code_utils.env_utils import get_pr_number from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports @@ -188,6 +189,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 with progress_bar( f"Generating new tests and optimizations for function {self.function_to_optimize.function_name}", transient=True, + revert_to_print=bool(get_pr_number()), ): generated_results = self.generate_tests_and_optimizations( testgen_context_code=code_context.testgen_context_code, diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 55ab14c35..66ab3a0e4 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -11,6 +11,7 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.env_utils import get_pr_number from codeflash.either import is_successful from codeflash.models.models import ValidCode from codeflash.telemetry.posthog_cf import ph @@ -110,7 +111,11 @@ def run(self) -> None: from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table - with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True): + with progress_bar( + f"Running benchmarks in {self.args.benchmarks_root}", + transient=True, + revert_to_print=bool(get_pr_number()), + ): # Insert decorator file_path_to_source_code = defaultdict(str) for file in file_to_funcs_to_optimize: From 2d62171b235da02fceb579e88b309531f7ce5945 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 23:10:44 -0700 Subject: [PATCH 31/32] some refactor --- codeflash/discovery/functions_to_optimize.py | 4 +++- codeflash/optimization/function_optimizer.py | 7 +++---- codeflash/optimization/optimizer.py | 1 + 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3fca62806..32aadb0ce 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -419,7 +419,9 @@ def inspect_top_level_functions_or_methods( ) -def check_optimization_status(function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext) -> bool: +def was_function_previously_optimized( + function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext +) -> bool: """Check which functions have already been optimized and filter them out. This function calls the optimization API to: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 8538e4ccb..4edbf8974 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -56,7 +56,7 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions -from codeflash.discovery.functions_to_optimize import check_optimization_status +from codeflash.discovery.functions_to_optimize import was_function_previously_optimized from codeflash.either import Failure, Success, is_successful from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( @@ -166,9 +166,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 # Random here means that we still attempt optimization with a fractional chance to see if # last time we could not find an optimization, maybe this time we do. # Random is before as a performance optimization, swapping the two 'and' statements has the same effect - if ( - random.random() > REPEAT_OPTIMIZATION_PROBABILITY # noqa: S311 - and check_optimization_status(self.function_to_optimize, code_context) + if random.random() > REPEAT_OPTIMIZATION_PROBABILITY and was_function_previously_optimized( # noqa: S311 + self.function_to_optimize, code_context ): return Failure("Function optimization previously attempted, skipping.") diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 66ab3a0e4..0401efe31 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -111,6 +111,7 @@ def run(self) -> None: from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table + console.rule() with progress_bar( f"Running benchmarks in {self.args.benchmarks_root}", transient=True, From 226acd772795798957c931dc99c9974c9eeda3cb Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 8 Jun 2025 23:16:20 -0700 Subject: [PATCH 32/32] remove unncessary line --- codeflash/discovery/functions_to_optimize.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 32aadb0ce..c50a0ad49 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -496,7 +496,6 @@ def filter_functions( site_packages_removed_count: int = 0 ignore_paths_removed_count: int = 0 malformed_paths_count: int = 0 - already_optimized_count: int = 0 submodule_ignored_paths_count: int = 0 blocklist_funcs_removed_count: int = 0 previous_checkpoint_functions_removed_count: int = 0 @@ -566,7 +565,6 @@ def filter_functions( f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count, f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count, f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count, - f"{already_optimized_count} already optimized function{'s' if already_optimized_count != 1 else ''}": already_optimized_count, f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count, f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count, }