diff --git a/code_to_optimize/bubble_sort2.py b/code_to_optimize/bubble_sort2.py index fce9e7d77..aa88d2ae4 100644 --- a/code_to_optimize/bubble_sort2.py +++ b/code_to_optimize/bubble_sort2.py @@ -1,6 +1,3 @@ def sorter(arr): arr.sort() - return arr - - -CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()" + return arr \ No newline at end of file diff --git a/code_to_optimize/bubble_sort_deps.py b/code_to_optimize/bubble_sort_deps.py index 7928f19ac..55d7959fa 100644 --- a/code_to_optimize/bubble_sort_deps.py +++ b/code_to_optimize/bubble_sort_deps.py @@ -9,142 +9,3 @@ def sorter_deps(arr): dep2_swap(arr, j) return arr - -CACHED_TESTS = """import dill as pickle -import os -def _log__test__values(values, duration, test_name): - iteration = os.environ["CODEFLASH_TEST_ITERATION"] - with open(os.path.join( - '/var/folders/ms/1tz2l1q55w5b7pp4wpdkbjq80000gn/T/codeflash_jk4pzz3w/', - f'test_return_values_{iteration}.bin'), 'ab') as f: - return_bytes = pickle.dumps(values) - _test_name = f"{test_name}".encode("ascii") - f.write(len(_test_name).to_bytes(4, byteorder='big')) - f.write(_test_name) - f.write(duration.to_bytes(8, byteorder='big')) - f.write(len(return_bytes).to_bytes(4, byteorder='big')) - f.write(return_bytes) -import time -import gc -from code_to_optimize.bubble_sort_deps import sorter_deps -import timeout_decorator -import unittest - -def dep1_comparer(arr, j: int) -> bool: - return arr[j] > arr[j + 1] - -def dep2_swap(arr, j): - temp = arr[j] - arr[j] = arr[j + 1] - arr[j + 1] = temp - -class TestSorterDeps(unittest.TestCase): - - @timeout_decorator.timeout(15, use_signals=True) - def test_integers(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([5, 3, 2, 4, 1]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values( - return_value, duration, - 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_integers:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([10, -3, 0, 2, 7]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values( - return_value, duration, - ('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:' - 'TestSorterDeps.test_integers:sorter_deps:1')) - - @timeout_decorator.timeout(15, use_signals=True) - def test_floats(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([3.2, 1.5, 2.7, 4.1, 1.0]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([-1.1, 0.0, 3.14, 2.71, -0.5]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:1') - - @timeout_decorator.timeout(15, use_signals=True) - def test_identical_elements(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([1, 1, 1, 1, 1]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - ('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:' - 'TestSorterDeps.test_identical_elements:sorter_deps:0')) - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([3.14, 3.14, 3.14]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - ('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:' - 'TestSorterDeps.test_identical_elements:sorter_deps:1')) - - @timeout_decorator.timeout(15, use_signals=True) - def test_single_element(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([5]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([-3.2]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:1') - - @timeout_decorator.timeout(15, use_signals=True) - def test_empty_array(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_empty_array:sorter_deps:0') - - @timeout_decorator.timeout(15, use_signals=True) - def test_strings(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps(['apple', 'banana', 'cherry', 'date']) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps(['dog', 'cat', 'elephant', 'ant']) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:1') - - @timeout_decorator.timeout(15, use_signals=True) - def test_mixed_types(self): - with self.assertRaises(TypeError): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([1, 'two', 3.0, 'four']) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_mixed_types:sorter_deps:0_0') -if __name__ == '__main__': - unittest.main() - -""" diff --git a/code_to_optimize/final_test_set/bubble_sort.py b/code_to_optimize/final_test_set/bubble_sort.py index 4a5132ec3..b18994494 100644 --- a/code_to_optimize/final_test_set/bubble_sort.py +++ b/code_to_optimize/final_test_set/bubble_sort.py @@ -5,7 +5,4 @@ def sorter(arr): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp - return arr - - -CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()" + return arr \ No newline at end of file diff --git a/code_to_optimize/use_cosine_similarity_from_other_file.py b/code_to_optimize/use_cosine_similarity_from_other_file.py index a3ffa60ab..4c41b9ee6 100644 --- a/code_to_optimize/use_cosine_similarity_from_other_file.py +++ b/code_to_optimize/use_cosine_similarity_from_other_file.py @@ -9,110 +9,4 @@ def use_cosine_similarity( top_k: Optional[int] = 5, score_threshold: Optional[float] = None, ) -> Tuple[List[Tuple[int, int]], List[float]]: - return cosine_similarity_top_k(X, Y, top_k, score_threshold) - - -CACHED_TESTS = """import unittest -import numpy as np -from sklearn.metrics.pairwise import cosine_similarity -from typing import List, Optional, Tuple, Union -Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] -def cosine_similarity_top_k(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]: - \"\"\"Row-wise cosine similarity with optional top-k and score threshold filtering. - Args: - X: Matrix. - Y: Matrix, same width as X. - top_k: Max number of results to return. - score_threshold: Minimum cosine similarity of results. - Returns: - Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx), - second contains corresponding cosine similarities. - \"\"\" - if len(X) == 0 or len(Y) == 0: - return ([], []) - score_array = cosine_similarity(X, Y) - sorted_idxs = score_array.flatten().argsort()[::-1] - top_k = top_k or len(sorted_idxs) - top_idxs = sorted_idxs[:top_k] - score_threshold = score_threshold or -1.0 - top_idxs = top_idxs[score_array.flatten()[top_idxs] > score_threshold] - ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in top_idxs] - scores = score_array.flatten()[top_idxs].tolist() - return (ret_idxs, scores) -def use_cosine_similarity(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]: - return cosine_similarity_top_k(X, Y, top_k, score_threshold) -class TestUseCosineSimilarity(unittest.TestCase): - def test_normal_scenario(self): - X = [[1, 2, 3], [4, 5, 6]] - Y = [[7, 8, 9], [10, 11, 12]] - result = use_cosine_similarity(X, Y, top_k=1, score_threshold=0.5) - self.assertEqual(result, ([(0, 1)], [0.9746318461970762])) - def test_edge_case_empty_matrices(self): - X = [] - Y = [] - result = use_cosine_similarity(X, Y) - self.assertEqual(result, ([], [])) - def test_edge_case_different_widths(self): - X = [[1, 2, 3]] - Y = [[4, 5]] - with self.assertRaises(ValueError): - use_cosine_similarity(X, Y) - def test_edge_case_negative_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - with self.assertRaises(IndexError): - use_cosine_similarity(X, Y, top_k=-1) - def test_edge_case_zero_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, top_k=0) - self.assertEqual(result, ([], [])) - def test_edge_case_negative_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, score_threshold=-1.0) - self.assertEqual(result, ([(0, 0)], [0.9746318461970762])) - def test_edge_case_large_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, score_threshold=2.0) - self.assertEqual(result, ([], [])) - def test_exceptional_case_non_matrix_X(self): - X = [1, 2, 3] - Y = [[4, 5, 6]] - with self.assertRaises(ValueError): - use_cosine_similarity(X, Y) - def test_exceptional_case_non_integer_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - with self.assertRaises(TypeError): - use_cosine_similarity(X, Y, top_k='5') - def test_exceptional_case_non_float_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - with self.assertRaises(TypeError): - use_cosine_similarity(X, Y, score_threshold='0.5') - def test_special_values_nan_in_matrices(self): - X = [[1, 2, np.nan]] - Y = [[4, 5, 6]] - with self.assertRaises(ValueError): - use_cosine_similarity(X, Y) - def test_special_values_none_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, top_k=None) - self.assertEqual(result, ([(0, 0)], [0.9746318461970762])) - def test_special_values_none_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, score_threshold=None) - self.assertEqual(result, ([(0, 0)], [0.9746318461970762])) - def test_large_inputs(self): - X = np.random.rand(1000, 1000) - Y = np.random.rand(1000, 1000) - result = use_cosine_similarity(X, Y, top_k=10, score_threshold=0.5) - self.assertEqual(len(result[0]), 10) - self.assertEqual(len(result[1]), 10) - self.assertTrue(all((score > 0.5 for score in result[1]))) -if __name__ == '__main__': - unittest.main()""" + return cosine_similarity_top_k(X, Y, top_k, score_threshold) \ No newline at end of file diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/code_utils/coverage_utils.py index 34c5475ec..21aa06ad9 100644 --- a/codeflash/code_utils/coverage_utils.py +++ b/codeflash/code_utils/coverage_utils.py @@ -12,7 +12,7 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]: """Extract the single dependent function from the code context excluding the main function.""" - ast_tree = ast.parse(code_context.code_to_optimize_with_helpers) + ast_tree = ast.parse(code_context.testgen_context_code) dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)} diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index b6dab0162..5d5487b98 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -15,43 +15,47 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource +from codeflash.models.models import ( + CodeContextType, + CodeOptimizationContext, + CodeString, + CodeStringsMarkdown, + FunctionSource, +) from codeflash.optimization.function_context import belongs_to_function_qualified def get_code_optimization_context( - function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000 + function_to_optimize: FunctionToOptimize, project_root_path: Path, optim_token_limit: int = 8000, testgen_token_limit: int = 8000 ) -> CodeOptimizationContext: - # Get qualified names and fully qualified names(fqn) of helpers - helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict( - {function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path - ) - - helpers_of_helpers, helpers_of_helpers_fqn, _ = get_file_path_to_helper_functions_dict( - helpers_of_fto, project_root_path - ) - - # Add function to optimize - helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name) - helpers_of_fto_fqn[function_to_optimize.file_path].add( - function_to_optimize.qualified_name_with_modules_from_root(project_root_path) - ) - - # Extract code - final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code - read_only_code_markdown = get_all_read_only_code_context( - helpers_of_fto, - helpers_of_fto_fqn, - helpers_of_helpers, - helpers_of_helpers_fqn, + # Get FunctionSource representation of helpers of FTO + helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi({function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path) + helpers_of_fto_qualified_names_dict = { + file_path: {source.qualified_name for source in sources} + for file_path, sources in helpers_of_fto_dict.items() + } + + # Get FunctionSource representation of helpers of helpers of FTO + helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(helpers_of_fto_qualified_names_dict, project_root_path) + + # Add function to optimize into helpers of FTO dict, as they'll be processed together + fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path) + helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source) + + # Extract code context for optimization + final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto_dict,{}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code + read_only_code_markdown = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, project_root_path, remove_docstrings=False, + code_context_type=CodeContextType.READ_ONLY, ) # Handle token limits tokenizer = tiktoken.encoding_for_model("gpt-4o") final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code)) - if final_read_writable_tokens > token_limit: + if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") # Setup preexisting objects for code replacer TODO: should remove duplicates @@ -61,170 +65,290 @@ def get_code_optimization_context( *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) ) - read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown)) + read_only_context_code = read_only_code_markdown.markdown + + read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code)) total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens - if total_tokens <= token_limit: - return CodeOptimizationContext( - code_to_optimize_with_helpers="", - read_writable_code=CodeString(code=final_read_writable_code).code, - read_only_context_code=read_only_code_markdown.markdown, - helper_functions=helpers_of_fto_obj_list, - preexisting_objects=preexisting_objects, + if total_tokens > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") + # Extract read only code without docstrings + read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, ) - - logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") - - # Extract read only code without docstrings - read_only_code_no_docstring_markdown = get_all_read_only_code_context( - helpers_of_fto, - helpers_of_fto_fqn, - helpers_of_helpers, - helpers_of_helpers_fqn, + read_only_context_code = read_only_code_no_docstring_markdown.markdown + read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code)) + total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens + if total_tokens > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing read-only code") + read_only_context_code = "" + + # Extract code context for testgen + testgen_code_markdown = extract_code_string_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, project_root_path, - remove_docstrings=True, + remove_docstrings=False, + code_context_type=CodeContextType.TESTGEN, ) - read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown)) - total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens - if total_tokens <= token_limit: - return CodeOptimizationContext( - code_to_optimize_with_helpers="", - read_writable_code=CodeString(code=final_read_writable_code).code, - read_only_context_code=read_only_code_no_docstring_markdown.markdown, - helper_functions=helpers_of_fto_obj_list, - preexisting_objects=preexisting_objects, + testgen_context_code = testgen_code_markdown.code + testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) + if testgen_context_code_tokens > testgen_token_limit: + testgen_code_markdown = extract_code_string_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, + code_context_type=CodeContextType.TESTGEN, ) + testgen_context_code = testgen_code_markdown.code + testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) + if testgen_context_code_tokens > testgen_token_limit: + raise ValueError("Testgen code context has exceeded token limit, cannot proceed") - logger.debug("Code context has exceeded token limit, removing read-only code") return CodeOptimizationContext( - code_to_optimize_with_helpers="", - read_writable_code=CodeString(code=final_read_writable_code).code, - read_only_context_code="", - helper_functions=helpers_of_fto_obj_list, + testgen_context_code = testgen_context_code, + read_writable_code=final_read_writable_code, + read_only_context_code=read_only_context_code, + helper_functions=helpers_of_fto_list, preexisting_objects=preexisting_objects, ) - -def get_all_read_writable_code( - helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path +def extract_code_string_context_from_files( + helpers_of_fto: dict[Path, set[FunctionSource]], + helpers_of_helpers: dict[Path, set[FunctionSource]], + project_root_path: Path, + remove_docstrings: bool = False, + code_context_type: CodeContextType = CodeContextType.READ_ONLY, ) -> CodeString: - final_read_writable_code = "" - # Extract code from file paths that contain fto and first degree helpers - for file_path, qualified_function_names in helpers_of_fto.items(): + """Extract code context from files containing target functions and their helpers. + This function processes two sets of files: + 1. Files containing the function to optimize (fto) and their first-degree helpers + 2. Files containing only helpers of helpers (with no overlap with the first set) + + For each file, it extracts relevant code based on the specified context type, adds necessary + imports, and combines them. + + Args: + helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers + helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions + project_root_path: Root path of the project + remove_docstrings: Whether to remove docstrings from the extracted code + code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) + + Returns: + CodeString containing the extracted code context with necessary imports + + """ + # Rearrange to remove overlaps, so we only access each file path once + helpers_of_helpers_no_overlap = defaultdict(set) + for file_path in helpers_of_helpers: + if file_path in helpers_of_fto: + # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto + helpers_of_helpers[file_path] -= helpers_of_fto[file_path] + else: + helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path] + + final_code_string_context = "" + + # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files + for file_path, function_sources in helpers_of_fto.items(): try: original_code = file_path.read_text("utf8") except Exception as e: logger.exception(f"Error while parsing {file_path}: {e}") continue try: - read_writable_code = get_read_writable_code(original_code, qualified_function_names) + qualified_function_names = {func.qualified_name for func in function_sources} + helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} + code_context = parse_code_and_prune_cst( + original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings + ) + except ValueError as e: - logger.debug(f"Error while getting read-writable code: {e}") + logger.debug(f"Error while getting read-only code: {e}") continue - - if read_writable_code: - final_read_writable_code += f"\n{read_writable_code}" - final_read_writable_code = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_read_writable_code, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions_fqn=helpers_of_fto_fqn[file_path], + if code_context.strip(): + final_code_string_context += f"\n{code_context}" + final_code_string_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=final_code_string_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions= list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) ) - return CodeString(code=final_read_writable_code) + if code_context_type == CodeContextType.READ_WRITABLE: + return CodeString(code=final_code_string_context) + # Extract code from file paths containing helpers of helpers + for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + continue + try: + qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + code_context = parse_code_and_prune_cst( + original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings + ) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") + continue + if code_context.strip(): + final_code_string_context += f"\n{code_context}" + final_code_string_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=final_code_string_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), + ) + return CodeString(code=final_code_string_context) -def get_all_read_only_code_context( - helpers_of_fto: dict[Path, set[str]], - helpers_of_fto_fqn: dict[Path, set[str]], - helpers_of_helpers: dict[Path, set[str]], - helpers_of_helpers_fqn: dict[Path, set[str]], +def extract_code_markdown_context_from_files( + helpers_of_fto: dict[Path, set[FunctionSource]], + helpers_of_helpers: dict[Path, set[FunctionSource]], project_root_path: Path, remove_docstrings: bool = False, + code_context_type: CodeContextType = CodeContextType.READ_ONLY, ) -> CodeStringsMarkdown: + """Extract code context from files containing target functions and their helpers, formatting them as markdown. + + This function processes two sets of files: + 1. Files containing the function to optimize (fto) and their first-degree helpers + 2. Files containing only helpers of helpers (with no overlap with the first set) + + For each file, it extracts relevant code based on the specified context type, adds necessary + imports, and combines them into a structured markdown format. + + Args: + helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers + helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions + project_root_path: Root path of the project + remove_docstrings: Whether to remove docstrings from the extracted code + code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) + + Returns: + CodeStringsMarkdown containing the extracted code context with necessary imports, + formatted for inclusion in markdown + + """ # Rearrange to remove overlaps, so we only access each file path once helpers_of_helpers_no_overlap = defaultdict(set) - helpers_of_helpers_no_overlap_fqn = defaultdict(set) for file_path in helpers_of_helpers: if file_path in helpers_of_fto: - # Remove duplicates, in case a helper of helper is also a helper of fto + # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto helpers_of_helpers[file_path] -= helpers_of_fto[file_path] - helpers_of_helpers_fqn[file_path] -= helpers_of_fto_fqn[file_path] else: helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path] - helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path] - - read_only_code_markdown = CodeStringsMarkdown() + code_context_markdown = CodeStringsMarkdown() # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files - for file_path, qualified_function_names in helpers_of_fto.items(): + for file_path, function_sources in helpers_of_fto.items(): try: original_code = file_path.read_text("utf8") except Exception as e: logger.exception(f"Error while parsing {file_path}: {e}") continue try: - read_only_code = get_read_only_code( - original_code, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings + qualified_function_names = {func.qualified_name for func in function_sources} + helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} + code_context = parse_code_and_prune_cst( + original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings ) + except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue - if read_only_code.strip(): - read_only_code_with_imports = CodeString( + if code_context.strip(): + code_context_with_imports = CodeString( code=add_needed_imports_from_module( src_module_code=original_code, - dst_module_code=read_only_code, + dst_module_code=code_context, src_path=file_path, dst_path=file_path, project_root=project_root_path, - helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path], + helper_functions=list( + helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) ), file_path=file_path.relative_to(project_root_path), ) - read_only_code_markdown.code_strings.append(read_only_code_with_imports) - + code_context_markdown.code_strings.append(code_context_with_imports) # Extract code from file paths containing helpers of helpers - for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items(): + for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): try: original_code = file_path.read_text("utf8") except Exception as e: logger.exception(f"Error while parsing {file_path}: {e}") continue try: - read_only_code = get_read_only_code( - original_code, set(), qualified_helper_function_names, remove_docstrings + qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + code_context = parse_code_and_prune_cst( + original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings, ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue - if read_only_code.strip(): - read_only_code_with_imports = CodeString( + if code_context.strip(): + code_context_with_imports = CodeString( code=add_needed_imports_from_module( src_module_code=original_code, - dst_module_code=read_only_code, + dst_module_code=code_context, src_path=file_path, dst_path=file_path, project_root=project_root_path, - helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path], + helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), ), file_path=file_path.relative_to(project_root_path), ) - read_only_code_markdown.code_strings.append(read_only_code_with_imports) - return read_only_code_markdown + code_context_markdown.code_strings.append(code_context_with_imports) + return code_context_markdown + + +def get_function_to_optimize_as_function_source(function_to_optimize: FunctionToOptimize, + project_root_path: Path) -> FunctionSource: + # Use jedi to find function to optimize + script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path)) + + # Get all names in the file + names = script.get_names(all_scopes=True, definitions=True, references=False) + + # Find the name that matches our function + for name in names: + if (name.type == "function" and + name.full_name and + name.name == function_to_optimize.function_name and + get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name): + + function_source = FunctionSource( + file_path=function_to_optimize.file_path, + qualified_name=function_to_optimize.qualified_name, + fully_qualified_name=name.full_name, + only_function_name=name.name, + source_code=name.get_line_code(), + jedi_definition=name, + ) + return function_source + + raise ValueError( + f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}") -def get_file_path_to_helper_functions_dict( +def get_function_sources_from_jedi( file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path -) -> tuple[dict[Path, set[str]], dict[Path, set[str]], list[FunctionSource]]: - file_path_to_helper_function_qualified_names = defaultdict(set) - file_path_to_helper_function_fqn = defaultdict(set) +) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + file_path_to_function_source = defaultdict(set) function_source_list: list[FunctionSource] = [] - for file_path in file_path_to_qualified_function_names: + for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) file_refs = script.get_names(all_scopes=True, definitions=False, references=True) - for qualified_function_name in file_path_to_qualified_function_names[file_path]: + for qualified_function_name in qualified_function_names: names = [ ref for ref in file_refs @@ -253,22 +377,18 @@ def get_file_path_to_helper_functions_dict( and definition.type == "function" and not belongs_to_function_qualified(definition, qualified_function_name) ): - file_path_to_helper_function_qualified_names[definition_path].add( - get_qualified_name(definition.module_name, definition.full_name) - ) - file_path_to_helper_function_fqn[definition_path].add(definition.full_name) - function_source_list.append( - FunctionSource( - file_path=definition_path, - qualified_name=get_qualified_name(definition.module_name, definition.full_name), - fully_qualified_name=definition.full_name, - only_function_name=definition.name, - source_code=definition.get_line_code(), - jedi_definition=definition, - ) + function_source = FunctionSource( + file_path=definition_path, + qualified_name=get_qualified_name(definition.module_name, definition.full_name), + fully_qualified_name=definition.full_name, + only_function_name=definition.name, + source_code=definition.get_line_code(), + jedi_definition=definition, ) + file_path_to_function_source[definition_path].add(function_source) + function_source_list.append(function_source) - return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn, function_source_list + return file_path_to_function_source, function_source_list def is_dunder_method(name: str) -> bool: @@ -290,6 +410,29 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode return indented_block.with_changes(body=indented_block.body[1:]) return indented_block +def parse_code_and_prune_cst( + code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = set(), remove_docstrings: bool = False +) -> str: + """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables.""" + module = cst.parse_module(code) + if code_context_type == CodeContextType.READ_WRITABLE: + filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions) + elif code_context_type == CodeContextType.READ_ONLY: + filtered_node, found_target = prune_cst_for_read_only_code( + module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + ) + elif code_context_type == CodeContextType.TESTGEN: + filtered_node, found_target = prune_cst_for_testgen_code( + module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + ) + else: + raise ValueError(f"Unknown code_context_type: {code_context_type}") + + if not found_target: + raise ValueError("No target functions found in the provided code") + if filtered_node and isinstance(filtered_node, cst.Module): + return str(filtered_node.code) + return "" def prune_cst_for_read_writable_code( node: cst.CSTNode, target_functions: set[str], prefix: str = "" @@ -370,20 +513,6 @@ def prune_cst_for_read_writable_code( return (node.with_changes(**updates) if updates else node), True - -def get_read_writable_code(code: str, target_functions: set[str]) -> str: - """Creates a read-writable code string by parsing and filtering the code to keep only - target functions and the minimal surrounding structure. - """ - module = cst.parse_module(code) - filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions) - if not found_target: - raise ValueError("No target functions found in the provided code") - if filtered_node and isinstance(filtered_node, cst.Module): - return str(filtered_node.code) - return "" - - def prune_cst_for_read_only_code( node: cst.CSTNode, target_functions: set[str], @@ -488,18 +617,107 @@ def prune_cst_for_read_only_code( return None, False -def get_read_only_code( - code: str, target_functions: set[str], helpers_of_helper_functions: set[str], remove_docstrings: bool = False -) -> str: - """Creates a read-only version of the code by parsing and filtering the code to keep only - class contextual information, and other module scoped variables. + +def prune_cst_for_testgen_code( + node: cst.CSTNode, + target_functions: set[str], + helpers_of_helper_functions: set[str], + prefix: str = "", + remove_docstrings: bool = False, +) -> tuple[cst.CSTNode | None, bool]: + """Recursively filter the node for testgen context: + + Returns: + (filtered_node, found_target): + filtered_node: The modified CST node or None if it should be removed. + found_target: True if a target function was found in this node's subtree. + """ - module = cst.parse_module(code) - filtered_node, found_target = prune_cst_for_read_only_code( - module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings - ) - if not found_target: - raise ValueError("No target functions found in the provided code") - if filtered_node and isinstance(filtered_node, cst.Module): - return str(filtered_node.code) - return "" + if isinstance(node, (cst.Import, cst.ImportFrom)): + return None, False + + if isinstance(node, cst.FunctionDef): + qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value + # If it's a target function, remove it but mark found_target = True + if qualified_name in helpers_of_helper_functions or qualified_name in target_functions: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + new_body = remove_docstring_from_body(node.body) + return node.with_changes(body=new_body), True + return node, True + # Keep all dunder methods + if is_dunder_method(node.name.value): + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + new_body = remove_docstring_from_body(node.body) + return node.with_changes(body=new_body), False + return node, False + return None, False + + if isinstance(node, cst.ClassDef): + # Do not recurse into nested classes + if prefix: + return None, False + # Assuming always an IndentedBlock + if not isinstance(node.body, cst.IndentedBlock): + raise ValueError("ClassDef body is not an IndentedBlock") + + class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + + # First pass: detect if there is a target function in the class + found_in_class = False + new_class_body: list[CSTNode] = [] + for stmt in node.body.body: + filtered, found_target = prune_cst_for_testgen_code( + stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings + ) + found_in_class |= found_target + if filtered: + new_class_body.append(filtered) + + if not found_in_class: + return None, False + + if remove_docstrings: + return node.with_changes( + body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) + ) if new_class_body else None, True + return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True + + # For other nodes, keep the node and recursively filter children + section_names = get_section_names(node) + if not section_names: + return node, False + + updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} + found_any_target = False + + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_target = False + for child in original_content: + filtered, found_target = prune_cst_for_testgen_code( + child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings + ) + if filtered: + new_children.append(filtered) + section_found_target |= found_target + + if section_found_target or new_children: + found_any_target |= section_found_target + updates[section] = new_children + elif original_content is not None: + filtered, found_target = prune_cst_for_testgen_code( + original_content, + target_functions, + helpers_of_helper_functions, + prefix, + remove_docstrings=remove_docstrings, + ) + found_any_target |= found_target + if filtered: + updates[section] = filtered + if updates: + return (node.with_changes(**updates), found_any_target) + + return None, False diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 7dd3d59f8..9cf51cf31 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -54,6 +54,18 @@ class FunctionSource: source_code: str jedi_definition: Name + def __eq__(self, other: object) -> bool: + if not isinstance(other, FunctionSource): + return False + return (self.file_path == other.file_path and + self.qualified_name == other.qualified_name and + self.fully_qualified_name == other.fully_qualified_name and + self.only_function_name == other.only_function_name and + self.source_code == other.source_code) + + def __hash__(self) -> int: + return hash((self.file_path, self.qualified_name, self.fully_qualified_name, + self.only_function_name, self.source_code)) class BestOptimization(BaseModel): candidate: OptimizedCandidate @@ -73,6 +85,7 @@ class CodeStringsMarkdown(BaseModel): @property def markdown(self) -> str: + """Returns the markdown representation of the code, including the file path where possible.""" return "\n".join( [ f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```" @@ -82,12 +95,17 @@ def markdown(self) -> str: class CodeOptimizationContext(BaseModel): - code_to_optimize_with_helpers: str + testgen_context_code: str = "" read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" helper_functions: list[FunctionSource] preexisting_objects: list[tuple[str, list[FunctionParent]]] +class CodeContextType(str, Enum): + READ_WRITABLE = "READ_WRITABLE" + READ_ONLY = "READ_ONLY" + TESTGEN = "TESTGEN" + class OptimizedCandidateResult(BaseModel): max_loop_count: int diff --git a/codeflash/optimization/function_context.py b/codeflash/optimization/function_context.py index 193327d39..4f1c892bc 100644 --- a/codeflash/optimization/function_context.py +++ b/codeflash/optimization/function_context.py @@ -1,28 +1,10 @@ from __future__ import annotations -import ast -import os -import re -from collections import defaultdict -from typing import TYPE_CHECKING - -import jedi -import tiktoken from jedi.api.classes import Name - -from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import get_code from codeflash.code_utils.code_utils import ( get_qualified_name, - module_name_from_file_path, - path_belongs_to_site_packages, -) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, FunctionSource - -if TYPE_CHECKING: - from pathlib import Path +) def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool: """Check if the given name belongs to the specified method.""" @@ -58,242 +40,4 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b return get_qualified_name(name.module_name, name.full_name) == qualified_function_name return False except ValueError: - return False - - -def get_type_annotation_context( - function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path -) -> tuple[list[FunctionSource], set[tuple[str, str]]]: - function_name: str = function.function_name - file_path: Path = function.file_path - file_contents: str = file_path.read_text(encoding="utf8") - try: - module: ast.Module = ast.parse(file_contents) - except SyntaxError as e: - logger.exception(f"get_type_annotation_context - Syntax error in code: {e}") - return [], set() - sources: list[FunctionSource] = [] - ast_parents: list[FunctionParent] = [] - contextual_dunder_methods = set() - - def get_annotation_source( - j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str - ) -> None: - try: - definition: list[Name] = j_script.goto( - line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False - ) - except Exception as ex: - if hasattr(name, "full_name"): - logger.exception(f"Error while getting definition for {name.full_name}: {ex}") - else: - logger.exception(f"Error while getting definition: {ex}") - definition = [] - if definition: # TODO can be multiple definitions - definition_path = definition[0].module_path - - # The definition is part of this project and not defined within the original function - if ( - str(definition_path).startswith(str(project_root_path) + os.sep) - and definition[0].full_name - and not path_belongs_to_site_packages(definition_path) - and not belongs_to_function(definition[0], function_name) - ): - source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])]) - if source_code[0]: - sources.append( - FunctionSource( - fully_qualified_name=definition[0].full_name, - jedi_definition=definition[0], - source_code=source_code[0], - file_path=definition_path, - qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."), - only_function_name=definition[0].name, - ) - ) - contextual_dunder_methods.update(source_code[1]) - - def visit_children( - node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent] - ) -> None: - child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module - for child in ast.iter_child_nodes(node): - visit(child, node_parents) - - def visit_all_annotation_children( - node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent] - ) -> None: - if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): - visit_all_annotation_children(node.left, node_parents) - visit_all_annotation_children(node.right, node_parents) - if isinstance(node, ast.Name) and hasattr(node, "id"): - name: str = node.id - line_no: int = node.lineno - col_no: int = node.col_offset - get_annotation_source(jedi_script, name, node_parents, line_no, col_no) - if isinstance(node, ast.Subscript): - if hasattr(node, "slice"): - if isinstance(node.slice, ast.Subscript): - visit_all_annotation_children(node.slice, node_parents) - elif isinstance(node.slice, ast.Tuple): - for elt in node.slice.elts: - if isinstance(elt, (ast.Name, ast.Subscript)): - visit_all_annotation_children(elt, node_parents) - elif isinstance(node.slice, ast.Name): - visit_all_annotation_children(node.slice, node_parents) - if hasattr(node, "value"): - visit_all_annotation_children(node.value, node_parents) - - def visit( - node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, - node_parents: list[FunctionParent], - ) -> None: - if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - if node.name == function_name and node_parents == function.parents: - arg: ast.arg - for arg in node.args.args: - if arg.annotation: - visit_all_annotation_children(arg.annotation, node_parents) - if node.returns: - visit_all_annotation_children(node.returns, node_parents) - - if not isinstance(node, ast.Module): - node_parents.append(FunctionParent(node.name, type(node).__name__)) - visit_children(node, node_parents) - if not isinstance(node, ast.Module): - node_parents.pop() - - visit(module, ast_parents) - - return sources, contextual_dunder_methods - - -def get_function_variables_definitions( - function_to_optimize: FunctionToOptimize, project_root_path: Path -) -> tuple[list[FunctionSource], set[tuple[str, str]]]: - function_name = function_to_optimize.function_name - file_path = function_to_optimize.file_path - script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) - sources: list[FunctionSource] = [] - contextual_dunder_methods = set() - # TODO: The function name condition can be stricter so that it does not clash with other class names etc. - # TODO: The function could have been imported as some other name, - # we should be checking for the translation as well. Also check for the original function name. - names = [] - for ref in script.get_names(all_scopes=True, definitions=False, references=True): - if ref.full_name: - if function_to_optimize.parents: - # Check if the reference belongs to the specified class when FunctionParent is provided - if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name): - names.append(ref) - elif belongs_to_function(ref, function_name): - names.append(ref) - - for name in names: - try: - definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) - except Exception as e: - try: - logger.exception(f"Error while getting definition for {name.full_name}: {e}") - except Exception as e: - # name.full_name can also throw exceptions sometimes - logger.exception(f"Error while getting definition: {e}") - definitions = [] - if definitions: - # TODO: there can be multiple definitions, see how to handle such cases - definition = definitions[0] - definition_path = definition.module_path - - # The definition is part of this project and not defined within the original function - if ( - str(definition_path).startswith(str(project_root_path) + os.sep) - and not path_belongs_to_site_packages(definition_path) - and definition.full_name - and not belongs_to_function(definition, function_name) - ): - module_name = module_name_from_file_path(definition_path, project_root_path) - m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name) - parents = [] - if m: - parents = [FunctionParent(m.group(1), "ClassDef")] - - source_code = get_code( - [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)] - ) - if source_code[0]: - sources.append( - FunctionSource( - fully_qualified_name=definition.full_name, - jedi_definition=definition, - source_code=source_code[0], - file_path=definition_path, - qualified_name=definition.full_name.removeprefix(definition.module_name + "."), - only_function_name=definition.name, - ) - ) - contextual_dunder_methods.update(source_code[1]) - annotation_sources, annotation_dunder_methods = get_type_annotation_context( - function_to_optimize, script, project_root_path - ) - sources[:0] = annotation_sources # prepend the annotation sources - contextual_dunder_methods.update(annotation_dunder_methods) - existing_fully_qualified_names = set() - no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set)) - parent_sources = set() - for source in sources: - if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names: - if not source.qualified_name.count("."): - no_parent_sources[source.file_path][source.qualified_name].add(source) - else: - parent_sources.add(source) - existing_fully_qualified_names.add(fully_qualified_name) - deduped_parent_sources = [ - source - for source in parent_sources - if source.file_path not in no_parent_sources - or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path] - ] - deduped_no_parent_sources = [ - source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2] - ] - return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods - - -MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k - - -def get_constrained_function_context_and_helper_functions( - function_to_optimize: FunctionToOptimize, - project_root_path: Path, - code_to_optimize: str, - max_tokens: int = MAX_PROMPT_TOKENS, -) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]: - helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path) - tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") - code_to_optimize_tokens = tokenizer.encode(code_to_optimize) - - if not function_to_optimize.parents: - helper_functions_sources = [function.source_code for function in helper_functions] - else: - helper_functions_sources = [ - function.source_code - for function in helper_functions - if not function.qualified_name.count(".") - or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name - ] - helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources] - - context_list = [] - context_len = len(code_to_optimize_tokens) - logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}") - logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}") - for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens): - if context_len + source_len <= max_tokens: - context_list.append(function_source) - context_len += source_len - else: - break - logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}") - helper_code: str = "\n".join(context_list) - return helper_code, helper_functions, dunder_methods + return False \ No newline at end of file diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 91f39c8f8..d5c2651b7 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -21,7 +21,6 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( cleanup_paths, @@ -49,7 +48,6 @@ BestOptimization, CodeOptimizationContext, FunctionCalledInTest, - FunctionParent, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -59,7 +57,6 @@ TestFiles, TestingMode, ) -from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.explanation import Explanation @@ -134,19 +131,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: with helper_function_path.open(encoding="utf8") as f: helper_code = f.read() original_helper_code[helper_function_path] = helper_code - if has_any_async_functions(code_context.code_to_optimize_with_helpers): + if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") code_print(code_context.read_writable_code) - for module_abspath, helper_code_source in original_helper_code.items(): - code_context.code_to_optimize_with_helpers = add_needed_imports_from_module( - helper_code_source, - code_context.code_to_optimize_with_helpers, - module_abspath, - self.function_to_optimize.file_path, - self.args.project_root, - ) - generated_test_paths = [ get_test_file_path( self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" @@ -165,7 +153,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: transient=True, ): generated_results = self.generate_tests_and_optimizations( - code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers, + testgen_context_code=code_context.testgen_context_code, read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, helper_functions=code_context.helper_functions, @@ -554,50 +542,6 @@ def replace_function_and_helpers_with_optimized_code( return did_update def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: - code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize]) - if code_to_optimize is None: - return Failure("Could not find function to optimize.") - (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions( - self.function_to_optimize, self.project_root, code_to_optimize - ) - if self.function_to_optimize.parents: - function_class = self.function_to_optimize.parents[0].name - same_class_helper_methods = [ - df - for df in helper_functions - if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class - ] - optimizable_methods = [ - FunctionToOptimize( - df.qualified_name.split(".")[-1], - df.file_path, - [FunctionParent(df.qualified_name.split(".")[0], "ClassDef")], - None, - None, - ) - for df in same_class_helper_methods - ] + [self.function_to_optimize] - dedup_optimizable_methods = [] - added_methods = set() - for method in reversed(optimizable_methods): - if f"{method.file_path}.{method.qualified_name}" not in added_methods: - dedup_optimizable_methods.append(method) - added_methods.add(f"{method.file_path}.{method.qualified_name}") - if len(dedup_optimizable_methods) > 1: - code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods))) - if code_to_optimize is None: - return Failure("Could not find function to optimize.") - code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize - - code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module( - self.function_to_optimize_source_code, - code_to_optimize_with_helpers, - self.function_to_optimize.file_path, - self.function_to_optimize.file_path, - self.project_root, - helper_functions, - ) - try: new_code_ctx = code_context_extractor.get_code_optimization_context( self.function_to_optimize, self.project_root @@ -607,7 +551,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: return Success( CodeOptimizationContext( - code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports, + testgen_context_code=new_code_ctx.testgen_context_code, read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, helper_functions=new_code_ctx.helper_functions, # only functions that are read writable @@ -709,7 +653,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi def generate_tests_and_optimizations( self, - code_to_optimize_with_helpers: str, + testgen_context_code: str, read_writable_code: str, read_only_context_code: str, helper_functions: list[FunctionSource], @@ -724,7 +668,7 @@ def generate_tests_and_optimizations( # Submit the test generation task as future future_tests = self.generate_and_instrument_tests( executor, - code_to_optimize_with_helpers, + testgen_context_code, [definition.fully_qualified_name for definition in helper_functions], generated_test_paths, generated_perf_test_paths, diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 88b46e87c..0cea00e81 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -741,7 +741,7 @@ def helper_method(self): ending_line=None, ) - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. expected_read_write_context = """ @@ -814,6 +814,57 @@ def helper_method(self): with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) +def test_example_class_token_limit_4() -> None: + string_filler = " ".join( + ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] + ) + code = f""" +class MyClass: + \"\"\"A class with a helper method. \"\"\" + def __init__(self): + self.x = 1 + def target_method(self): + \"\"\"Docstring for target method\"\"\" + y = HelperClass().helper_method() +x = '{string_filler}' + +class HelperClass: + \"\"\"A helper class for MyClass.\"\"\" + def __init__(self): + \"\"\"Initialize the HelperClass.\"\"\" + self.x = 1 + def __repr__(self): + \"\"\"Return a string representation of the HelperClass.\"\"\" + return "HelperClass" + str(self.x) + def helper_method(self): + return self.x +""" + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + # In this scenario, the testgen code context is too long, so we abort. + with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"): + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) def test_repo_helper() -> None: project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 501a0583b..59bdbcc23 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -748,20 +748,23 @@ def main_method(self): def test_code_replacement10() -> None: get_code_output = """from __future__ import annotations +import os + +os.environ["CODEFLASH_API_KEY"] = "cf-test-key" + class HelperClass: def __init__(self, name): self.name = name - def innocent_bystander(self): - pass - def helper_method(self): return self.name + class MainClass: def __init__(self, name): self.name = name + def main_method(self): return HelperClass(self.name).helper_method() """ @@ -778,7 +781,7 @@ def main_method(self): ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output def test_code_replacement11() -> None: diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index 8534cb803..019cc4261 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -6,7 +6,6 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import is_successful from codeflash.models.models import FunctionParent -from codeflash.optimization.function_context import get_function_variables_definitions from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -19,15 +18,6 @@ def simple_function_with_one_dep(data): return calculate_something(data) -def test_simple_dependencies() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("simple_function_with_one_dep", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something" - - def global_dependency_1(num): return num + 1 @@ -94,63 +84,12 @@ def recursive(self, num): return self.recursive(num) + num_1 -def test_multiple_classes_dependencies() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve()) - ) - - assert len(helper_functions) == 2 - assert list(map(lambda x: x.fully_qualified_name, helper_functions[0])) == [ - "test_function_dependencies.global_dependency_3", - "test_function_dependencies.C.calculate_something_3", - ] - - def recursive_dependency_1(num): if num == 0: return 0 num_1 = calculate_something(num) return recursive_dependency_1(num) + num_1 - -def test_recursive_dependency() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("recursive_dependency_1", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something" - assert helper_functions[0].fully_qualified_name == "test_function_dependencies.calculate_something" - - -@dataclass -class MyData: - MyInt: int - - -def calculate_something_ann(data): - return data + 1 - - -def simple_function_with_one_dep_ann(data: MyData): - return calculate_something_ann(data) - - -def list_comprehension_dependency(data: MyData): - return [calculate_something(data) for x in range(10)] - - -def test_simple_dependencies_ann() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 2 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData" - assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann" - - from collections import defaultdict @@ -221,13 +160,14 @@ def test_class_method_dependencies() -> None: ) assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert ( - code_context.code_to_optimize_with_helpers + code_context.testgen_context_code == """from collections import defaultdict class Graph: def __init__(self, vertices): self.graph = defaultdict(list) self.V = vertices # No. of vertices + def topologicalSortUtil(self, v, visited, stack): visited[v] = True @@ -236,6 +176,7 @@ def topologicalSortUtil(self, v, visited, stack): self.topologicalSortUtil(i, visited, stack) stack.insert(0, v) + def topologicalSort(self): visited = [False] * self.V stack = [] @@ -245,40 +186,9 @@ def topologicalSort(self): self.topologicalSortUtil(i, visited, stack) # Print contents of stack - return stack -""" + return stack""" ) - -def calculate_something_else(data): - return data + 1 - - -def imalittledecorator(func): - def wrapper(data): - return func(data) - - return wrapper - - -@imalittledecorator -def simple_function_with_decorator_dep(data): - return calculate_something_else(data) - - -@pytest.mark.skip(reason="no decorator dependency support") -def test_decorator_dependencies() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 2 - assert {helper_functions[0][0].definition.full_name, helper_functions[1][0].definition.full_name} == { - "test_function_dependencies.calculate_something", - "test_function_dependencies.imalittledecorator", - } - - def test_recursive_function_context() -> None: file_path = pathlib.Path(__file__).resolve() @@ -310,73 +220,14 @@ def test_recursive_function_context() -> None: assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3" assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert ( - code_context.code_to_optimize_with_helpers + code_context.testgen_context_code == """class C: def calculate_something_3(self, num): return num + 1 + def recursive(self, num): if num == 0: return 0 num_1 = self.calculate_something_3(num) - return self.recursive(num) + num_1 -""" - ) - - -def test_list_comprehension_dependency() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("list_comprehension_dependency", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 2 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData" - assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something" - - -def test_function_in_method_list_comprehension() -> None: - file_path = pathlib.Path(__file__).resolve() - function_to_optimize = FunctionToOptimize( - function_name="function_in_list_comprehension", - file_path=str(file_path), - parents=[FunctionParent(name="A", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0] - - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.global_dependency_3" - - -def test_method_in_method_list_comprehension() -> None: - file_path = pathlib.Path(__file__).resolve() - function_to_optimize = FunctionToOptimize( - function_name="method_in_list_comprehension", - file_path=str(file_path), - parents=[FunctionParent(name="A", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0] - - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two" - - -def test_nested_method() -> None: - file_path = pathlib.Path(__file__).resolve() - function_to_optimize = FunctionToOptimize( - function_name="nested_function", - file_path=str(file_path), - parents=[FunctionParent(name="A", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0] - - # The nested function should be included in the helper functions - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two" + return self.recursive(num) + num_1""" + ) \ No newline at end of file diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index b7dde84a4..36359d3e3 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -217,6 +217,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: f.write(code) f.flush() file_path = Path(f.name).resolve() + project_root_path = file_path.parent.resolve() function_to_optimize = FunctionToOptimize( function_name="__call__", file_path=file_path, @@ -227,7 +228,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: test_config = TestConfig( tests_root="tests", tests_project_rootdir=Path.cwd(), - project_root_path=file_path.parent.resolve(), + project_root_path=project_root_path, test_framework="pytest", pytest_cmd="pytest", ) @@ -239,13 +240,36 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: pytest.fail() code_context = ctx_result.unwrap() assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" - assert ( - code_context.code_to_optimize_with_helpers - == '''_R = TypeVar("_R") - + code_context.testgen_context_code + == f'''_P = ParamSpec("_P") +_KEY_T = TypeVar("_KEY_T") +_STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): + """Interface for cache backends used by the persistent cache decorator.""" + def __init__(self) -> None: ... + + def hash_key( + self, + *, + func: Callable[_P, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[str, _KEY_T]: ... + + def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401 + ... + + def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401 + ... + + def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ... + + def delete(self, *, key: tuple[str, _KEY_T]) -> None: ... + + def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ... + def get_cache_or_call( self, *, @@ -300,7 +324,33 @@ def get_cache_or_call( # If encoding fails, we should still return the result. return result +_P = ParamSpec("_P") +_R = TypeVar("_R") +_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend) + + class _PersistentCache(Generic[_P, _R, _CacheBackendT]): + """ + A decorator class that provides persistent caching functionality for a function. + + Args: + ---- + func (Callable[_P, _R]): The function to be decorated. + duration (datetime.timedelta): The duration for which the cached results should be considered valid. + backend (_backend): The backend storage for the cached results. + + Attributes: + ---------- + __wrapped__ (Callable[_P, _R]): The wrapped function. + __duration__ (datetime.timedelta): The duration for which the cached results should be considered valid. + __backend__ (_backend): The backend storage for the cached results. + + """ # noqa: E501 + + __wrapped__: Callable[_P, _R] + __duration__: datetime.timedelta + __backend__: _CacheBackendT + def __init__( self, func: Callable[_P, _R], @@ -310,6 +360,7 @@ def __init__( self.__duration__ = duration self.__backend__ = AbstractCacheBackend() functools.update_wrapper(self, func) + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """ Calls the wrapped function, either using the cache or bypassing it based on environment variables. @@ -358,8 +409,11 @@ def test_bubble_sort_deps() -> None: pytest.fail() code_context = ctx_result.unwrap() assert ( - code_context.code_to_optimize_with_helpers - == """def dep1_comparer(arr, j: int) -> bool: + code_context.testgen_context_code + == """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer +from code_to_optimize.bubble_sort_dep2_swap import dep2_swap + +def dep1_comparer(arr, j: int) -> bool: return arr[j] > arr[j + 1] def dep2_swap(arr, j): @@ -367,12 +421,15 @@ def dep2_swap(arr, j): arr[j] = arr[j + 1] arr[j + 1] = temp + + def sorter_deps(arr): for i in range(len(arr)): for j in range(len(arr) - 1): if dep1_comparer(arr, j): dep2_swap(arr, j) return arr + """ ) assert len(code_context.helper_functions) == 2 diff --git a/tests/test_get_read_only_code.py b/tests/test_get_read_only_code.py index 0c71d9b6c..f6e975d5f 100644 --- a/tests/test_get_read_only_code.py +++ b/tests/test_get_read_only_code.py @@ -2,7 +2,8 @@ import pytest -from codeflash.context.code_context_extractor import get_read_only_code +from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.models.models import CodeContextType def test_basic_class() -> None: @@ -22,7 +23,7 @@ class TestClass: class_var = "value" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -46,7 +47,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -72,7 +73,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True) assert dedent(expected).strip() == output.strip() @@ -97,7 +98,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True) assert dedent(expected).strip() == output.strip() @@ -124,7 +125,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True) assert dedent(expected).strip() == output.strip() @@ -142,7 +143,7 @@ def target_method(self): """ with pytest.raises(ValueError, match="No target functions found in the provided code"): - get_read_only_code(dedent(code), {"Outer.Inner.target_method"}, set()) + parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"Outer.Inner.target_method"}, set()) def test_docstrings() -> None: @@ -164,7 +165,7 @@ class TestClass: \"\"\"Class docstring.\"\"\" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -183,7 +184,7 @@ def class_method(cls, param: int = 42) -> None: expected = """""" - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -203,7 +204,7 @@ def __init__(self): expected = """ """ - output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()) assert dedent(expected).strip() == output.strip() @@ -223,7 +224,7 @@ class TestClass: var2: str """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -245,7 +246,7 @@ class TestClass: var2: str """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -271,7 +272,7 @@ class TestClass: continue """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -307,7 +308,7 @@ class TestClass: var2: str """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -322,7 +323,7 @@ def some_function(): expected = """""" - output = get_read_only_code(dedent(code), {"target_function"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()) assert dedent(expected).strip() == output.strip() @@ -341,7 +342,7 @@ def some_function(): x = 5 """ - output = get_read_only_code(dedent(code), {"target_function"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()) assert dedent(expected).strip() == output.strip() @@ -352,7 +353,7 @@ def target_function(self) -> None: if y: x = 5 - else: + else: z = 10 def some_function(): print("wow") @@ -364,11 +365,11 @@ def some_function(): expected = """ if y: x = 5 - else: + else: z = 10 """ - output = get_read_only_code(dedent(code), {"target_function"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()) assert dedent(expected).strip() == output.strip() @@ -403,7 +404,7 @@ class PlatformClass: platform = "other" """ - output = get_read_only_code(dedent(code), {"PlatformClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -462,7 +463,7 @@ class TestClass: error_type = "cleanup" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -515,7 +516,7 @@ class TestClass: context = "cleanup" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -564,7 +565,7 @@ class TestClass: status = "cancelled" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -664,7 +665,7 @@ def __str__(self) -> str: pass """ - output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -751,7 +752,7 @@ def __str__(self) -> str: pass """ - output = get_read_only_code( - dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True + output = parse_code_and_prune_cst( + dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True ) assert dedent(expected).strip() == output.strip() diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index 5040eabe2..d1eeb6e99 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -1,8 +1,8 @@ from textwrap import dedent import pytest - -from codeflash.context.code_context_extractor import get_read_writable_code +from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.models.models import CodeContextType def test_simple_function() -> None: @@ -12,7 +12,7 @@ def target_function(): y = 2 return x + y """ - result = get_read_writable_code(dedent(code), {"target_function"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"}) expected = dedent(""" def target_function(): @@ -31,7 +31,7 @@ def target_function(self): y = 2 return x + y """ - result = get_read_writable_code(dedent(code), {"MyClass.target_function"}) + result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}) expected = dedent(""" class MyClass: @@ -55,7 +55,7 @@ def target_method(self): def other_method(self): print("this should be excluded") """ - result = get_read_writable_code(dedent(code), {"MyClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"}) expected = dedent(""" class MyClass: @@ -79,7 +79,7 @@ class Inner: def not_findable(self): return 42 """ - result = get_read_writable_code(dedent(code), {"Outer.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"Outer.target_method"}) expected = dedent(""" class Outer: @@ -99,7 +99,7 @@ def method1(self): def target_function(): return 42 """ - result = get_read_writable_code(dedent(code), {"target_function"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"}) expected = dedent(""" def target_function(): @@ -122,7 +122,7 @@ class ClassC: def process(self): return "C" """ - result = get_read_writable_code(dedent(code), {"ClassA.process", "ClassC.process"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}) expected = dedent(""" class ClassA: @@ -147,7 +147,7 @@ class ErrorClass: def handle_error(self): print("error") """ - result = get_read_writable_code(dedent(code), {"TargetClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}) expected = dedent(""" try: @@ -174,7 +174,7 @@ def other_method(self): def target_method(self): return f"Value: {self.x}" """ - result = get_read_writable_code(dedent(code), {"MyClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"}) expected = dedent(""" class MyClass: @@ -198,7 +198,7 @@ def other_method(self): def target_method(self): return f"Value: {self.x}" """ - result = get_read_writable_code(dedent(code), {"MyClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"}) expected = dedent(""" class MyClass: @@ -219,7 +219,7 @@ def target(self): pass """ with pytest.raises(ValueError, match="No target functions found in the provided code"): - get_read_writable_code(dedent(code), {"MyClass.Inner.target"}) + parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) def test_module_var() -> None: @@ -243,7 +243,7 @@ def target_function(self) -> None: var2 = "test" """ - output = get_read_writable_code(dedent(code), {"target_function"}) + output = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"}) assert dedent(expected).strip() == output.strip() diff --git a/tests/test_get_testgen_code.py b/tests/test_get_testgen_code.py new file mode 100644 index 000000000..da399a243 --- /dev/null +++ b/tests/test_get_testgen_code.py @@ -0,0 +1,745 @@ +from textwrap import dedent + +import pytest + +from codeflash.models.models import CodeContextType +from codeflash.context.code_context_extractor import parse_code_and_prune_cst + +def test_simple_function() -> None: + code = """ + def target_function(): + x = 1 + y = 2 + return x + y + """ + result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()) + + expected = """ + def target_function(): + x = 1 + y = 2 + return x + y + """ + assert dedent(expected).strip() == result.strip() + +def test_basic_class() -> None: + code = """ + class TestClass: + class_var = "value" + + def target_method(self): + print("This should be included") + + def other_method(self): + print("This too") + """ + + expected = """ + class TestClass: + class_var = "value" + + def target_method(self): + print("This should be included") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_dunder_methods() -> None: + code = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + expected = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_dunder_methods_remove_docstring() -> None: + code = """ + class TestClass: + def __init__(self): + \"\"\"Constructor for TestClass.\"\"\" + self.x = 42 + + def __str__(self): + \"\"\"String representation of TestClass.\"\"\" + return f"Value: {self.x}" + + def target_method(self): + \"\"\"Target method docstring.\"\"\" + print("include me") + """ + + expected = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True) + assert dedent(expected).strip() == output.strip() + + +def test_class_remove_docstring() -> None: + code = """ + class TestClass: + \"\"\"Class docstring.\"\"\" + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + expected = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True) + assert dedent(expected).strip() == output.strip() + + +def test_target_in_nested_class() -> None: + """Test that attempting to find a target in a nested class raises an error.""" + code = """ + class Outer: + outer_var = 1 + + class Inner: + inner_var = 2 + + def target_method(self): + print("include this") + """ + + with pytest.raises(ValueError, match="No target functions found in the provided code"): + parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"Outer.Inner.target_method"}, set()) + +def test_method_signatures() -> None: + code = """ + class TestClass: + @property + def target_method(self) -> str: + \"\"\"Property docstring.\"\"\" + return "value" + + @classmethod + def class_method(cls, param: int = 42) -> None: + print("class method") + """ + + expected = """ + class TestClass: + @property + def target_method(self) -> str: + \"\"\"Property docstring.\"\"\" + return "value" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() +def test_multiple_top_level_targets() -> None: + code = """ + class TestClass: + def target1(self): + print("include 1") + + def target2(self): + print("include 2") + + def __init__(self): + self.x = 42 + + def other_method(self): + print("include other") + """ + + expected = """ + class TestClass: + def target1(self): + print("include 1") + + def target2(self): + print("include 2") + + def __init__(self): + self.x = 42 + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_class_annotations() -> None: + code = """ + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + expected = """ + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_class_annotations_if() -> None: + code = """ + if True: + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + expected = """ + if True: + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_conditional_class_definitions() -> None: + code = """ + if PLATFORM == "linux": + class PlatformClass: + platform = "linux" + def target_method(self): + print("linux") + elif PLATFORM == "windows": + class PlatformClass: + platform = "windows" + def target_method(self): + print("windows") + else: + class PlatformClass: + platform = "other" + def target_method(self): + print("other") + """ + + expected = """ + if PLATFORM == "linux": + class PlatformClass: + platform = "linux" + def target_method(self): + print("linux") + elif PLATFORM == "windows": + class PlatformClass: + platform = "windows" + def target_method(self): + print("windows") + else: + class PlatformClass: + platform = "other" + def target_method(self): + print("other") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_try_except_structure() -> None: + code = """ + try: + class TargetClass: + attr = "value" + def target_method(self): + return 42 + except ValueError: + class ErrorClass: + def handle_error(self): + print("error") + """ + + expected = """ + try: + class TargetClass: + attr = "value" + def target_method(self): + return 42 + except ValueError: + class ErrorClass: + def handle_error(self): + print("error") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_module_var() -> None: + code = """ + def target_function(self) -> None: + self.var2 = "test" + + x = 5 + + def some_function(): + print("wow") + """ + + expected = """ + def target_function(self) -> None: + self.var2 = "test" + + x = 5 + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_module_var_if() -> None: + code = """ + def target_function(self) -> None: + var2 = "test" + + if y: + x = 5 + else: + z = 10 + def some_function(): + print("wow") + + def some_function(): + print("wow") + """ + + expected = """ + def target_function(self) -> None: + var2 = "test" + + if y: + x = 5 + else: + z = 10 + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_multiple_classes() -> None: + code = """ + class ClassA: + def process(self): + return "A" + + class ClassB: + def process(self): + return "B" + + class ClassC: + def process(self): + return "C" + """ + + expected = """ + class ClassA: + def process(self): + return "A" + + class ClassC: + def process(self): + return "C" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_with_statement_and_loops() -> None: + code = """ + with context_manager() as ctx: + while attempt_count < max_attempts: + try: + for item in items: + if item.ready: + class TestClass: + context = "ready" + def target_method(self): + print("ready") + else: + class TestClass: + context = "not_ready" + def target_method(self): + print("not ready") + except ConnectionError: + class TestClass: + context = "connection_error" + def target_method(self): + print("connection error") + continue + finally: + class TestClass: + context = "cleanup" + def target_method(self): + print("cleanup") + """ + + expected = """ + with context_manager() as ctx: + while attempt_count < max_attempts: + try: + for item in items: + if item.ready: + class TestClass: + context = "ready" + def target_method(self): + print("ready") + else: + class TestClass: + context = "not_ready" + def target_method(self): + print("not ready") + except ConnectionError: + class TestClass: + context = "connection_error" + def target_method(self): + print("connection error") + continue + finally: + class TestClass: + context = "cleanup" + def target_method(self): + print("cleanup") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_async_with_try_except() -> None: + code = """ + async with async_context() as ctx: + try: + async for item in items: + if await item.is_valid(): + class TestClass: + status = "valid" + async def target_method(self): + await self.process() + elif await item.can_retry(): + continue + else: + break + except AsyncIOError: + class TestClass: + status = "io_error" + async def target_method(self): + await self.handle_error() + except CancelledError: + class TestClass: + status = "cancelled" + async def target_method(self): + await self.cleanup() + """ + + expected = """ + async with async_context() as ctx: + try: + async for item in items: + if await item.is_valid(): + class TestClass: + status = "valid" + async def target_method(self): + await self.process() + elif await item.can_retry(): + continue + else: + break + except AsyncIOError: + class TestClass: + status = "io_error" + async def target_method(self): + await self.handle_error() + except CancelledError: + class TestClass: + status = "cancelled" + async def target_method(self): + await self.cleanup() + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_simplified_complete_implementation() -> None: + code = """ + class DataProcessor: + \"\"\"A simple data processing class.\"\"\" + + def __init__(self, data: Dict[str, Any]) -> None: + self.data = data + self._processed = False + self.result = None + + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Process and retrieve a specific key from the data.\"\"\" + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + def _process_data(self) -> None: + \"\"\"Internal method to process the data.\"\"\" + processed = {} + for key, value in self.data.items(): + if isinstance(value, (int, float)): + processed[key] = value * 2 + elif isinstance(value, str): + processed[key] = value.upper() + else: + processed[key] = value + self.result = processed + self._processed = True + + def to_json(self) -> str: + \"\"\"Convert the processed data to JSON string.\"\"\" + if not self._processed: + self._process_data() + return json.dumps(self.result) + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + def __init__(self, processor: DataProcessor): + self.processor = processor + self.cache = {} + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Retrieve and cache results for a key.\"\"\" + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + def clear_cache(self) -> None: + \"\"\"Clear the internal cache.\"\"\" + self.cache.clear() + + def get_stats(self) -> Dict[str, int]: + \"\"\"Get cache statistics.\"\"\" + return { + "cache_size": len(self.cache), + "hits": sum(1 for v in self.cache.values() if v is not None) + } + + except Exception as e: + class ResultHandler: + def __init__(self): + self.error = str(e) + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + expected = """ + class DataProcessor: + \"\"\"A simple data processing class.\"\"\" + + def __init__(self, data: Dict[str, Any]) -> None: + self.data = data + self._processed = False + self.result = None + + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Process and retrieve a specific key from the data.\"\"\" + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + def __init__(self, processor: DataProcessor): + self.processor = processor + self.cache = {} + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Retrieve and cache results for a key.\"\"\" + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + except Exception as e: + class ResultHandler: + def __init__(self): + self.error = str(e) + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_simplified_complete_implementation_no_docstring() -> None: + code = """ + class DataProcessor: + \"\"\"A simple data processing class.\"\"\" + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Process and retrieve a specific key from the data.\"\"\" + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + def _process_data(self) -> None: + \"\"\"Internal method to process the data.\"\"\" + processed = {} + for key, value in self.data.items(): + if isinstance(value, (int, float)): + processed[key] = value * 2 + elif isinstance(value, str): + processed[key] = value.upper() + else: + processed[key] = value + self.result = processed + self._processed = True + + def to_json(self) -> str: + \"\"\"Convert the processed data to JSON string.\"\"\" + if not self._processed: + self._process_data() + return json.dumps(self.result) + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Retrieve and cache results for a key.\"\"\" + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + def clear_cache(self) -> None: + \"\"\"Clear the internal cache.\"\"\" + self.cache.clear() + + def get_stats(self) -> Dict[str, int]: + \"\"\"Get cache statistics.\"\"\" + return { + "cache_size": len(self.cache), + "hits": sum(1 for v in self.cache.values() if v is not None) + } + + except Exception as e: + class ResultHandler: + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + expected = """ + class DataProcessor: + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + except Exception as e: + class ResultHandler: + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + output = parse_code_and_prune_cst( + dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True + ) + assert dedent(expected).strip() == output.strip() diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index 79f4bc5dd..e64fea1cf 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -2693,13 +2693,13 @@ def test_code_replacement10() -> None: project_root=str(file_path.parent), original_source_code=original_code, ).unwrap() - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output code_context = opt.get_code_optimization_context( function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code, ) - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output """ expected = """import gc @@ -2758,9 +2758,9 @@ def test_code_replacement10() -> None: with open(file_path) as f: original_code = f.read() code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code).unwrap() - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code) - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output codeflash_con.close() """ diff --git a/tests/test_type_annotation_context.py b/tests/test_type_annotation_context.py deleted file mode 100644 index b10a8ed42..000000000 --- a/tests/test_type_annotation_context.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import pathlib -from dataclasses import dataclass, field -from typing import List - -from codeflash.code_utils.code_extractor import get_code -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions - - -class CustomType: - def __init__(self) -> None: - self.name = None - self.data: List[int] = [] - - -@dataclass -class CustomDataClass: - name: str = "" - data: List[int] = field(default_factory=list) - - -def function_to_optimize(data: CustomType) -> CustomType: - name = data.name - data.data.sort() - return data - - -def function_to_optimize2(data: CustomDataClass) -> CustomType: - name = data.name - data.data.sort() - return data - - -def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None: - name = data.name - data.data.sort() - return data - - -def test_function_context_includes_type_annotation() -> None: - file_path = pathlib.Path(__file__).resolve() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("function_to_optimize", str(file_path), []), - str(file_path.parent.resolve()), - """def function_to_optimize(data: CustomType): - name = data.name - data.data.sort() - return data""", - 1000, - ) - - assert len(helper_functions) == 1 - assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType" - - -def test_function_context_includes_type_annotation_dataclass() -> None: - file_path = pathlib.Path(__file__).resolve() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("function_to_optimize2", str(file_path), []), - str(file_path.parent.resolve()), - """def function_to_optimize2(data: CustomDataClass) -> CustomType: - name = data.name - data.data.sort() - return data""", - 1000, - ) - - assert len(helper_functions) == 2 - assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" - assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" - - -def test_function_context_works_for_composite_types() -> None: - file_path = pathlib.Path(__file__).resolve() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("function_to_optimize3", str(file_path), []), - str(file_path.parent.resolve()), - """def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]: - name = data.name - data.data.sort() - return data""", - 1000, - ) - - assert len(helper_functions) == 2 - assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" - assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" - - -def test_function_context_custom_datatype() -> None: - project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" - file_path = project_path / "math_utils.py" - code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])]) - assert code is not None - assert contextual_dunder_methods == set() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000 - ) - - assert len(helper_functions) == 1 - assert helper_functions[0].fully_qualified_name == "math_utils.Matrix"