From 8fd505d961473b0d5cd64411533a258483bc6527 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 25 Mar 2025 17:26:37 -0500 Subject: [PATCH 1/8] Update discover_unit_tests.py --- codeflash/discovery/discover_unit_tests.py | 155 ++++++++++++++------- 1 file changed, 101 insertions(+), 54 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 02ae2e4c1..4b132b060 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -34,7 +34,10 @@ class TestFunction: def discover_unit_tests( cfg: TestConfig, discover_only_these_tests: list[Path] | None = None ) -> dict[str, list[FunctionCalledInTest]]: - framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} + framework_strategies: dict[str, Callable] = { + "pytest": discover_tests_pytest, + "unittest": discover_tests_unittest, + } strategy = framework_strategies.get(cfg.test_framework, None) if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" @@ -82,7 +85,9 @@ def discover_tests_pytest( ) elif 0 <= exitcode <= 5: - logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}") + logger.warning( + f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}" + ) else: logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}") console.rule() @@ -105,7 +110,10 @@ def discover_tests_pytest( test_function=test["test_function"], test_type=test_type, ) - if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests: + if ( + discover_only_these_tests + and test_obj.test_file not in discover_only_these_tests + ): continue file_to_test_map[test_obj.test_file].append(test_obj) # Within these test files, find the project functions they are referring to and return their names/locations @@ -130,7 +138,8 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: _test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py") _test_module_path = tests_root / _test_module_path if not _test_module_path.exists() or ( - discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests + discover_only_these_tests + and str(_test_module_path) not in discover_only_these_tests ): return None if "__replay_test" in str(_test_module_path): @@ -157,7 +166,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"): for test_2 in test._tests: if not hasattr(test_2, "_testMethodName"): - logger.warning(f"Didn't find tests for {test_2}") # it goes deeper? + logger.warning( + f"Didn't find tests for {test_2}" + ) # it goes deeper? continue details = get_test_details(test_2) if details is not None: @@ -182,8 +193,9 @@ def process_test_files( ) -> dict[str, list[FunctionCalledInTest]]: project_root_path = cfg.project_root_path test_framework = cfg.test_framework - function_to_test_map = defaultdict(list) + function_to_test_map = defaultdict(set) jedi_project = jedi.Project(path=project_root_path) + goto_cache = {} for test_file, functions in file_to_test_map.items(): try: @@ -194,8 +206,12 @@ def process_test_files( all_defs = script.get_names(all_scopes=True, definitions=True) all_names_top = script.get_names(all_scopes=True) - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} + top_level_functions = { + name.name: name for name in all_names_top if name.type == "function" + } + top_level_classes = { + name.name: name for name in all_names_top if name.type == "class" + } except Exception as e: logger.debug(f"Failed to get jedi script for {test_file}: {e}") continue @@ -207,11 +223,21 @@ def process_test_files( parameters = re.split(r"[\[\]]", function.test_function)[1] if function_name in top_level_functions: test_functions.add( - TestFunction(function_name, function.test_class, parameters, function.test_type) + TestFunction( + function_name, + function.test_class, + parameters, + function.test_type, + ) ) elif function.test_function in top_level_functions: test_functions.add( - TestFunction(function.test_function, function.test_class, None, function.test_type) + TestFunction( + function.test_function, + function.test_class, + None, + function.test_type, + ) ) elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function): # Try to match parameterized unittest functions here, although we can't get the parameters. @@ -229,7 +255,7 @@ def process_test_files( elif test_framework == "unittest": functions_to_search = [elem.test_function for elem in functions] - test_suites = [elem.test_class for elem in functions] + test_suites = {elem.test_class for elem in functions} matching_names = test_suites & top_level_classes.keys() for matched_name in matching_names: @@ -240,7 +266,9 @@ def process_test_files( and f".{matched_name}." in def_name.full_name ): for function in functions_to_search: - (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) + (is_parameterized, new_function, parameters) = ( + discover_parameters_unittest(function) + ) if is_parameterized and new_function == def_name.name: test_functions.add( @@ -264,53 +292,72 @@ def process_test_files( test_functions_list = list(test_functions) test_functions_raw = [elem.function_name for elem in test_functions_list] + test_functions_by_name = defaultdict(list) + for i, func_name in enumerate(test_functions_raw): + test_functions_by_name[func_name].append(i) + for name in all_names: if name.full_name is None: continue m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name) if not m: continue + scope = m.group(1) - indices = [i for i, x in enumerate(test_functions_raw) if x == scope] - for index in indices: - scope_test_function = test_functions_list[index].function_name - scope_test_class = test_functions_list[index].test_class - scope_parameters = test_functions_list[index].parameters - test_type = test_functions_list[index].test_type - try: - definition = name.goto(follow_imports=True, follow_builtin_imports=False) - except Exception as e: - logger.debug(str(e)) - continue - if definition and definition[0].type == "function": - definition_path = str(definition[0].module_path) - # The definition is part of this project and not defined within the original function - if ( - definition_path.startswith(str(project_root_path) + os.sep) - and definition[0].module_name != name.module_name - and definition[0].full_name is not None - ): - if scope_parameters is not None: - if test_framework == "pytest": - scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": - scope_test_function += "_" + scope_parameters - full_name_without_module_prefix = definition[0].full_name.replace( - definition[0].module_name + ".", "", 1 - ) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - function_to_test_map[qualified_name_with_modules_from_root].append( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - ), - position=CodePosition(line_no=name.line, col_no=name.column), - ) + if scope not in test_functions_by_name: + continue + + cache_key = (name.full_name, name.module_name) + try: + if cache_key in goto_cache: + definition = goto_cache[cache_key] + else: + definition = name.goto( + follow_imports=True, follow_builtin_imports=False + ) + goto_cache[cache_key] = definition + except Exception as e: + logger.debug(str(e)) + continue + + if not definition or definition[0].type != "function": + continue + + definition_path = str(definition[0].module_path) + if ( + definition_path.startswith(str(project_root_path) + os.sep) + and definition[0].module_name != name.module_name + and definition[0].full_name is not None + ): + for index in test_functions_by_name[scope]: + scope_test_function = test_functions_list[index].function_name + scope_test_class = test_functions_list[index].test_class + scope_parameters = test_functions_list[index].parameters + test_type = test_functions_list[index].test_type + + if scope_parameters is not None: + if test_framework == "pytest": + scope_test_function += "[" + scope_parameters + "]" + if test_framework == "unittest": + scope_test_function += "_" + scope_parameters + + full_name_without_module_prefix = definition[0].full_name.replace( + definition[0].module_name + ".", "", 1 + ) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + + function_to_test_map[qualified_name_with_modules_from_root].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=scope_test_class, + test_function=scope_test_function, + test_type=test_type, + ), + position=CodePosition( + line_no=name.line, col_no=name.column + ), ) - deduped_function_to_test_map = {} - for function, tests in function_to_test_map.items(): - deduped_function_to_test_map[function] = list(set(tests)) - return deduped_function_to_test_map + ) + + return {function: list(tests) for function, tests in function_to_test_map.items()} From baa5176aed0eebe6aa03f7e6cfa988a9d33b4776 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 25 Mar 2025 17:28:40 -0500 Subject: [PATCH 2/8] pre-compile regex --- codeflash/discovery/discover_unit_tests.py | 28 ++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 4b132b060..dcf5e41bf 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -31,6 +31,13 @@ class TestFunction: test_type: TestType +ERROR_PATTERN = re.compile(r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)") +PYTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"[\[\]]") +UNITTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"^test_\w+_\d+(?:_\w+)*") +UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX = re.compile(r"_\d+(?:_\w+)*$") +FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$") + + def discover_unit_tests( cfg: TestConfig, discover_only_these_tests: list[Path] | None = None ) -> dict[str, list[FunctionCalledInTest]]: @@ -76,8 +83,7 @@ def discover_tests_pytest( if exitcode != 0: if exitcode == 2 and "ERROR collecting" in result.stdout: # Pattern matches "===== ERRORS =====" (any number of =) and captures everything after - error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)" - match = re.search(error_pattern, result.stdout) + match = ERROR_PATTERN.search(result.stdout) error_section = match.group(1) if match else result.stdout logger.warning( @@ -219,8 +225,12 @@ def process_test_files( if test_framework == "pytest": for function in functions: if "[" in function.test_function: - function_name = re.split(r"[\[\]]", function.test_function)[0] - parameters = re.split(r"[\[\]]", function.test_function)[1] + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( + function.test_function + )[0] + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( + function.test_function + )[1] if function_name in top_level_functions: test_functions.add( TestFunction( @@ -239,10 +249,14 @@ def process_test_files( function.test_type, ) ) - elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function): + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match( + function.test_function + ): # Try to match parameterized unittest functions here, although we can't get the parameters. # Extract base name by removing the numbered suffix and any additional descriptions - base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function) + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub( + "", function.test_function + ) if base_name in top_level_functions: test_functions.add( TestFunction( @@ -299,7 +313,7 @@ def process_test_files( for name in all_names: if name.full_name is None: continue - m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name) + m = FUNCTION_NAME_REGEX.search(name.full_name) if not m: continue From 1a9625b82f5e1e469018ffa85d432abf26b50576 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 25 Mar 2025 18:08:52 -0500 Subject: [PATCH 3/8] add progress bar --- codeflash/discovery/discover_unit_tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index dcf5e41bf..44febd932 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass from pytest import ExitCode -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile @@ -49,7 +49,8 @@ def discover_unit_tests( if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" raise ValueError(error_message) - return strategy(cfg, discover_only_these_tests) + with progress_bar("Discovering unit tests…", transient=True): + return strategy(cfg, discover_only_these_tests) def discover_tests_pytest( From e9347dcbedf92e09768bb44f9e55f3a5589d7d51 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 25 Mar 2025 18:27:20 -0500 Subject: [PATCH 4/8] move progress bar --- codeflash/discovery/discover_unit_tests.py | 6 +++--- codeflash/optimization/optimizer.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 44febd932..92bb54462 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass from pytest import ExitCode -from codeflash.cli_cmds.console import console, logger, progress_bar +from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile @@ -49,8 +49,8 @@ def discover_unit_tests( if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" raise ValueError(error_message) - with progress_bar("Discovering unit tests…", transient=True): - return strategy(cfg, discover_only_these_tests) + + return strategy(cfg, discover_only_these_tests) def discover_tests_pytest( diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 7e6848010..22443eade 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import normalize_code, normalize_node from codeflash.code_utils.code_utils import get_run_tmp_file @@ -95,9 +95,8 @@ def run(self) -> None: return console.rule() - logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root}…") - console.rule() - function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) + with progress_bar(f"Discovering existing unit tests in {self.test_cfg.tests_root}…", transient=True): + function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) console.rule() logger.info(f"Discovered {num_discovered_tests} existing unit tests in {self.test_cfg.tests_root}") From 5b18d98fce3c9f7921f6030f96b5dd8faa793884 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 25 Mar 2025 23:40:28 -0500 Subject: [PATCH 5/8] make pb real --- codeflash/cli_cmds/console.py | 55 +++- codeflash/discovery/discover_unit_tests.py | 319 +++++++++++---------- codeflash/optimization/optimizer.py | 3 +- 3 files changed, 213 insertions(+), 164 deletions(-) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 45959ded2..b3396eca2 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -7,7 +7,16 @@ from rich.console import Console from rich.logging import RichHandler -from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) from codeflash.cli_cmds.console_constants import SPINNER_TYPES from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT @@ -22,15 +31,26 @@ console = Console() logging.basicConfig( level=logging.INFO, - handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)], + handlers=[ + RichHandler( + rich_tracebacks=True, + markup=False, + console=console, + show_path=False, + show_time=False, + ) + ], format=BARE_LOGGING_FORMAT, ) logger = logging.getLogger("rich") -logging.getLogger('parso').setLevel(logging.WARNING) +logging.getLogger("parso").setLevel(logging.WARNING) + def paneled_text( - text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None + text: str, + panel_args: dict[str, str | bool] | None = None, + text_args: dict[str, str] | None = None, ) -> None: """Print text in a panel.""" from rich.panel import Panel @@ -57,7 +77,9 @@ def code_print(code_str: str) -> None: @contextmanager -def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]: +def progress_bar( + message: str, *, transient: bool = False +) -> Generator[TaskID, None, None]: """Display a progress bar with a spinner and elapsed time.""" progress = Progress( SpinnerColumn(next(spinners)), @@ -69,3 +91,26 @@ def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, task = progress.add_task(message, total=None) with progress: yield task + + +@contextmanager +def test_files_progress_bar( + total: int, description: str +) -> Generator[tuple[Progress, TaskID], None, None]: + """Progress bar for test files.""" + with Progress( + SpinnerColumn(next(spinners)), + TextColumn("[progress.description]{task.description}"), + BarColumn( + complete_style="cyan", + finished_style="green", + pulse_style="yellow", + ), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=console, + transient=True, + ) as progress: + task_id = progress.add_task(description, total=total) + yield progress, task_id diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 92bb54462..812b656c7 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass from pytest import ExitCode -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import console, logger, test_files_progress_bar from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile @@ -204,175 +204,180 @@ def process_test_files( jedi_project = jedi.Project(path=project_root_path) goto_cache = {} - for test_file, functions in file_to_test_map.items(): - try: - script = jedi.Script(path=test_file, project=jedi_project) - test_functions = set() - - all_names = script.get_names(all_scopes=True, references=True) - all_defs = script.get_names(all_scopes=True, definitions=True) - all_names_top = script.get_names(all_scopes=True) - - top_level_functions = { - name.name: name for name in all_names_top if name.type == "function" - } - top_level_classes = { - name.name: name for name in all_names_top if name.type == "class" - } - except Exception as e: - logger.debug(f"Failed to get jedi script for {test_file}: {e}") - continue + with test_files_progress_bar( + total=len(file_to_test_map), description="Processing test files" + ) as (progress, task_id): - if test_framework == "pytest": - for function in functions: - if "[" in function.test_function: - function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( - function.test_function - )[0] - parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( - function.test_function - )[1] - if function_name in top_level_functions: + for test_file, functions in file_to_test_map.items(): + try: + script = jedi.Script(path=test_file, project=jedi_project) + test_functions = set() + + all_names = script.get_names(all_scopes=True, references=True) + all_defs = script.get_names(all_scopes=True, definitions=True) + all_names_top = script.get_names(all_scopes=True) + + top_level_functions = { + name.name: name for name in all_names_top if name.type == "function" + } + top_level_classes = { + name.name: name for name in all_names_top if name.type == "class" + } + except Exception as e: + logger.debug(f"Failed to get jedi script for {test_file}: {e}") + progress.advance(task_id) + continue + + if test_framework == "pytest": + for function in functions: + if "[" in function.test_function: + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( + function.test_function + )[0] + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( + function.test_function + )[1] + if function_name in top_level_functions: + test_functions.add( + TestFunction( + function_name, + function.test_class, + parameters, + function.test_type, + ) + ) + elif function.test_function in top_level_functions: test_functions.add( TestFunction( - function_name, + function.test_function, function.test_class, - parameters, + None, function.test_type, ) ) - elif function.test_function in top_level_functions: - test_functions.add( - TestFunction( - function.test_function, - function.test_class, - None, - function.test_type, - ) - ) - elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match( - function.test_function - ): - # Try to match parameterized unittest functions here, although we can't get the parameters. - # Extract base name by removing the numbered suffix and any additional descriptions - base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub( - "", function.test_function - ) - if base_name in top_level_functions: - test_functions.add( - TestFunction( - function_name=base_name, - test_class=function.test_class, - parameters=function.test_function, - test_type=function.test_type, - ) - ) - - elif test_framework == "unittest": - functions_to_search = [elem.test_function for elem in functions] - test_suites = {elem.test_class for elem in functions} - - matching_names = test_suites & top_level_classes.keys() - for matched_name in matching_names: - for def_name in all_defs: - if ( - def_name.type == "function" - and def_name.full_name is not None - and f".{matched_name}." in def_name.full_name + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match( + function.test_function ): - for function in functions_to_search: - (is_parameterized, new_function, parameters) = ( - discover_parameters_unittest(function) + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub( + "", function.test_function + ) + if base_name in top_level_functions: + test_functions.add( + TestFunction( + function_name=base_name, + test_class=function.test_class, + parameters=function.test_function, + test_type=function.test_type, + ) ) - if is_parameterized and new_function == def_name.name: - test_functions.add( - TestFunction( - function_name=def_name.name, - test_class=matched_name, - parameters=parameters, - test_type=functions[0].test_type, - ) # A test file must not have more than one test type - ) - elif function == def_name.name: - test_functions.add( - TestFunction( - function_name=def_name.name, - test_class=matched_name, - parameters=None, - test_type=functions[0].test_type, - ) + elif test_framework == "unittest": + functions_to_search = [elem.test_function for elem in functions] + test_suites = {elem.test_class for elem in functions} + + matching_names = test_suites & top_level_classes.keys() + for matched_name in matching_names: + for def_name in all_defs: + if ( + def_name.type == "function" + and def_name.full_name is not None + and f".{matched_name}." in def_name.full_name + ): + for function in functions_to_search: + (is_parameterized, new_function, parameters) = ( + discover_parameters_unittest(function) ) - test_functions_list = list(test_functions) - test_functions_raw = [elem.function_name for elem in test_functions_list] - - test_functions_by_name = defaultdict(list) - for i, func_name in enumerate(test_functions_raw): - test_functions_by_name[func_name].append(i) - - for name in all_names: - if name.full_name is None: - continue - m = FUNCTION_NAME_REGEX.search(name.full_name) - if not m: - continue - - scope = m.group(1) - if scope not in test_functions_by_name: - continue - - cache_key = (name.full_name, name.module_name) - try: - if cache_key in goto_cache: - definition = goto_cache[cache_key] - else: - definition = name.goto( - follow_imports=True, follow_builtin_imports=False - ) - goto_cache[cache_key] = definition - except Exception as e: - logger.debug(str(e)) - continue - - if not definition or definition[0].type != "function": - continue + if is_parameterized and new_function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=parameters, + test_type=functions[0].test_type, + ) + ) + elif function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=None, + test_type=functions[0].test_type, + ) + ) - definition_path = str(definition[0].module_path) - if ( - definition_path.startswith(str(project_root_path) + os.sep) - and definition[0].module_name != name.module_name - and definition[0].full_name is not None - ): - for index in test_functions_by_name[scope]: - scope_test_function = test_functions_list[index].function_name - scope_test_class = test_functions_list[index].test_class - scope_parameters = test_functions_list[index].parameters - test_type = test_functions_list[index].test_type - - if scope_parameters is not None: - if test_framework == "pytest": - scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": - scope_test_function += "_" + scope_parameters - - full_name_without_module_prefix = definition[0].full_name.replace( - definition[0].module_name + ".", "", 1 - ) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - - function_to_test_map[qualified_name_with_modules_from_root].add( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - ), - position=CodePosition( - line_no=name.line, col_no=name.column - ), + test_functions_list = list(test_functions) + test_functions_raw = [elem.function_name for elem in test_functions_list] + + test_functions_by_name = defaultdict(list) + for i, func_name in enumerate(test_functions_raw): + test_functions_by_name[func_name].append(i) + + for name in all_names: + if name.full_name is None: + continue + m = FUNCTION_NAME_REGEX.search(name.full_name) + if not m: + continue + + scope = m.group(1) + if scope not in test_functions_by_name: + continue + + cache_key = (name.full_name, name.module_name) + try: + if cache_key in goto_cache: + definition = goto_cache[cache_key] + else: + definition = name.goto( + follow_imports=True, follow_builtin_imports=False ) - ) + goto_cache[cache_key] = definition + except Exception as e: + logger.debug(str(e)) + continue + + if not definition or definition[0].type != "function": + continue + + definition_path = str(definition[0].module_path) + if ( + definition_path.startswith(str(project_root_path) + os.sep) + and definition[0].module_name != name.module_name + and definition[0].full_name is not None + ): + for index in test_functions_by_name[scope]: + scope_test_function = test_functions_list[index].function_name + scope_test_class = test_functions_list[index].test_class + scope_parameters = test_functions_list[index].parameters + test_type = test_functions_list[index].test_type + + if scope_parameters is not None: + if test_framework == "pytest": + scope_test_function += "[" + scope_parameters + "]" + if test_framework == "unittest": + scope_test_function += "_" + scope_parameters + + full_name_without_module_prefix = definition[ + 0 + ].full_name.replace(definition[0].module_name + ".", "", 1) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + + function_to_test_map[qualified_name_with_modules_from_root].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=scope_test_class, + test_function=scope_test_function, + test_type=test_type, + ), + position=CodePosition( + line_no=name.line, col_no=name.column + ), + ) + ) + + progress.advance(task_id) return {function: list(tests) for function, tests in function_to_test_map.items()} diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 22443eade..c0011d436 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -95,8 +95,7 @@ def run(self) -> None: return console.rule() - with progress_bar(f"Discovering existing unit tests in {self.test_cfg.tests_root}…", transient=True): - function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) + function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) console.rule() logger.info(f"Discovered {num_discovered_tests} existing unit tests in {self.test_cfg.tests_root}") From 8a3ea9130970868f07b8646c064a92e5a53a996e Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 25 Mar 2025 23:47:53 -0500 Subject: [PATCH 6/8] Update console.py --- codeflash/cli_cmds/console.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index b3396eca2..b4bfda3ff 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -109,7 +109,6 @@ def test_files_progress_bar( MofNCompleteColumn(), TimeElapsedColumn(), TimeRemainingColumn(), - console=console, transient=True, ) as progress: task_id = progress.add_task(description, total=total) From ad9b3063b1a6ff39ae70ba3f4a72d1a4f32841ee Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 26 Mar 2025 19:03:11 -0500 Subject: [PATCH 7/8] implement suggestions Revert "implement suggestions" This reverts commit 8bd8068eebb3ae1f8fef0df0e126eb8511e8bbd1. first pass --- codeflash/discovery/discover_unit_tests.py | 105 +++++++++------------ 1 file changed, 46 insertions(+), 59 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 812b656c7..ac6c96dd9 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -211,7 +211,7 @@ def process_test_files( for test_file, functions in file_to_test_map.items(): try: script = jedi.Script(path=test_file, project=jedi_project) - test_functions = set() + test_functions_by_name: dict[str, TestFunction] = {} all_names = script.get_names(all_scopes=True, references=True) all_defs = script.get_names(all_scopes=True, definitions=True) @@ -238,22 +238,18 @@ def process_test_files( function.test_function )[1] if function_name in top_level_functions: - test_functions.add( - TestFunction( - function_name, - function.test_class, - parameters, - function.test_type, - ) - ) - elif function.test_function in top_level_functions: - test_functions.add( - TestFunction( - function.test_function, + test_functions_by_name[function_name] = TestFunction( + function_name, function.test_class, - None, + parameters, function.test_type, ) + elif function.test_function in top_level_functions: + test_functions_by_name[function.test_function] = TestFunction( + function.test_function, + function.test_class, + None, + function.test_type, ) elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match( function.test_function @@ -262,13 +258,11 @@ def process_test_files( "", function.test_function ) if base_name in top_level_functions: - test_functions.add( - TestFunction( - function_name=base_name, - test_class=function.test_class, - parameters=function.test_function, - test_type=function.test_type, - ) + test_functions_by_name[base_name] = TestFunction( + function_name=base_name, + test_class=function.test_class, + parameters=function.test_function, + test_type=function.test_type, ) elif test_framework == "unittest": @@ -289,7 +283,7 @@ def process_test_files( ) if is_parameterized and new_function == def_name.name: - test_functions.add( + test_functions_by_name[def_name.name] = ( TestFunction( function_name=def_name.name, test_class=matched_name, @@ -298,7 +292,7 @@ def process_test_files( ) ) elif function == def_name.name: - test_functions.add( + test_functions_by_name[def_name.name] = ( TestFunction( function_name=def_name.name, test_class=matched_name, @@ -307,13 +301,6 @@ def process_test_files( ) ) - test_functions_list = list(test_functions) - test_functions_raw = [elem.function_name for elem in test_functions_list] - - test_functions_by_name = defaultdict(list) - for i, func_name in enumerate(test_functions_raw): - test_functions_by_name[func_name].append(i) - for name in all_names: if name.full_name is None: continue @@ -347,36 +334,36 @@ def process_test_files( and definition[0].module_name != name.module_name and definition[0].full_name is not None ): - for index in test_functions_by_name[scope]: - scope_test_function = test_functions_list[index].function_name - scope_test_class = test_functions_list[index].test_class - scope_parameters = test_functions_list[index].parameters - test_type = test_functions_list[index].test_type - - if scope_parameters is not None: - if test_framework == "pytest": - scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": - scope_test_function += "_" + scope_parameters - - full_name_without_module_prefix = definition[ - 0 - ].full_name.replace(definition[0].module_name + ".", "", 1) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - - function_to_test_map[qualified_name_with_modules_from_root].add( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - ), - position=CodePosition( - line_no=name.line, col_no=name.column - ), - ) + test_function = test_functions_by_name[scope] + scope_test_function = test_function.function_name + scope_test_class = test_function.test_class + scope_parameters = test_function.parameters + test_type = test_function.test_type + + if scope_parameters is not None: + if test_framework == "pytest": + scope_test_function += "[" + scope_parameters + "]" + if test_framework == "unittest": + scope_test_function += "_" + scope_parameters + + full_name_without_module_prefix = definition[0].full_name.replace( + definition[0].module_name + ".", "", 1 + ) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + + function_to_test_map[qualified_name_with_modules_from_root].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=scope_test_class, + test_function=scope_test_function, + test_type=test_type, + ), + position=CodePosition( + line_no=name.line, col_no=name.column + ), ) + ) progress.advance(task_id) From 27784fe8cfeb87e27bb031ca3722f39348cd1f3d Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 28 Mar 2025 19:09:10 -0500 Subject: [PATCH 8/8] Revert "implement suggestions" This reverts commit ad9b3063b1a6ff39ae70ba3f4a72d1a4f32841ee. --- codeflash/discovery/discover_unit_tests.py | 105 ++++++++++++--------- 1 file changed, 59 insertions(+), 46 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index ac6c96dd9..812b656c7 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -211,7 +211,7 @@ def process_test_files( for test_file, functions in file_to_test_map.items(): try: script = jedi.Script(path=test_file, project=jedi_project) - test_functions_by_name: dict[str, TestFunction] = {} + test_functions = set() all_names = script.get_names(all_scopes=True, references=True) all_defs = script.get_names(all_scopes=True, definitions=True) @@ -238,18 +238,22 @@ def process_test_files( function.test_function )[1] if function_name in top_level_functions: - test_functions_by_name[function_name] = TestFunction( - function_name, + test_functions.add( + TestFunction( + function_name, + function.test_class, + parameters, + function.test_type, + ) + ) + elif function.test_function in top_level_functions: + test_functions.add( + TestFunction( + function.test_function, function.test_class, - parameters, + None, function.test_type, ) - elif function.test_function in top_level_functions: - test_functions_by_name[function.test_function] = TestFunction( - function.test_function, - function.test_class, - None, - function.test_type, ) elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match( function.test_function @@ -258,11 +262,13 @@ def process_test_files( "", function.test_function ) if base_name in top_level_functions: - test_functions_by_name[base_name] = TestFunction( - function_name=base_name, - test_class=function.test_class, - parameters=function.test_function, - test_type=function.test_type, + test_functions.add( + TestFunction( + function_name=base_name, + test_class=function.test_class, + parameters=function.test_function, + test_type=function.test_type, + ) ) elif test_framework == "unittest": @@ -283,7 +289,7 @@ def process_test_files( ) if is_parameterized and new_function == def_name.name: - test_functions_by_name[def_name.name] = ( + test_functions.add( TestFunction( function_name=def_name.name, test_class=matched_name, @@ -292,7 +298,7 @@ def process_test_files( ) ) elif function == def_name.name: - test_functions_by_name[def_name.name] = ( + test_functions.add( TestFunction( function_name=def_name.name, test_class=matched_name, @@ -301,6 +307,13 @@ def process_test_files( ) ) + test_functions_list = list(test_functions) + test_functions_raw = [elem.function_name for elem in test_functions_list] + + test_functions_by_name = defaultdict(list) + for i, func_name in enumerate(test_functions_raw): + test_functions_by_name[func_name].append(i) + for name in all_names: if name.full_name is None: continue @@ -334,36 +347,36 @@ def process_test_files( and definition[0].module_name != name.module_name and definition[0].full_name is not None ): - test_function = test_functions_by_name[scope] - scope_test_function = test_function.function_name - scope_test_class = test_function.test_class - scope_parameters = test_function.parameters - test_type = test_function.test_type - - if scope_parameters is not None: - if test_framework == "pytest": - scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": - scope_test_function += "_" + scope_parameters - - full_name_without_module_prefix = definition[0].full_name.replace( - definition[0].module_name + ".", "", 1 - ) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - - function_to_test_map[qualified_name_with_modules_from_root].add( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - ), - position=CodePosition( - line_no=name.line, col_no=name.column - ), + for index in test_functions_by_name[scope]: + scope_test_function = test_functions_list[index].function_name + scope_test_class = test_functions_list[index].test_class + scope_parameters = test_functions_list[index].parameters + test_type = test_functions_list[index].test_type + + if scope_parameters is not None: + if test_framework == "pytest": + scope_test_function += "[" + scope_parameters + "]" + if test_framework == "unittest": + scope_test_function += "_" + scope_parameters + + full_name_without_module_prefix = definition[ + 0 + ].full_name.replace(definition[0].module_name + ".", "", 1) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + + function_to_test_map[qualified_name_with_modules_from_root].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=scope_test_class, + test_function=scope_test_function, + test_type=test_type, + ), + position=CodePosition( + line_no=name.line, col_no=name.column + ), + ) ) - ) progress.advance(task_id)