Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import os
import shutil
import site
from functools import lru_cache
from pathlib import Path
Expand Down Expand Up @@ -118,4 +119,8 @@ def has_any_async_functions(code: str) -> bool:

def cleanup_paths(paths: list[Path]) -> None:
for path in paths:
path.unlink(missing_ok=True)
if path and path.exists():
if path.is_dir():
shutil.rmtree(path, ignore_errors=True)
else:
path.unlink(missing_ok=True)
7 changes: 7 additions & 0 deletions codeflash/code_utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
from pathlib import Path

from platformdirs import user_config_dir

# os-independent newline
# important for any user-facing output or files we write
# make sure to use this in f-strings e.g. f"some string{LF}"
Expand All @@ -12,3 +14,8 @@
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()

IS_POSIX = os.name != "nt"


codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))

codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db"
217 changes: 150 additions & 67 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import hashlib
import os
import pickle
import re
import sqlite3
import subprocess
import unittest
from collections import defaultdict
Expand All @@ -15,7 +17,7 @@

from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType

if TYPE_CHECKING:
Expand All @@ -37,13 +39,101 @@ class TestFunction:
FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")


class TestsCache:
def __init__(self) -> None:
self.connection = sqlite3.connect(codeflash_cache_db)
self.cur = self.connection.cursor()

