From 905d34e524ffc23b5202fa72267685c4161e1740 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 3 May 2025 07:47:59 -0400 Subject: [PATCH 1/5] implement test caching --- codeflash/code_utils/compat.py | 5 + codeflash/discovery/discover_unit_tests.py | 216 ++++++++++++++------- 2 files changed, 154 insertions(+), 67 deletions(-) diff --git a/codeflash/code_utils/compat.py b/codeflash/code_utils/compat.py index 8bdf093bb..a6b2f5805 100644 --- a/codeflash/code_utils/compat.py +++ b/codeflash/code_utils/compat.py @@ -2,6 +2,8 @@ import sys from pathlib import Path +from platformdirs import user_config_dir + # os-independent newline # important for any user-facing output or files we write # make sure to use this in f-strings e.g. f"some string{LF}" @@ -12,3 +14,6 @@ SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() IS_POSIX = os.name != "nt" + + +codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True)) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index e26680e1a..de1635173 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -1,8 +1,10 @@ from __future__ import annotations +import hashlib import os import pickle import re +import sqlite3 import subprocess import unittest from collections import defaultdict @@ -15,7 +17,7 @@ from codeflash.cli_cmds.console import console, logger, test_files_progress_bar from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path -from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE +from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_dir from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType if TYPE_CHECKING: @@ -37,13 +39,100 @@ class TestFunction: FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$") +class TestsCache: + def __init__(self) -> None: + self.connection = sqlite3.connect(codeflash_cache_dir / "tests_cache.db") + self.cur = self.connection.cursor() + + self.cur.execute( + """ + CREATE TABLE IF NOT EXISTS discovered_tests( + file_path TEXT, + file_hash TEXT, + qualified_name_with_modules_from_root TEXT, + function_name TEXT, + test_class TEXT, + test_function TEXT, + test_type TEXT, + line_number INTEGER, + col_number INTEGER + ) + """ + ) + self.cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash + ON discovered_tests (file_path, file_hash) + """ + ) + self._memory_cache = {} + + def insert_test( + self, + file_path: str, + file_hash: str, + qualified_name_with_modules_from_root: str, + function_name: str, + test_class: str, + test_function: str, + test_type: TestType, + line_number: int, + col_number: int, + ) -> None: + test_type_value = test_type.value if hasattr(test_type, "value") else test_type + self.cur.execute( + "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + file_path, + file_hash, + qualified_name_with_modules_from_root, + function_name, + test_class, + test_function, + test_type_value, + line_number, + col_number, + ), + ) + self.connection.commit() + + def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]: + cache_key = (file_path, file_hash) + if cache_key in self._memory_cache: + return self._memory_cache[cache_key] + self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash)) + result = [ + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6])) + ), + position=CodePosition(line_no=row[7], col_no=row[8]), + ) + for row in self.cur.fetchall() + ] + self._memory_cache[cache_key] = result + return result + + @staticmethod + def compute_file_hash(path: str) -> str: + h = hashlib.md5(usedforsecurity=False) + with Path(path).open("rb") as f: + while True: + chunk = f.read(8192) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + + def close(self) -> None: + self.cur.close() + self.connection.close() + + def discover_unit_tests( cfg: TestConfig, discover_only_these_tests: list[Path] | None = None ) -> dict[str, list[FunctionCalledInTest]]: - framework_strategies: dict[str, Callable] = { - "pytest": discover_tests_pytest, - "unittest": discover_tests_unittest, - } + framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} strategy = framework_strategies.get(cfg.test_framework, None) if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" @@ -54,7 +143,7 @@ def discover_unit_tests( def discover_tests_pytest( cfg: TestConfig, discover_only_these_tests: list[Path] | None = None -) -> dict[str, list[FunctionCalledInTest]]: +) -> dict[Path, list[FunctionCalledInTest]]: tests_root = cfg.tests_root project_root = cfg.project_root_path @@ -91,9 +180,7 @@ def discover_tests_pytest( ) elif 0 <= exitcode <= 5: - logger.warning( - f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}" - ) + logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}") else: logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}") console.rule() @@ -101,7 +188,7 @@ def discover_tests_pytest( logger.debug(f"Pytest collection exit code: {exitcode}") if pytest_rootdir is not None: cfg.tests_project_rootdir = Path(pytest_rootdir) - file_to_test_map = defaultdict(list) + file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list) for test in tests: if "__replay_test" in test["test_file"]: test_type = TestType.REPLAY_TEST @@ -116,10 +203,7 @@ def discover_tests_pytest( test_function=test["test_function"], test_type=test_type, ) - if ( - discover_only_these_tests - and test_obj.test_file not in discover_only_these_tests - ): + if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests: continue file_to_test_map[test_obj.test_file].append(test_obj) # Within these test files, find the project functions they are referring to and return their names/locations @@ -128,7 +212,7 @@ def discover_tests_pytest( def discover_tests_unittest( cfg: TestConfig, discover_only_these_tests: list[str] | None = None -) -> dict[str, list[FunctionCalledInTest]]: +) -> dict[Path, list[FunctionCalledInTest]]: tests_root: Path = cfg.tests_root loader: unittest.TestLoader = unittest.TestLoader() tests: unittest.TestSuite = loader.discover(str(tests_root)) @@ -144,8 +228,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: _test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py") _test_module_path = tests_root / _test_module_path if not _test_module_path.exists() or ( - discover_only_these_tests - and str(_test_module_path) not in discover_only_these_tests + discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests ): return None if "__replay_test" in str(_test_module_path): @@ -172,9 +255,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"): for test_2 in test._tests: if not hasattr(test_2, "_testMethodName"): - logger.warning( - f"Didn't find tests for {test_2}" - ) # it goes deeper? + logger.warning(f"Didn't find tests for {test_2}") # it goes deeper? continue details = get_test_details(test_2) if details is not None: @@ -195,19 +276,35 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N def process_test_files( - file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig + file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig ) -> dict[str, list[FunctionCalledInTest]]: project_root_path = cfg.project_root_path test_framework = cfg.test_framework + function_to_test_map = defaultdict(set) jedi_project = jedi.Project(path=project_root_path) goto_cache = {} + tests_cache = TestsCache() - with test_files_progress_bar( - total=len(file_to_test_map), description="Processing test files" - ) as (progress, task_id): - + with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( + progress, + task_id, + ): for test_file, functions in file_to_test_map.items(): + file_hash = TestsCache.compute_file_hash(test_file) + cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) + if cached_tests: + self_cur = tests_cache.cur + self_cur.execute( + "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?", + (str(test_file), file_hash), + ) + qualified_names = [row[0] for row in self_cur.fetchall()] + for cached, qualified_name in zip(cached_tests, qualified_names): + function_to_test_map[qualified_name].add(cached) + progress.advance(task_id) + continue + try: script = jedi.Script(path=test_file, project=jedi_project) test_functions = set() @@ -216,12 +313,8 @@ def process_test_files( all_defs = script.get_names(all_scopes=True, definitions=True) all_names_top = script.get_names(all_scopes=True) - top_level_functions = { - name.name: name for name in all_names_top if name.type == "function" - } - top_level_classes = { - name.name: name for name in all_names_top if name.type == "class" - } + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} except Exception as e: logger.debug(f"Failed to get jedi script for {test_file}: {e}") progress.advance(task_id) @@ -230,36 +323,18 @@ def process_test_files( if test_framework == "pytest": for function in functions: if "[" in function.test_function: - function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( - function.test_function - )[0] - parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( - function.test_function - )[1] + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] if function_name in top_level_functions: test_functions.add( - TestFunction( - function_name, - function.test_class, - parameters, - function.test_type, - ) + TestFunction(function_name, function.test_class, parameters, function.test_type) ) elif function.test_function in top_level_functions: test_functions.add( - TestFunction( - function.test_function, - function.test_class, - None, - function.test_type, - ) - ) - elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match( - function.test_function - ): - base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub( - "", function.test_function + TestFunction(function.test_function, function.test_class, None, function.test_type) ) + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) if base_name in top_level_functions: test_functions.add( TestFunction( @@ -283,9 +358,7 @@ def process_test_files( and f".{matched_name}." in def_name.full_name ): for function in functions_to_search: - (is_parameterized, new_function, parameters) = ( - discover_parameters_unittest(function) - ) + (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) if is_parameterized and new_function == def_name.name: test_functions.add( @@ -329,9 +402,7 @@ def process_test_files( if cache_key in goto_cache: definition = goto_cache[cache_key] else: - definition = name.goto( - follow_imports=True, follow_builtin_imports=False - ) + definition = name.goto(follow_imports=True, follow_builtin_imports=False) goto_cache[cache_key] = definition except Exception as e: logger.debug(str(e)) @@ -358,11 +429,23 @@ def process_test_files( if test_framework == "unittest": scope_test_function += "_" + scope_parameters - full_name_without_module_prefix = definition[ - 0 - ].full_name.replace(definition[0].module_name + ".", "", 1) + full_name_without_module_prefix = definition[0].full_name.replace( + definition[0].module_name + ".", "", 1 + ) qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + tests_cache.insert_test( + file_path=str(test_file), + file_hash=file_hash, + qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, + function_name=scope, + test_class=scope_test_class, + test_function=scope_test_function, + test_type=test_type, + line_number=name.line, + col_number=name.column, + ) + function_to_test_map[qualified_name_with_modules_from_root].add( FunctionCalledInTest( tests_in_file=TestsInFile( @@ -371,12 +454,11 @@ def process_test_files( test_function=scope_test_function, test_type=test_type, ), - position=CodePosition( - line_no=name.line, col_no=name.column - ), + position=CodePosition(line_no=name.line, col_no=name.column), ) ) progress.advance(task_id) + tests_cache.close() return {function: list(tests) for function, tests in function_to_test_map.items()} From 6addeccbdcd020b005a49cd0c1bc21fe4dfea6fd Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 4 May 2025 17:51:39 -0500 Subject: [PATCH 2/5] consolidate logic around path cleanup --- codeflash/code_utils/code_utils.py | 6 +++- codeflash/optimization/function_optimizer.py | 36 ++++++++++++-------- codeflash/optimization/optimizer.py | 20 ++--------- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 3ae28c65b..6d98194a1 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -2,6 +2,7 @@ import ast import os +import shutil import site from functools import lru_cache from pathlib import Path @@ -118,4 +119,7 @@ def has_any_async_functions(code: str) -> bool: def cleanup_paths(paths: list[Path]) -> None: for path in paths: - path.unlink(missing_ok=True) + if path.is_dir(): + shutil.rmtree(path, ignore_errors=True) + else: + path.unlink(missing_ok=True) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 786c4afb4..56124a9cb 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -345,20 +345,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: original_helper_code, self.function_to_optimize.file_path, ) - for generated_test_path in generated_test_paths: - generated_test_path.unlink(missing_ok=True) - for generated_perf_test_path in generated_perf_test_paths: - generated_perf_test_path.unlink(missing_ok=True) - for test_paths in instrumented_unittests_created_for_function: - test_paths.unlink(missing_ok=True) - for fn in function_to_concolic_tests: - for test in function_to_concolic_tests[fn]: - if not test.tests_in_file.test_file.parent.exists(): - logger.warning( - f"Concolic test directory {test.tests_in_file.test_file.parent} does not exist so could not be deleted." - ) - shutil.rmtree(test.tests_in_file.test_file.parent, ignore_errors=True) - break # need to delete only one test directory if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") @@ -1242,3 +1228,25 @@ def generate_and_instrument_tests( zip(generated_test_paths, generated_perf_test_paths) ) ] + + def cleanup_generated_files(self) -> None: + paths_to_cleanup = ( + [ + test_file.instrumented_behavior_file_path + for test_type in [ + TestType.GENERATED_REGRESSION, + TestType.EXISTING_UNIT_TEST, + TestType.CONCOLIC_COVERAGE_TEST, + ] + for test_file in self.test_files.get_by_type(test_type).test_files + ] + + [ + test_file.benchmarking_file_path + for test_type in [TestType.GENERATED_REGRESSION, TestType.EXISTING_UNIT_TEST] + for test_file in self.test_files.get_by_type(test_type).test_files + ] + + [self.test_cfg.concolic_test_root_dir] + ) + cleanup_paths(paths_to_cleanup) + if hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir.cleanup() diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 1e1f98435..10d21def5 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -2,7 +2,6 @@ import ast import os -import shutil import tempfile import time from collections import defaultdict @@ -19,7 +18,7 @@ from codeflash.code_utils import env_utils from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, ask_should_use_checkpoint_get_functions from codeflash.code_utils.code_replacer import normalize_code, normalize_node -from codeflash.code_utils.code_utils import get_run_tmp_file +from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.functions_to_optimize import get_functions_to_optimize @@ -266,22 +265,7 @@ def run(self) -> None: logger.info("✨ All functions have been optimized! ✨") finally: if function_optimizer: - for test_file in function_optimizer.test_files.get_by_type(TestType.GENERATED_REGRESSION).test_files: - test_file.instrumented_behavior_file_path.unlink(missing_ok=True) - test_file.benchmarking_file_path.unlink(missing_ok=True) - for test_file in function_optimizer.test_files.get_by_type(TestType.EXISTING_UNIT_TEST).test_files: - test_file.instrumented_behavior_file_path.unlink(missing_ok=True) - test_file.benchmarking_file_path.unlink(missing_ok=True) - for test_file in function_optimizer.test_files.get_by_type(TestType.CONCOLIC_COVERAGE_TEST).test_files: - test_file.instrumented_behavior_file_path.unlink(missing_ok=True) - if function_optimizer.test_cfg.concolic_test_root_dir: - shutil.rmtree(function_optimizer.test_cfg.concolic_test_root_dir, ignore_errors=True) - if self.args.benchmark: - if self.replay_tests_dir.exists(): - shutil.rmtree(self.replay_tests_dir, ignore_errors=True) - trace_file.unlink(missing_ok=True) - if hasattr(get_run_tmp_file, "tmpdir"): - get_run_tmp_file.tmpdir.cleanup() + function_optimizer.cleanup_generated_files() def run_with_args(args: Namespace) -> None: From 3d2676983eeda31296fb22f87c6b27688f2608e7 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 4 May 2025 23:33:27 -0500 Subject: [PATCH 3/5] PR review feedback --- codeflash/code_utils/compat.py | 2 ++ codeflash/discovery/discover_unit_tests.py | 7 ++++--- pyproject.toml | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/compat.py b/codeflash/code_utils/compat.py index a6b2f5805..b62f93d4c 100644 --- a/codeflash/code_utils/compat.py +++ b/codeflash/code_utils/compat.py @@ -17,3 +17,5 @@ codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True)) + +codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db" diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index de1635173..f729b90d2 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -17,7 +17,7 @@ from codeflash.cli_cmds.console import console, logger, test_files_progress_bar from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path -from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_dir +from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType if TYPE_CHECKING: @@ -41,7 +41,7 @@ class TestFunction: class TestsCache: def __init__(self) -> None: - self.connection = sqlite3.connect(codeflash_cache_dir / "tests_cache.db") + self.connection = sqlite3.connect(codeflash_cache_db) self.cur = self.connection.cursor() self.cur.execute( @@ -79,6 +79,7 @@ def insert_test( line_number: int, col_number: int, ) -> None: + self.cur.execute("DELETE FROM discovered_tests WHERE file_path = ?", (file_path,)) test_type_value = test_type.value if hasattr(test_type, "value") else test_type self.cur.execute( "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", @@ -115,7 +116,7 @@ def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCal @staticmethod def compute_file_hash(path: str) -> str: - h = hashlib.md5(usedforsecurity=False) + h = hashlib.sha256(usedforsecurity=False) with Path(path).open("rb") as f: while True: chunk = f.read(8192) diff --git a/pyproject.toml b/pyproject.toml index c42599db9..e320feac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ lxml = ">=5.3.0" crosshair-tool = ">=0.0.78" coverage = ">=7.6.4" line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13 +platformdirs = "^4.3.7" [tool.poetry.group.dev] optional = true From e2166e310bbe5e7a017d09d09a4b6c4fb937e157 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 4 May 2025 23:54:39 -0500 Subject: [PATCH 4/5] Update code_utils.py --- codeflash/code_utils/code_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 6d98194a1..f63756d98 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -119,7 +119,8 @@ def has_any_async_functions(code: str) -> bool: def cleanup_paths(paths: list[Path]) -> None: for path in paths: - if path.is_dir(): - shutil.rmtree(path, ignore_errors=True) - else: - path.unlink(missing_ok=True) + if path and path.exists(): + if path.is_dir(): + shutil.rmtree(path, ignore_errors=True) + else: + path.unlink(missing_ok=True) From 39f1e7066061d14c77754c0c68517e61890e0c00 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Tue, 6 May 2025 07:53:04 -0700 Subject: [PATCH 5/5] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e320feac0..15dc01098 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ lxml = ">=5.3.0" crosshair-tool = ">=0.0.78" coverage = ">=7.6.4" line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13 -platformdirs = "^4.3.7" +platformdirs = ">=4.3.7" [tool.poetry.group.dev] optional = true