self.cur.execute(
"""
CREATE TABLE IF NOT EXISTS discovered_tests(
file_path TEXT,
file_hash TEXT,
qualified_name_with_modules_from_root TEXT,
function_name TEXT,
test_class TEXT,
test_function TEXT,
test_type TEXT,
line_number INTEGER,
col_number INTEGER
)
"""
)
self.cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
ON discovered_tests (file_path, file_hash)
"""
)
self._memory_cache = {}

def insert_test(
self,
file_path: str,
file_hash: str,
qualified_name_with_modules_from_root: str,
function_name: str,
test_class: str,
test_function: str,
test_type: TestType,
line_number: int,
col_number: int,
) -> None:
self.cur.execute("DELETE FROM discovered_tests WHERE file_path = ?", (file_path,))
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
self.cur.execute(
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
file_path,
file_hash,
qualified_name_with_modules_from_root,
function_name,
test_class,
test_function,
test_type_value,
line_number,
col_number,
),
)
self.connection.commit()

def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
cache_key = (file_path, file_hash)
if cache_key in self._memory_cache:
return self._memory_cache[cache_key]
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
result = [
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
),
position=CodePosition(line_no=row[7], col_no=row[8]),
)
for row in self.cur.fetchall()
]
self._memory_cache[cache_key] = result
return result

@staticmethod
def compute_file_hash(path: str) -> str:
h = hashlib.sha256(usedforsecurity=False)
with Path(path).open("rb") as f:
while True:
chunk = f.read(8192)
if not chunk:
break
h.update(chunk)
return h.hexdigest()

def close(self) -> None:
self.cur.close()
self.connection.close()


def discover_unit_tests(
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
) -> dict[str, list[FunctionCalledInTest]]:
framework_strategies: dict[str, Callable] = {
"pytest": discover_tests_pytest,
"unittest": discover_tests_unittest,
}
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
strategy = framework_strategies.get(cfg.test_framework, None)
if not strategy:
error_message = f"Unsupported test framework: {cfg.test_framework}"
Expand All @@ -54,7 +144,7 @@ def discover_unit_tests(

def discover_tests_pytest(
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
) -> dict[str, list[FunctionCalledInTest]]:
) -> dict[Path, list[FunctionCalledInTest]]:
tests_root = cfg.tests_root
project_root = cfg.project_root_path

Expand Down Expand Up @@ -91,17 +181,15 @@ def discover_tests_pytest(
)

elif 0 <= exitcode <= 5:
logger.warning(
f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}"
)
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}")
else:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
console.rule()
else:
logger.debug(f"Pytest collection exit code: {exitcode}")
if pytest_rootdir is not None:
cfg.tests_project_rootdir = Path(pytest_rootdir)
file_to_test_map = defaultdict(list)
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
for test in tests:
if "__replay_test" in test["test_file"]:
test_type = TestType.REPLAY_TEST
Expand All @@ -116,10 +204,7 @@ def discover_tests_pytest(
test_function=test["test_function"],
test_type=test_type,
)
if (
discover_only_these_tests
and test_obj.test_file not in discover_only_these_tests
):
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
continue
file_to_test_map[test_obj.test_file].append(test_obj)
# Within these test files, find the project functions they are referring to and return their names/locations
Expand All @@ -128,7 +213,7 @@ def discover_tests_pytest(

def discover_tests_unittest(
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
) -> dict[str, list[FunctionCalledInTest]]:
) -> dict[Path, list[FunctionCalledInTest]]:
tests_root: Path = cfg.tests_root
loader: unittest.TestLoader = unittest.TestLoader()
tests: unittest.TestSuite = loader.discover(str(tests_root))
Expand All @@ -144,8 +229,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
_test_module_path = tests_root / _test_module_path
if not _test_module_path.exists() or (
discover_only_these_tests
and str(_test_module_path) not in discover_only_these_tests
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
):
return None
if "__replay_test" in str(_test_module_path):
Expand All @@ -172,9 +256,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"):
for test_2 in test._tests:
if not hasattr(test_2, "_testMethodName"):
logger.warning(
f"Didn't find tests for {test_2}"
) # it goes deeper?
logger.warning(f"Didn't find tests for {test_2}") # it goes deeper?
continue
details = get_test_details(test_2)
if details is not None:
Expand All @@ -195,19 +277,35 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N


def process_test_files(
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
) -> dict[str, list[FunctionCalledInTest]]:
project_root_path = cfg.project_root_path
test_framework = cfg.test_framework

function_to_test_map = defaultdict(set)
jedi_project = jedi.Project(path=project_root_path)
goto_cache = {}
tests_cache = TestsCache()

with test_files_progress_bar(
total=len(file_to_test_map), description="Processing test files"
) as (progress, task_id):

with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
progress,
task_id,
):
for test_file, functions in file_to_test_map.items():
file_hash = TestsCache.compute_file_hash(test_file)
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
if cached_tests:
self_cur = tests_cache.cur
self_cur.execute(
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
(str(test_file), file_hash),
)
qualified_names = [row[0] for row in self_cur.fetchall()]
for cached, qualified_name in zip(cached_tests, qualified_names):
function_to_test_map[qualified_name].add(cached)
progress.advance(task_id)
continue

try:
script = jedi.Script(path=test_file, project=jedi_project)
test_functions = set()
Expand All @@ -216,12 +314,8 @@ def process_test_files(
all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)

top_level_functions = {
name.name: name for name in all_names_top if name.type == "function"
}
top_level_classes = {
name.name: name for name in all_names_top if name.type == "class"
}
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
except Exception as e:
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
progress.advance(task_id)
Expand All @@ -230,36 +324,18 @@ def process_test_files(
if test_framework == "pytest":
for function in functions:
if "[" in function.test_function:
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
function.test_function
)[0]
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
function.test_function
)[1]
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
if function_name in top_level_functions:
test_functions.add(
TestFunction(
function_name,
function.test_class,
parameters,
function.test_type,
)
TestFunction(function_name, function.test_class, parameters, function.test_type)
)
elif function.test_function in top_level_functions:
test_functions.add(
TestFunction(
function.test_function,
function.test_class,
None,
function.test_type,
)
)
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(
function.test_function
):
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub(
"", function.test_function
TestFunction(function.test_function, function.test_class, None, function.test_type)
)
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
if base_name in top_level_functions:
test_functions.add(
TestFunction(
Expand All @@ -283,9 +359,7 @@ def process_test_files(
and f".{matched_name}." in def_name.full_name
):
for function in functions_to_search:
(is_parameterized, new_function, parameters) = (
discover_parameters_unittest(function)
)
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)

if is_parameterized and new_function == def_name.name:
test_functions.add(
Expand Down Expand Up @@ -329,9 +403,7 @@ def process_test_files(
if cache_key in goto_cache:
definition = goto_cache[cache_key]
else:
definition = name.goto(
follow_imports=True, follow_builtin_imports=False
)
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
goto_cache[cache_key] = definition
except Exception as e:
logger.debug(str(e))
Expand All @@ -358,11 +430,23 @@ def process_test_files(
if test_framework == "unittest":
scope_test_function += "_" + scope_parameters

full_name_without_module_prefix = definition[
0
].full_name.replace(definition[0].module_name + ".", "", 1)
full_name_without_module_prefix = definition[0].full_name.replace(
definition[0].module_name + ".", "", 1
)
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"

tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
function_name=scope,
test_class=scope_test_class,
test_function=scope_test_function,
test_type=test_type,
line_number=name.line,
col_number=name.column,
)

function_to_test_map[qualified_name_with_modules_from_root].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
Expand All @@ -371,12 +455,11 @@ def process_test_files(
test_function=scope_test_function,
test_type=test_type,
),
position=CodePosition(
line_no=name.line, col_no=name.column
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
)

progress.advance(task_id)

tests_cache.close()
return {function: list(tests) for function, tests in function_to_test_map.items()}
Loading
Loading