From 17a42a218c8b687b2e050940dc78b8c030e5abf6 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Wed, 5 Mar 2025 16:40:23 -0800 Subject: [PATCH 1/7] Implemented testgen context retrieval. Context retrieved is the union of read-writable code and read-only code. Did some refactors to remove code_to_optimize_with_helpers, and updated tests. --- code_to_optimize/bubble_sort2.py | 5 +- code_to_optimize/bubble_sort_deps.py | 139 ---- .../final_test_set/bubble_sort.py | 5 +- .../use_cosine_similarity_from_other_file.py | 108 +-- codeflash/code_utils/coverage_utils.py | 2 +- codeflash/context/code_context_extractor.py | 316 ++++++-- codeflash/models/models.py | 9 +- codeflash/optimization/function_context.py | 474 +++++------ codeflash/optimization/function_optimizer.py | 113 +-- tests/test_code_context_extractor.py | 53 +- tests/test_code_replacement.py | 16 +- tests/test_function_dependencies.py | 170 +--- tests/test_get_helper_code.py | 75 +- tests/test_get_read_only_code.py | 53 +- tests/test_get_read_writable_code.py | 25 +- tests/test_get_testgen_code.py | 745 ++++++++++++++++++ tests/test_instrument_tests.py | 8 +- tests/test_type_annotation_context.py | 206 ++--- 18 files changed, 1573 insertions(+), 949 deletions(-) create mode 100644 tests/test_get_testgen_code.py diff --git a/code_to_optimize/bubble_sort2.py b/code_to_optimize/bubble_sort2.py index fce9e7d77..aa88d2ae4 100644 --- a/code_to_optimize/bubble_sort2.py +++ b/code_to_optimize/bubble_sort2.py @@ -1,6 +1,3 @@ def sorter(arr): arr.sort() - return arr - - -CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()" + return arr \ No newline at end of file diff --git a/code_to_optimize/bubble_sort_deps.py b/code_to_optimize/bubble_sort_deps.py index 7928f19ac..55d7959fa 100644 --- a/code_to_optimize/bubble_sort_deps.py +++ b/code_to_optimize/bubble_sort_deps.py @@ -9,142 +9,3 @@ def sorter_deps(arr): dep2_swap(arr, j) return arr - -CACHED_TESTS = """import dill as pickle -import os -def _log__test__values(values, duration, test_name): - iteration = os.environ["CODEFLASH_TEST_ITERATION"] - with open(os.path.join( - '/var/folders/ms/1tz2l1q55w5b7pp4wpdkbjq80000gn/T/codeflash_jk4pzz3w/', - f'test_return_values_{iteration}.bin'), 'ab') as f: - return_bytes = pickle.dumps(values) - _test_name = f"{test_name}".encode("ascii") - f.write(len(_test_name).to_bytes(4, byteorder='big')) - f.write(_test_name) - f.write(duration.to_bytes(8, byteorder='big')) - f.write(len(return_bytes).to_bytes(4, byteorder='big')) - f.write(return_bytes) -import time -import gc -from code_to_optimize.bubble_sort_deps import sorter_deps -import timeout_decorator -import unittest - -def dep1_comparer(arr, j: int) -> bool: - return arr[j] > arr[j + 1] - -def dep2_swap(arr, j): - temp = arr[j] - arr[j] = arr[j + 1] - arr[j + 1] = temp - -class TestSorterDeps(unittest.TestCase): - - @timeout_decorator.timeout(15, use_signals=True) - def test_integers(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([5, 3, 2, 4, 1]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values( - return_value, duration, - 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_integers:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([10, -3, 0, 2, 7]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values( - return_value, duration, - ('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:' - 'TestSorterDeps.test_integers:sorter_deps:1')) - - @timeout_decorator.timeout(15, use_signals=True) - def test_floats(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([3.2, 1.5, 2.7, 4.1, 1.0]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([-1.1, 0.0, 3.14, 2.71, -0.5]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_floats:sorter_deps:1') - - @timeout_decorator.timeout(15, use_signals=True) - def test_identical_elements(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([1, 1, 1, 1, 1]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - ('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:' - 'TestSorterDeps.test_identical_elements:sorter_deps:0')) - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([3.14, 3.14, 3.14]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, - ('code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:' - 'TestSorterDeps.test_identical_elements:sorter_deps:1')) - - @timeout_decorator.timeout(15, use_signals=True) - def test_single_element(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([5]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([-3.2]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_single_element:sorter_deps:1') - - @timeout_decorator.timeout(15, use_signals=True) - def test_empty_array(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([]) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_empty_array:sorter_deps:0') - - @timeout_decorator.timeout(15, use_signals=True) - def test_strings(self): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps(['apple', 'banana', 'cherry', 'date']) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:0') - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps(['dog', 'cat', 'elephant', 'ant']) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_strings:sorter_deps:1') - - @timeout_decorator.timeout(15, use_signals=True) - def test_mixed_types(self): - with self.assertRaises(TypeError): - gc.disable() - counter = time.perf_counter_ns() - return_value = sorter_deps([1, 'two', 3.0, 'four']) - duration = time.perf_counter_ns() - counter - gc.enable() - _log__test__values(return_value, duration, 'code_to_optimize.tests.unittest.test_sorter_deps__unit_test_0:TestSorterDeps.test_mixed_types:sorter_deps:0_0') -if __name__ == '__main__': - unittest.main() - -""" diff --git a/code_to_optimize/final_test_set/bubble_sort.py b/code_to_optimize/final_test_set/bubble_sort.py index 4a5132ec3..b18994494 100644 --- a/code_to_optimize/final_test_set/bubble_sort.py +++ b/code_to_optimize/final_test_set/bubble_sort.py @@ -5,7 +5,4 @@ def sorter(arr): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp - return arr - - -CACHED_TESTS = "import unittest\ndef sorter(arr):\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n return arr\nclass SorterTestCase(unittest.TestCase):\n def test_empty_list(self):\n self.assertEqual(sorter([]), [])\n def test_single_element_list(self):\n self.assertEqual(sorter([5]), [5])\n def test_ascending_order_list(self):\n self.assertEqual(sorter([1, 2, 3, 4, 5]), [1, 2, 3, 4, 5])\n def test_descending_order_list(self):\n self.assertEqual(sorter([5, 4, 3, 2, 1]), [1, 2, 3, 4, 5])\n def test_random_order_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 5]), [1, 2, 3, 4, 5])\n def test_duplicate_elements_list(self):\n self.assertEqual(sorter([3, 1, 4, 2, 2, 5, 1]), [1, 1, 2, 2, 3, 4, 5])\n def test_negative_numbers_list(self):\n self.assertEqual(sorter([-5, -2, -8, -1, -3]), [-8, -5, -3, -2, -1])\n def test_mixed_data_types_list(self):\n self.assertEqual(sorter(['apple', 2, 'banana', 1, 'cherry']), [1, 2, 'apple', 'banana', 'cherry'])\n def test_large_input_list(self):\n self.assertEqual(sorter(list(range(1000, 0, -1))), list(range(1, 1001)))\n def test_list_with_none_values(self):\n self.assertEqual(sorter([None, 2, None, 1, None]), [None, None, None, 1, 2])\n def test_list_with_nan_values(self):\n self.assertEqual(sorter([float('nan'), 2, float('nan'), 1, float('nan')]), [1, 2, float('nan'), float('nan'), float('nan')])\n def test_list_with_complex_numbers(self):\n self.assertEqual(sorter([3 + 2j, 1 + 1j, 4 + 3j, 2 + 1j, 5 + 4j]), [1 + 1j, 2 + 1j, 3 + 2j, 4 + 3j, 5 + 4j])\n def test_list_with_custom_class_objects(self):\n class Person:\n def __init__(self, name, age):\n self.name = name\n self.age = age\n def __repr__(self):\n return f\"Person('{self.name}', {self.age})\"\n input_list = [Person('Alice', 25), Person('Bob', 30), Person('Charlie', 20)]\n expected_output = [Person('Charlie', 20), Person('Alice', 25), Person('Bob', 30)]\n self.assertEqual(sorter(input_list), expected_output)\n def test_list_with_uncomparable_elements(self):\n with self.assertRaises(TypeError):\n sorter([5, 'apple', 3, [1, 2, 3], 2])\n def test_list_with_custom_comparison_function(self):\n input_list = [5, 4, 3, 2, 1]\n expected_output = [5, 4, 3, 2, 1]\n self.assertEqual(sorter(input_list, reverse=True), expected_output)\nif __name__ == '__main__':\n unittest.main()" + return arr \ No newline at end of file diff --git a/code_to_optimize/use_cosine_similarity_from_other_file.py b/code_to_optimize/use_cosine_similarity_from_other_file.py index a3ffa60ab..4c41b9ee6 100644 --- a/code_to_optimize/use_cosine_similarity_from_other_file.py +++ b/code_to_optimize/use_cosine_similarity_from_other_file.py @@ -9,110 +9,4 @@ def use_cosine_similarity( top_k: Optional[int] = 5, score_threshold: Optional[float] = None, ) -> Tuple[List[Tuple[int, int]], List[float]]: - return cosine_similarity_top_k(X, Y, top_k, score_threshold) - - -CACHED_TESTS = """import unittest -import numpy as np -from sklearn.metrics.pairwise import cosine_similarity -from typing import List, Optional, Tuple, Union -Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] -def cosine_similarity_top_k(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]: - \"\"\"Row-wise cosine similarity with optional top-k and score threshold filtering. - Args: - X: Matrix. - Y: Matrix, same width as X. - top_k: Max number of results to return. - score_threshold: Minimum cosine similarity of results. - Returns: - Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx), - second contains corresponding cosine similarities. - \"\"\" - if len(X) == 0 or len(Y) == 0: - return ([], []) - score_array = cosine_similarity(X, Y) - sorted_idxs = score_array.flatten().argsort()[::-1] - top_k = top_k or len(sorted_idxs) - top_idxs = sorted_idxs[:top_k] - score_threshold = score_threshold or -1.0 - top_idxs = top_idxs[score_array.flatten()[top_idxs] > score_threshold] - ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in top_idxs] - scores = score_array.flatten()[top_idxs].tolist() - return (ret_idxs, scores) -def use_cosine_similarity(X: Matrix, Y: Matrix, top_k: Optional[int]=5, score_threshold: Optional[float]=None) -> Tuple[List[Tuple[int, int]], List[float]]: - return cosine_similarity_top_k(X, Y, top_k, score_threshold) -class TestUseCosineSimilarity(unittest.TestCase): - def test_normal_scenario(self): - X = [[1, 2, 3], [4, 5, 6]] - Y = [[7, 8, 9], [10, 11, 12]] - result = use_cosine_similarity(X, Y, top_k=1, score_threshold=0.5) - self.assertEqual(result, ([(0, 1)], [0.9746318461970762])) - def test_edge_case_empty_matrices(self): - X = [] - Y = [] - result = use_cosine_similarity(X, Y) - self.assertEqual(result, ([], [])) - def test_edge_case_different_widths(self): - X = [[1, 2, 3]] - Y = [[4, 5]] - with self.assertRaises(ValueError): - use_cosine_similarity(X, Y) - def test_edge_case_negative_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - with self.assertRaises(IndexError): - use_cosine_similarity(X, Y, top_k=-1) - def test_edge_case_zero_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, top_k=0) - self.assertEqual(result, ([], [])) - def test_edge_case_negative_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, score_threshold=-1.0) - self.assertEqual(result, ([(0, 0)], [0.9746318461970762])) - def test_edge_case_large_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, score_threshold=2.0) - self.assertEqual(result, ([], [])) - def test_exceptional_case_non_matrix_X(self): - X = [1, 2, 3] - Y = [[4, 5, 6]] - with self.assertRaises(ValueError): - use_cosine_similarity(X, Y) - def test_exceptional_case_non_integer_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - with self.assertRaises(TypeError): - use_cosine_similarity(X, Y, top_k='5') - def test_exceptional_case_non_float_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - with self.assertRaises(TypeError): - use_cosine_similarity(X, Y, score_threshold='0.5') - def test_special_values_nan_in_matrices(self): - X = [[1, 2, np.nan]] - Y = [[4, 5, 6]] - with self.assertRaises(ValueError): - use_cosine_similarity(X, Y) - def test_special_values_none_top_k(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, top_k=None) - self.assertEqual(result, ([(0, 0)], [0.9746318461970762])) - def test_special_values_none_score_threshold(self): - X = [[1, 2, 3]] - Y = [[4, 5, 6]] - result = use_cosine_similarity(X, Y, score_threshold=None) - self.assertEqual(result, ([(0, 0)], [0.9746318461970762])) - def test_large_inputs(self): - X = np.random.rand(1000, 1000) - Y = np.random.rand(1000, 1000) - result = use_cosine_similarity(X, Y, top_k=10, score_threshold=0.5) - self.assertEqual(len(result[0]), 10) - self.assertEqual(len(result[1]), 10) - self.assertTrue(all((score > 0.5 for score in result[1]))) -if __name__ == '__main__': - unittest.main()""" + return cosine_similarity_top_k(X, Y, top_k, score_threshold) \ No newline at end of file diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/code_utils/coverage_utils.py index 34c5475ec..21aa06ad9 100644 --- a/codeflash/code_utils/coverage_utils.py +++ b/codeflash/code_utils/coverage_utils.py @@ -12,7 +12,7 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]: """Extract the single dependent function from the code context excluding the main function.""" - ast_tree = ast.parse(code_context.code_to_optimize_with_helpers) + ast_tree = ast.parse(code_context.testgen_context_code) dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)} diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 0a7caab45..f2ebe5655 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -15,12 +15,13 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource +from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource, \ + CodeContextType from codeflash.optimization.function_context import belongs_to_function_qualified def get_code_optimization_context( - function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000 + function_to_optimize: FunctionToOptimize, project_root_path: Path, optim_token_limit: int = 8000, testgen_token_limit: int = 8000 ) -> CodeOptimizationContext: # Get qualified names and fully qualified names(fqn) of helpers helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict( @@ -37,21 +38,22 @@ def get_code_optimization_context( function_to_optimize.qualified_name_with_modules_from_root(project_root_path) ) - # Extract code - final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code - read_only_code_markdown = get_all_read_only_code_context( + # Extract code context for optimization + final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code + read_only_code_markdown = extract_code_markdown_context_from_files( helpers_of_fto, helpers_of_fto_fqn, helpers_of_helpers, helpers_of_helpers_fqn, project_root_path, remove_docstrings=False, + code_context_type=CodeContextType.READ_ONLY, ) # Handle token limits tokenizer = tiktoken.encoding_for_model("gpt-4o") final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code)) - if final_read_writable_tokens > token_limit: + if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") # Setup preexisting objects for code replacer TODO: should remove duplicates @@ -61,53 +63,82 @@ def get_code_optimization_context( *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) ) - read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown)) - total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens - if total_tokens <= token_limit: - return CodeOptimizationContext( - code_to_optimize_with_helpers="", - read_writable_code=CodeString(code=final_read_writable_code).code, - read_only_context_code=read_only_code_markdown.markdown, - helper_functions=helpers_of_fto_obj_list, - preexisting_objects=preexisting_objects, + read_only_context_code = read_only_code_markdown.markdown + read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code)) + total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens + if total_tokens > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") + # Extract read only code without docstrings + read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files( + helpers_of_fto, + helpers_of_fto_fqn, + helpers_of_helpers, + helpers_of_helpers_fqn, + project_root_path, + remove_docstrings=True, ) - - logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") - - # Extract read only code without docstrings - read_only_code_no_docstring_markdown = get_all_read_only_code_context( + read_only_context_code = read_only_code_no_docstring_markdown.markdown + read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code)) + total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens + if total_tokens > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing read-only code") + read_only_context_code = "" + # Extract code context for testgen + testgen_code_markdown = extract_code_markdown_context_from_files( helpers_of_fto, helpers_of_fto_fqn, helpers_of_helpers, helpers_of_helpers_fqn, project_root_path, - remove_docstrings=True, + remove_docstrings=False, + code_context_type=CodeContextType.TESTGEN, ) - read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown)) - total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens - if total_tokens <= token_limit: - return CodeOptimizationContext( - code_to_optimize_with_helpers="", - read_writable_code=CodeString(code=final_read_writable_code).code, - read_only_context_code=read_only_code_no_docstring_markdown.markdown, - helper_functions=helpers_of_fto_obj_list, - preexisting_objects=preexisting_objects, + testgen_context_code = testgen_code_markdown.markdown + testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) + if testgen_context_code_tokens > testgen_token_limit: + testgen_code_markdown = extract_code_markdown_context_from_files( + helpers_of_fto, + helpers_of_fto_fqn, + helpers_of_helpers, + helpers_of_helpers_fqn, + project_root_path, + remove_docstrings=True, + code_context_type=CodeContextType.TESTGEN, ) + testgen_context_code = testgen_code_markdown.markdown + testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) + if testgen_context_code_tokens > testgen_token_limit: + raise ValueError("Testgen code context has exceeded token limit, cannot proceed") - logger.debug("Code context has exceeded token limit, removing read-only code") return CodeOptimizationContext( - code_to_optimize_with_helpers="", + testgen_context_code = testgen_context_code, read_writable_code=CodeString(code=final_read_writable_code).code, - read_only_context_code="", + read_only_context_code=read_only_context_code, helper_functions=helpers_of_fto_obj_list, preexisting_objects=preexisting_objects, ) -def get_all_read_writable_code( +def extract_code_string_context_from_files( helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path ) -> CodeString: + """Extract read-writable code context from files containing target functions and their helpers. + + This function iterates through each file path that contains functions to optimize (fto) or + their first-degree helpers, reads the original code, extracts relevant parts using CST parsing, + and adds necessary imports from the original modules. + + Args: + helpers_of_fto: Dictionary mapping file paths to sets of qualified function names + helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions + project_root_path: Root path of the project for resolving relative imports + + Returns: + CodeString object containing the consolidated read-writable code with all necessary + imports for the target functions and their helpers + + """ final_read_writable_code = "" # Extract code from file paths that contain fto and first degree helpers for file_path, qualified_function_names in helpers_of_fto.items(): @@ -117,7 +148,7 @@ def get_all_read_writable_code( logger.exception(f"Error while parsing {file_path}: {e}") continue try: - read_writable_code = get_read_writable_code(original_code, qualified_function_names) + read_writable_code = parse_code_and_prune_cst(original_code, CodeContextType.READ_WRITABLE, qualified_function_names) except ValueError as e: logger.debug(f"Error while getting read-writable code: {e}") continue @@ -133,16 +164,38 @@ def get_all_read_writable_code( helper_functions_fqn=helpers_of_fto_fqn[file_path], ) return CodeString(code=final_read_writable_code) - - -def get_all_read_only_code_context( +def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], helpers_of_helpers: dict[Path, set[str]], helpers_of_helpers_fqn: dict[Path, set[str]], project_root_path: Path, remove_docstrings: bool = False, + code_context_type: CodeContextType = CodeContextType.READ_ONLY, ) -> CodeStringsMarkdown: + """Extract code context from files containing target functions and their helpers, formatting them as markdown. + + This function processes two sets of files: + 1. Files containing the function to optimize (fto) and their first-degree helpers + 2. Files containing only helpers of helpers (with no overlap with the first set) + + For each file, it extracts relevant code based on the specified context type, adds necessary + imports, and combines them into a structured markdown format. + + Args: + helpers_of_fto: Dictionary mapping file paths to sets of function names to be optimized + helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions to be optimized + helpers_of_helpers: Dictionary mapping file paths to sets of helper function names + helpers_of_helpers_fqn: Dictionary mapping file paths to sets of fully qualified names of helper functions + project_root_path: Root path of the project + remove_docstrings: Whether to remove docstrings from the extracted code + code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) + + Returns: + CodeStringsMarkdown containing the extracted code context with necessary imports, + formatted for inclusion in markdown + + """ # Rearrange to remove overlaps, so we only access each file path once helpers_of_helpers_no_overlap = defaultdict(set) helpers_of_helpers_no_overlap_fqn = defaultdict(set) @@ -155,7 +208,7 @@ def get_all_read_only_code_context( helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path] helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path] - read_only_code_markdown = CodeStringsMarkdown() + code_context_markdown = CodeStringsMarkdown() # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files for file_path, qualified_function_names in helpers_of_fto.items(): try: @@ -164,17 +217,18 @@ def get_all_read_only_code_context( logger.exception(f"Error while parsing {file_path}: {e}") continue try: - read_only_code = get_read_only_code( - original_code, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings + code_context = parse_code_and_prune_cst( + original_code, code_context_type, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings ) + except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue - if read_only_code.strip(): - read_only_code_with_imports = CodeString( + if code_context.strip(): + code_context_with_imports = CodeString( code=add_needed_imports_from_module( src_module_code=original_code, - dst_module_code=read_only_code, + dst_module_code=code_context, src_path=file_path, dst_path=file_path, project_root=project_root_path, @@ -182,7 +236,7 @@ def get_all_read_only_code_context( ), file_path=file_path.relative_to(project_root_path), ) - read_only_code_markdown.code_strings.append(read_only_code_with_imports) + code_context_markdown.code_strings.append(code_context_with_imports) # Extract code from file paths containing helpers of helpers for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items(): @@ -192,18 +246,18 @@ def get_all_read_only_code_context( logger.exception(f"Error while parsing {file_path}: {e}") continue try: - read_only_code = get_read_only_code( - original_code, set(), qualified_helper_function_names, remove_docstrings + code_context = parse_code_and_prune_cst( + original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue - if read_only_code.strip(): - read_only_code_with_imports = CodeString( + if code_context.strip(): + code_context_with_imports = CodeString( code=add_needed_imports_from_module( src_module_code=original_code, - dst_module_code=read_only_code, + dst_module_code=code_context, src_path=file_path, dst_path=file_path, project_root=project_root_path, @@ -211,8 +265,8 @@ def get_all_read_only_code_context( ), file_path=file_path.relative_to(project_root_path), ) - read_only_code_markdown.code_strings.append(read_only_code_with_imports) - return read_only_code_markdown + code_context_markdown.code_strings.append(code_context_with_imports) + return code_context_markdown def get_file_path_to_helper_functions_dict( @@ -221,11 +275,11 @@ def get_file_path_to_helper_functions_dict( file_path_to_helper_function_qualified_names = defaultdict(set) file_path_to_helper_function_fqn = defaultdict(set) function_source_list: list[FunctionSource] = [] - for file_path in file_path_to_qualified_function_names: + for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) file_refs = script.get_names(all_scopes=True, definitions=False, references=True) - for qualified_function_name in file_path_to_qualified_function_names[file_path]: + for qualified_function_name in qualified_function_names: names = [ ref for ref in file_refs @@ -291,6 +345,29 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode return indented_block.with_changes(body=indented_block.body[1:]) return indented_block +def parse_code_and_prune_cst( + code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = {}, remove_docstrings: bool = False +) -> str: + """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables. """ + module = cst.parse_module(code) + if code_context_type == CodeContextType.READ_WRITABLE: + filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions) + elif code_context_type == CodeContextType.READ_ONLY: + filtered_node, found_target = prune_cst_for_read_only_code( + module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + ) + elif code_context_type == CodeContextType.TESTGEN: + filtered_node, found_target = prune_cst_for_testgen_code( + module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings + ) + else: + raise ValueError(f"Unknown code_context_type: {code_context_type}") + + if not found_target: + raise ValueError("No target functions found in the provided code") + if filtered_node and isinstance(filtered_node, cst.Module): + return str(filtered_node.code) + return "" def prune_cst_for_read_writable_code( node: cst.CSTNode, target_functions: set[str], prefix: str = "" @@ -371,20 +448,6 @@ def prune_cst_for_read_writable_code( return (node.with_changes(**updates) if updates else node), True - -def get_read_writable_code(code: str, target_functions: set[str]) -> str: - """Creates a read-writable code string by parsing and filtering the code to keep only - target functions and the minimal surrounding structure. - """ - module = cst.parse_module(code) - filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions) - if not found_target: - raise ValueError("No target functions found in the provided code") - if filtered_node and isinstance(filtered_node, cst.Module): - return str(filtered_node.code) - return "" - - def prune_cst_for_read_only_code( node: cst.CSTNode, target_functions: set[str], @@ -489,18 +552,107 @@ def prune_cst_for_read_only_code( return None, False -def get_read_only_code( - code: str, target_functions: set[str], helpers_of_helper_functions: set[str], remove_docstrings: bool = False -) -> str: - """Creates a read-only version of the code by parsing and filtering the code to keep only - class contextual information, and other module scoped variables. + +def prune_cst_for_testgen_code( + node: cst.CSTNode, + target_functions: set[str], + helpers_of_helper_functions: set[str], + prefix: str = "", + remove_docstrings: bool = False, +) -> tuple[cst.CSTNode | None, bool]: + """Recursively filter the node for testgen context: + + Returns: + (filtered_node, found_target): + filtered_node: The modified CST node or None if it should be removed. + found_target: True if a target function was found in this node's subtree. + """ - module = cst.parse_module(code) - filtered_node, found_target = prune_cst_for_read_only_code( - module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings - ) - if not found_target: - raise ValueError("No target functions found in the provided code") - if filtered_node and isinstance(filtered_node, cst.Module): - return str(filtered_node.code) - return "" + if isinstance(node, (cst.Import, cst.ImportFrom)): + return None, False + + if isinstance(node, cst.FunctionDef): + qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value + # If it's a target function, remove it but mark found_target = True + if qualified_name in helpers_of_helper_functions or qualified_name in target_functions: + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + new_body = remove_docstring_from_body(node.body) + return node.with_changes(body=new_body), True + return node, True + # Keep all dunder methods + if is_dunder_method(node.name.value): + if remove_docstrings and isinstance(node.body, cst.IndentedBlock): + new_body = remove_docstring_from_body(node.body) + return node.with_changes(body=new_body), False + return node, False + return None, False + + if isinstance(node, cst.ClassDef): + # Do not recurse into nested classes + if prefix: + return None, False + # Assuming always an IndentedBlock + if not isinstance(node.body, cst.IndentedBlock): + raise ValueError("ClassDef body is not an IndentedBlock") + + class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + + # First pass: detect if there is a target function in the class + found_in_class = False + new_class_body: list[CSTNode] = [] + for stmt in node.body.body: + filtered, found_target = prune_cst_for_testgen_code( + stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings + ) + found_in_class |= found_target + if filtered: + new_class_body.append(filtered) + + if not found_in_class: + return None, False + + if remove_docstrings: + return node.with_changes( + body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) + ) if new_class_body else None, True + return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True + + # For other nodes, keep the node and recursively filter children + section_names = get_section_names(node) + if not section_names: + return node, False + + updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} + found_any_target = False + + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_target = False + for child in original_content: + filtered, found_target = prune_cst_for_testgen_code( + child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings + ) + if filtered: + new_children.append(filtered) + section_found_target |= found_target + + if section_found_target or new_children: + found_any_target |= section_found_target + updates[section] = new_children + elif original_content is not None: + filtered, found_target = prune_cst_for_testgen_code( + original_content, + target_functions, + helpers_of_helper_functions, + prefix, + remove_docstrings=remove_docstrings, + ) + found_any_target |= found_target + if filtered: + updates[section] = filtered + if updates: + return (node.with_changes(**updates), found_any_target) + + return None, False diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 27f36ca67..08c62129c 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -72,6 +72,7 @@ class CodeStringsMarkdown(BaseModel): @property def markdown(self) -> str: + """Returns the markdown representation of the code, including the file path where possible.""" return "\n".join( [ f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```" @@ -81,12 +82,18 @@ def markdown(self) -> str: class CodeOptimizationContext(BaseModel): - code_to_optimize_with_helpers: str + # code_to_optimize_with_helpers: str + testgen_context_code: str = "" read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" helper_functions: list[FunctionSource] preexisting_objects: list[tuple[str, list[FunctionParent]]] +class CodeContextType(str, Enum): + READ_WRITABLE = "READ_WRITABLE" + READ_ONLY = "READ_ONLY" + TESTGEN = "TESTGEN" + class OptimizedCandidateResult(BaseModel): max_loop_count: int diff --git a/codeflash/optimization/function_context.py b/codeflash/optimization/function_context.py index 193327d39..7840660c3 100644 --- a/codeflash/optimization/function_context.py +++ b/codeflash/optimization/function_context.py @@ -60,240 +60,240 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b except ValueError: return False - -def get_type_annotation_context( - function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path -) -> tuple[list[FunctionSource], set[tuple[str, str]]]: - function_name: str = function.function_name - file_path: Path = function.file_path - file_contents: str = file_path.read_text(encoding="utf8") - try: - module: ast.Module = ast.parse(file_contents) - except SyntaxError as e: - logger.exception(f"get_type_annotation_context - Syntax error in code: {e}") - return [], set() - sources: list[FunctionSource] = [] - ast_parents: list[FunctionParent] = [] - contextual_dunder_methods = set() - - def get_annotation_source( - j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str - ) -> None: - try: - definition: list[Name] = j_script.goto( - line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False - ) - except Exception as ex: - if hasattr(name, "full_name"): - logger.exception(f"Error while getting definition for {name.full_name}: {ex}") - else: - logger.exception(f"Error while getting definition: {ex}") - definition = [] - if definition: # TODO can be multiple definitions - definition_path = definition[0].module_path - - # The definition is part of this project and not defined within the original function - if ( - str(definition_path).startswith(str(project_root_path) + os.sep) - and definition[0].full_name - and not path_belongs_to_site_packages(definition_path) - and not belongs_to_function(definition[0], function_name) - ): - source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])]) - if source_code[0]: - sources.append( - FunctionSource( - fully_qualified_name=definition[0].full_name, - jedi_definition=definition[0], - source_code=source_code[0], - file_path=definition_path, - qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."), - only_function_name=definition[0].name, - ) - ) - contextual_dunder_methods.update(source_code[1]) - - def visit_children( - node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent] - ) -> None: - child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module - for child in ast.iter_child_nodes(node): - visit(child, node_parents) - - def visit_all_annotation_children( - node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent] - ) -> None: - if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): - visit_all_annotation_children(node.left, node_parents) - visit_all_annotation_children(node.right, node_parents) - if isinstance(node, ast.Name) and hasattr(node, "id"): - name: str = node.id - line_no: int = node.lineno - col_no: int = node.col_offset - get_annotation_source(jedi_script, name, node_parents, line_no, col_no) - if isinstance(node, ast.Subscript): - if hasattr(node, "slice"): - if isinstance(node.slice, ast.Subscript): - visit_all_annotation_children(node.slice, node_parents) - elif isinstance(node.slice, ast.Tuple): - for elt in node.slice.elts: - if isinstance(elt, (ast.Name, ast.Subscript)): - visit_all_annotation_children(elt, node_parents) - elif isinstance(node.slice, ast.Name): - visit_all_annotation_children(node.slice, node_parents) - if hasattr(node, "value"): - visit_all_annotation_children(node.value, node_parents) - - def visit( - node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, - node_parents: list[FunctionParent], - ) -> None: - if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - if node.name == function_name and node_parents == function.parents: - arg: ast.arg - for arg in node.args.args: - if arg.annotation: - visit_all_annotation_children(arg.annotation, node_parents) - if node.returns: - visit_all_annotation_children(node.returns, node_parents) - - if not isinstance(node, ast.Module): - node_parents.append(FunctionParent(node.name, type(node).__name__)) - visit_children(node, node_parents) - if not isinstance(node, ast.Module): - node_parents.pop() - - visit(module, ast_parents) - - return sources, contextual_dunder_methods - - -def get_function_variables_definitions( - function_to_optimize: FunctionToOptimize, project_root_path: Path -) -> tuple[list[FunctionSource], set[tuple[str, str]]]: - function_name = function_to_optimize.function_name - file_path = function_to_optimize.file_path - script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) - sources: list[FunctionSource] = [] - contextual_dunder_methods = set() - # TODO: The function name condition can be stricter so that it does not clash with other class names etc. - # TODO: The function could have been imported as some other name, - # we should be checking for the translation as well. Also check for the original function name. - names = [] - for ref in script.get_names(all_scopes=True, definitions=False, references=True): - if ref.full_name: - if function_to_optimize.parents: - # Check if the reference belongs to the specified class when FunctionParent is provided - if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name): - names.append(ref) - elif belongs_to_function(ref, function_name): - names.append(ref) - - for name in names: - try: - definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) - except Exception as e: - try: - logger.exception(f"Error while getting definition for {name.full_name}: {e}") - except Exception as e: - # name.full_name can also throw exceptions sometimes - logger.exception(f"Error while getting definition: {e}") - definitions = [] - if definitions: - # TODO: there can be multiple definitions, see how to handle such cases - definition = definitions[0] - definition_path = definition.module_path - - # The definition is part of this project and not defined within the original function - if ( - str(definition_path).startswith(str(project_root_path) + os.sep) - and not path_belongs_to_site_packages(definition_path) - and definition.full_name - and not belongs_to_function(definition, function_name) - ): - module_name = module_name_from_file_path(definition_path, project_root_path) - m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name) - parents = [] - if m: - parents = [FunctionParent(m.group(1), "ClassDef")] - - source_code = get_code( - [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)] - ) - if source_code[0]: - sources.append( - FunctionSource( - fully_qualified_name=definition.full_name, - jedi_definition=definition, - source_code=source_code[0], - file_path=definition_path, - qualified_name=definition.full_name.removeprefix(definition.module_name + "."), - only_function_name=definition.name, - ) - ) - contextual_dunder_methods.update(source_code[1]) - annotation_sources, annotation_dunder_methods = get_type_annotation_context( - function_to_optimize, script, project_root_path - ) - sources[:0] = annotation_sources # prepend the annotation sources - contextual_dunder_methods.update(annotation_dunder_methods) - existing_fully_qualified_names = set() - no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set)) - parent_sources = set() - for source in sources: - if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names: - if not source.qualified_name.count("."): - no_parent_sources[source.file_path][source.qualified_name].add(source) - else: - parent_sources.add(source) - existing_fully_qualified_names.add(fully_qualified_name) - deduped_parent_sources = [ - source - for source in parent_sources - if source.file_path not in no_parent_sources - or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path] - ] - deduped_no_parent_sources = [ - source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2] - ] - return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods - - -MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k - - -def get_constrained_function_context_and_helper_functions( - function_to_optimize: FunctionToOptimize, - project_root_path: Path, - code_to_optimize: str, - max_tokens: int = MAX_PROMPT_TOKENS, -) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]: - helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path) - tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") - code_to_optimize_tokens = tokenizer.encode(code_to_optimize) - - if not function_to_optimize.parents: - helper_functions_sources = [function.source_code for function in helper_functions] - else: - helper_functions_sources = [ - function.source_code - for function in helper_functions - if not function.qualified_name.count(".") - or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name - ] - helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources] - - context_list = [] - context_len = len(code_to_optimize_tokens) - logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}") - logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}") - for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens): - if context_len + source_len <= max_tokens: - context_list.append(function_source) - context_len += source_len - else: - break - logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}") - helper_code: str = "\n".join(context_list) - return helper_code, helper_functions, dunder_methods +# +# def get_type_annotation_context( +# function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path +# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]: +# function_name: str = function.function_name +# file_path: Path = function.file_path +# file_contents: str = file_path.read_text(encoding="utf8") +# try: +# module: ast.Module = ast.parse(file_contents) +# except SyntaxError as e: +# logger.exception(f"get_type_annotation_context - Syntax error in code: {e}") +# return [], set() +# sources: list[FunctionSource] = [] +# ast_parents: list[FunctionParent] = [] +# contextual_dunder_methods = set() +# +# def get_annotation_source( +# j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str +# ) -> None: +# try: +# definition: list[Name] = j_script.goto( +# line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False +# ) +# except Exception as ex: +# if hasattr(name, "full_name"): +# logger.exception(f"Error while getting definition for {name.full_name}: {ex}") +# else: +# logger.exception(f"Error while getting definition: {ex}") +# definition = [] +# if definition: # TODO can be multiple definitions +# definition_path = definition[0].module_path +# +# # The definition is part of this project and not defined within the original function +# if ( +# str(definition_path).startswith(str(project_root_path) + os.sep) +# and definition[0].full_name +# and not path_belongs_to_site_packages(definition_path) +# and not belongs_to_function(definition[0], function_name) +# ): +# source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])]) +# if source_code[0]: +# sources.append( +# FunctionSource( +# fully_qualified_name=definition[0].full_name, +# jedi_definition=definition[0], +# source_code=source_code[0], +# file_path=definition_path, +# qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."), +# only_function_name=definition[0].name, +# ) +# ) +# contextual_dunder_methods.update(source_code[1]) +# +# def visit_children( +# node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent] +# ) -> None: +# child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module +# for child in ast.iter_child_nodes(node): +# visit(child, node_parents) +# +# def visit_all_annotation_children( +# node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent] +# ) -> None: +# if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): +# visit_all_annotation_children(node.left, node_parents) +# visit_all_annotation_children(node.right, node_parents) +# if isinstance(node, ast.Name) and hasattr(node, "id"): +# name: str = node.id +# line_no: int = node.lineno +# col_no: int = node.col_offset +# get_annotation_source(jedi_script, name, node_parents, line_no, col_no) +# if isinstance(node, ast.Subscript): +# if hasattr(node, "slice"): +# if isinstance(node.slice, ast.Subscript): +# visit_all_annotation_children(node.slice, node_parents) +# elif isinstance(node.slice, ast.Tuple): +# for elt in node.slice.elts: +# if isinstance(elt, (ast.Name, ast.Subscript)): +# visit_all_annotation_children(elt, node_parents) +# elif isinstance(node.slice, ast.Name): +# visit_all_annotation_children(node.slice, node_parents) +# if hasattr(node, "value"): +# visit_all_annotation_children(node.value, node_parents) +# +# def visit( +# node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, +# node_parents: list[FunctionParent], +# ) -> None: +# if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): +# if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): +# if node.name == function_name and node_parents == function.parents: +# arg: ast.arg +# for arg in node.args.args: +# if arg.annotation: +# visit_all_annotation_children(arg.annotation, node_parents) +# if node.returns: +# visit_all_annotation_children(node.returns, node_parents) +# +# if not isinstance(node, ast.Module): +# node_parents.append(FunctionParent(node.name, type(node).__name__)) +# visit_children(node, node_parents) +# if not isinstance(node, ast.Module): +# node_parents.pop() +# +# visit(module, ast_parents) +# +# return sources, contextual_dunder_methods + + +# def get_function_variables_definitions( +# function_to_optimize: FunctionToOptimize, project_root_path: Path +# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]: +# function_name = function_to_optimize.function_name +# file_path = function_to_optimize.file_path +# script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) +# sources: list[FunctionSource] = [] +# contextual_dunder_methods = set() +# # TODO: The function name condition can be stricter so that it does not clash with other class names etc. +# # TODO: The function could have been imported as some other name, +# # we should be checking for the translation as well. Also check for the original function name. +# names = [] +# for ref in script.get_names(all_scopes=True, definitions=False, references=True): +# if ref.full_name: +# if function_to_optimize.parents: +# # Check if the reference belongs to the specified class when FunctionParent is provided +# if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name): +# names.append(ref) +# elif belongs_to_function(ref, function_name): +# names.append(ref) +# +# for name in names: +# try: +# definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) +# except Exception as e: +# try: +# logger.exception(f"Error while getting definition for {name.full_name}: {e}") +# except Exception as e: +# # name.full_name can also throw exceptions sometimes +# logger.exception(f"Error while getting definition: {e}") +# definitions = [] +# if definitions: +# # TODO: there can be multiple definitions, see how to handle such cases +# definition = definitions[0] +# definition_path = definition.module_path +# +# # The definition is part of this project and not defined within the original function +# if ( +# str(definition_path).startswith(str(project_root_path) + os.sep) +# and not path_belongs_to_site_packages(definition_path) +# and definition.full_name +# and not belongs_to_function(definition, function_name) +# ): +# module_name = module_name_from_file_path(definition_path, project_root_path) +# m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name) +# parents = [] +# if m: +# parents = [FunctionParent(m.group(1), "ClassDef")] +# +# source_code = get_code( +# [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)] +# ) +# if source_code[0]: +# sources.append( +# FunctionSource( +# fully_qualified_name=definition.full_name, +# jedi_definition=definition, +# source_code=source_code[0], +# file_path=definition_path, +# qualified_name=definition.full_name.removeprefix(definition.module_name + "."), +# only_function_name=definition.name, +# ) +# ) +# contextual_dunder_methods.update(source_code[1]) +# annotation_sources, annotation_dunder_methods = get_type_annotation_context( +# function_to_optimize, script, project_root_path +# ) +# sources[:0] = annotation_sources # prepend the annotation sources +# contextual_dunder_methods.update(annotation_dunder_methods) +# existing_fully_qualified_names = set() +# no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set)) +# parent_sources = set() +# for source in sources: +# if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names: +# if not source.qualified_name.count("."): +# no_parent_sources[source.file_path][source.qualified_name].add(source) +# else: +# parent_sources.add(source) +# existing_fully_qualified_names.add(fully_qualified_name) +# deduped_parent_sources = [ +# source +# for source in parent_sources +# if source.file_path not in no_parent_sources +# or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path] +# ] +# deduped_no_parent_sources = [ +# source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2] +# ] +# return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods +# +# +# MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k +# +# +# def get_constrained_function_context_and_helper_functions( +# function_to_optimize: FunctionToOptimize, +# project_root_path: Path, +# code_to_optimize: str, +# max_tokens: int = MAX_PROMPT_TOKENS, +# ) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]: +# helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path) +# tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") +# code_to_optimize_tokens = tokenizer.encode(code_to_optimize) +# +# if not function_to_optimize.parents: +# helper_functions_sources = [function.source_code for function in helper_functions] +# else: +# helper_functions_sources = [ +# function.source_code +# for function in helper_functions +# if not function.qualified_name.count(".") +# or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name +# ] +# helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources] +# +# context_list = [] +# context_len = len(code_to_optimize_tokens) +# logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}") +# logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}") +# for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens): +# if context_len + source_len <= max_tokens: +# context_list.append(function_source) +# context_len += source_len +# else: +# break +# logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}") +# helper_code: str = "\n".join(context_list) +# return helper_code, helper_functions, dunder_methods diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7b067a094..23cff0c17 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -58,7 +58,7 @@ TestFiles, TestingMode, ) -from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions +# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.explanation import Explanation @@ -140,14 +140,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: logger.info("Code to be optimized:") code_print(code_context.read_writable_code) - for module_abspath, helper_code_source in original_helper_code.items(): - code_context.code_to_optimize_with_helpers = add_needed_imports_from_module( - helper_code_source, - code_context.code_to_optimize_with_helpers, - module_abspath, - self.function_to_optimize.file_path, - self.args.project_root, - ) + # for module_abspath, helper_code_source in original_helper_code.items(): + # code_context.code_to_optimize_with_helpers = add_needed_imports_from_module( + # helper_code_source, + # code_context.code_to_optimize_with_helpers, + # module_abspath, + # self.function_to_optimize.file_path, + # self.args.project_root, + # ) generated_test_paths = [ get_test_file_path( @@ -167,7 +167,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: transient=True, ): generated_results = self.generate_tests_and_optimizations( - code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers, + testgen_context_code=code_context.testgen_context_code, read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, helper_functions=code_context.helper_functions, @@ -556,49 +556,49 @@ def replace_function_and_helpers_with_optimized_code( return did_update def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: - code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize]) - if code_to_optimize is None: - return Failure("Could not find function to optimize.") - (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions( - self.function_to_optimize, self.project_root, code_to_optimize - ) - if self.function_to_optimize.parents: - function_class = self.function_to_optimize.parents[0].name - same_class_helper_methods = [ - df - for df in helper_functions - if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class - ] - optimizable_methods = [ - FunctionToOptimize( - df.qualified_name.split(".")[-1], - df.file_path, - [FunctionParent(df.qualified_name.split(".")[0], "ClassDef")], - None, - None, - ) - for df in same_class_helper_methods - ] + [self.function_to_optimize] - dedup_optimizable_methods = [] - added_methods = set() - for method in reversed(optimizable_methods): - if f"{method.file_path}.{method.qualified_name}" not in added_methods: - dedup_optimizable_methods.append(method) - added_methods.add(f"{method.file_path}.{method.qualified_name}") - if len(dedup_optimizable_methods) > 1: - code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods))) - if code_to_optimize is None: - return Failure("Could not find function to optimize.") - code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize - - code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module( - self.function_to_optimize_source_code, - code_to_optimize_with_helpers, - self.function_to_optimize.file_path, - self.function_to_optimize.file_path, - self.project_root, - helper_functions, - ) + # code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize]) + # if code_to_optimize is None: + # return Failure("Could not find function to optimize.") + # (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions( + # self.function_to_optimize, self.project_root, code_to_optimize + # ) + # if self.function_to_optimize.parents: + # function_class = self.function_to_optimize.parents[0].name + # same_class_helper_methods = [ + # df + # for df in helper_functions + # if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class + # ] + # optimizable_methods = [ + # FunctionToOptimize( + # df.qualified_name.split(".")[-1], + # df.file_path, + # [FunctionParent(df.qualified_name.split(".")[0], "ClassDef")], + # None, + # None, + # ) + # for df in same_class_helper_methods + # ] + [self.function_to_optimize] + # dedup_optimizable_methods = [] + # added_methods = set() + # for method in reversed(optimizable_methods): + # if f"{method.file_path}.{method.qualified_name}" not in added_methods: + # dedup_optimizable_methods.append(method) + # added_methods.add(f"{method.file_path}.{method.qualified_name}") + # if len(dedup_optimizable_methods) > 1: + # code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods))) + # if code_to_optimize is None: + # return Failure("Could not find function to optimize.") + # code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize + # + # code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module( + # self.function_to_optimize_source_code, + # code_to_optimize_with_helpers, + # self.function_to_optimize.file_path, + # self.function_to_optimize.file_path, + # self.project_root, + # helper_functions, + # ) try: new_code_ctx = code_context_extractor.get_code_optimization_context( @@ -609,7 +609,8 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: return Success( CodeOptimizationContext( - code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports, + # code_to_optimize_with_helpers=new_code_ctx.testgen_context_code, # Outdated, fix this! + testgen_context_code=new_code_ctx.testgen_context_code, read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, helper_functions=new_code_ctx.helper_functions, # only functions that are read writable @@ -711,7 +712,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi def generate_tests_and_optimizations( self, - code_to_optimize_with_helpers: str, + testgen_context_code: str, read_writable_code: str, read_only_context_code: str, helper_functions: list[FunctionSource], @@ -726,7 +727,7 @@ def generate_tests_and_optimizations( # Submit the test generation task as future future_tests = self.generate_and_instrument_tests( executor, - code_to_optimize_with_helpers, + testgen_context_code, [definition.fully_qualified_name for definition in helper_functions], generated_test_paths, generated_perf_test_paths, diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 7f4a94845..434437894 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -740,7 +740,7 @@ def helper_method(self): ending_line=None, ) - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. expected_read_write_context = """ @@ -813,6 +813,57 @@ def helper_method(self): with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) +def test_example_class_token_limit_4() -> None: + string_filler = " ".join( + ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] + ) + code = f""" +class MyClass: + \"\"\"A class with a helper method. \"\"\" + def __init__(self): + self.x = 1 + def target_method(self): + \"\"\"Docstring for target method\"\"\" + y = HelperClass().helper_method() +x = '{string_filler}' + +class HelperClass: + \"\"\"A helper class for MyClass.\"\"\" + def __init__(self): + \"\"\"Initialize the HelperClass.\"\"\" + self.x = 1 + def __repr__(self): + \"\"\"Return a string representation of the HelperClass.\"\"\" + return "HelperClass" + str(self.x) + def helper_method(self): + return self.x +""" + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + # In this scenario, the testgen code context is too long, so we abort. + with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"): + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) def test_repo_helper() -> None: project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 501a0583b..d8855e87f 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -747,24 +747,28 @@ def main_method(self): def test_code_replacement10() -> None: - get_code_output = """from __future__ import annotations + get_code_output = """```python:test_code_replacement.py +from __future__ import annotations +import os + +os.environ["CODEFLASH_API_KEY"] = "cf-test-key" + class HelperClass: def __init__(self, name): self.name = name - def innocent_bystander(self): - pass - def helper_method(self): return self.name + class MainClass: def __init__(self, name): self.name = name + def main_method(self): return HelperClass(self.name).helper_method() -""" +```""" file_path = Path(__file__).resolve() func_top_optimize = FunctionToOptimize( function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")] @@ -778,7 +782,7 @@ def main_method(self): ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output def test_code_replacement11() -> None: diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index faa754af9..390ea37fc 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -5,7 +5,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import is_successful from codeflash.models.models import FunctionParent -from codeflash.optimization.function_context import get_function_variables_definitions +# from codeflash.optimization.function_context import get_function_variables_definitions from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -18,15 +18,6 @@ def simple_function_with_one_dep(data): return calculate_something(data) -def test_simple_dependencies() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("simple_function_with_one_dep", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something" - - def global_dependency_1(num): return num + 1 @@ -93,63 +84,12 @@ def recursive(self, num): return self.recursive(num) + num_1 -def test_multiple_classes_dependencies() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve()) - ) - - assert len(helper_functions) == 2 - assert list(map(lambda x: x.fully_qualified_name, helper_functions[0])) == [ - "test_function_dependencies.global_dependency_3", - "test_function_dependencies.C.calculate_something_3", - ] - - def recursive_dependency_1(num): if num == 0: return 0 num_1 = calculate_something(num) return recursive_dependency_1(num) + num_1 - -def test_recursive_dependency() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("recursive_dependency_1", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something" - assert helper_functions[0].fully_qualified_name == "test_function_dependencies.calculate_something" - - -@dataclass -class MyData: - MyInt: int - - -def calculate_something_ann(data): - return data + 1 - - -def simple_function_with_one_dep_ann(data: MyData): - return calculate_something_ann(data) - - -def list_comprehension_dependency(data: MyData): - return [calculate_something(data) for x in range(10)] - - -def test_simple_dependencies_ann() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 2 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData" - assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann" - - from collections import defaultdict @@ -220,13 +160,15 @@ def test_class_method_dependencies() -> None: ) assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert ( - code_context.code_to_optimize_with_helpers - == """from collections import defaultdict + code_context.testgen_context_code + == """```python:test_function_dependencies.py +from collections import defaultdict class Graph: def __init__(self, vertices): self.graph = defaultdict(list) self.V = vertices # No. of vertices + def topologicalSortUtil(self, v, visited, stack): visited[v] = True @@ -235,6 +177,7 @@ def topologicalSortUtil(self, v, visited, stack): self.topologicalSortUtil(i, visited, stack) stack.insert(0, v) + def topologicalSort(self): visited = [False] * self.V stack = [] @@ -245,39 +188,9 @@ def topologicalSort(self): # Print contents of stack return stack -""" +```""" ) - -def calculate_something_else(data): - return data + 1 - - -def imalittledecorator(func): - def wrapper(data): - return func(data) - - return wrapper - - -@imalittledecorator -def simple_function_with_decorator_dep(data): - return calculate_something_else(data) - - -@pytest.mark.skip(reason="no decorator dependency support") -def test_decorator_dependencies() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 2 - assert {helper_functions[0][0].definition.full_name, helper_functions[1][0].definition.full_name} == { - "test_function_dependencies.calculate_something", - "test_function_dependencies.imalittledecorator", - } - - def test_recursive_function_context() -> None: file_path = pathlib.Path(__file__).resolve() @@ -309,73 +222,16 @@ def test_recursive_function_context() -> None: assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3" assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert ( - code_context.code_to_optimize_with_helpers - == """class C: + code_context.testgen_context_code + == """```python:test_function_dependencies.py +class C: def calculate_something_3(self, num): return num + 1 + def recursive(self, num): if num == 0: return 0 num_1 = self.calculate_something_3(num) return self.recursive(num) + num_1 -""" - ) - - -def test_list_comprehension_dependency() -> None: - file_path = pathlib.Path(__file__).resolve() - helper_functions = get_function_variables_definitions( - FunctionToOptimize("list_comprehension_dependency", str(file_path), []), str(file_path.parent.resolve()) - )[0] - assert len(helper_functions) == 2 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData" - assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something" - - -def test_function_in_method_list_comprehension() -> None: - file_path = pathlib.Path(__file__).resolve() - function_to_optimize = FunctionToOptimize( - function_name="function_in_list_comprehension", - file_path=str(file_path), - parents=[FunctionParent(name="A", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0] - - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.global_dependency_3" - - -def test_method_in_method_list_comprehension() -> None: - file_path = pathlib.Path(__file__).resolve() - function_to_optimize = FunctionToOptimize( - function_name="method_in_list_comprehension", - file_path=str(file_path), - parents=[FunctionParent(name="A", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0] - - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two" - - -def test_nested_method() -> None: - file_path = pathlib.Path(__file__).resolve() - function_to_optimize = FunctionToOptimize( - function_name="nested_function", - file_path=str(file_path), - parents=[FunctionParent(name="A", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - helper_functions = get_function_variables_definitions(function_to_optimize, str(file_path.parent.resolve()))[0] - - # The nested function should be included in the helper functions - assert len(helper_functions) == 1 - assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two" +```""" + ) \ No newline at end of file diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index b7dde84a4..553243a8c 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -239,13 +239,37 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: pytest.fail() code_context = ctx_result.unwrap() assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" - assert ( - code_context.code_to_optimize_with_helpers - == '''_R = TypeVar("_R") - + code_context.testgen_context_code + == f'''```python:{file_path.name} +_P = ParamSpec("_P") +_KEY_T = TypeVar("_KEY_T") +_STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): + """Interface for cache backends used by the persistent cache decorator.""" + def __init__(self) -> None: ... + + def hash_key( + self, + *, + func: Callable[_P, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[str, _KEY_T]: ... + + def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401 + ... + + def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401 + ... + + def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ... + + def delete(self, *, key: tuple[str, _KEY_T]) -> None: ... + + def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ... + def get_cache_or_call( self, *, @@ -300,7 +324,33 @@ def get_cache_or_call( # If encoding fails, we should still return the result. return result +_P = ParamSpec("_P") +_R = TypeVar("_R") +_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend) + + class _PersistentCache(Generic[_P, _R, _CacheBackendT]): + """ + A decorator class that provides persistent caching functionality for a function. + + Args: + ---- + func (Callable[_P, _R]): The function to be decorated. + duration (datetime.timedelta): The duration for which the cached results should be considered valid. + backend (_backend): The backend storage for the cached results. + + Attributes: + ---------- + __wrapped__ (Callable[_P, _R]): The wrapped function. + __duration__ (datetime.timedelta): The duration for which the cached results should be considered valid. + __backend__ (_backend): The backend storage for the cached results. + + """ # noqa: E501 + + __wrapped__: Callable[_P, _R] + __duration__: datetime.timedelta + __backend__: _CacheBackendT + def __init__( self, func: Callable[_P, _R], @@ -310,6 +360,7 @@ def __init__( self.__duration__ = duration self.__backend__ = AbstractCacheBackend() functools.update_wrapper(self, func) + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """ Calls the wrapped function, either using the cache or bypassing it based on environment variables. @@ -333,7 +384,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: kwargs=kwargs, lifespan=self.__duration__, ) -''' +```''' ) @@ -358,14 +409,20 @@ def test_bubble_sort_deps() -> None: pytest.fail() code_context = ctx_result.unwrap() assert ( - code_context.code_to_optimize_with_helpers - == """def dep1_comparer(arr, j: int) -> bool: + code_context.testgen_context_code + == """```python:code_to_optimize/bubble_sort_dep1_helper.py +def dep1_comparer(arr, j: int) -> bool: return arr[j] > arr[j + 1] - +``` +```python:code_to_optimize/bubble_sort_dep2_swap.py def dep2_swap(arr, j): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp +``` +```python:code_to_optimize/bubble_sort_deps.py +from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer +from code_to_optimize.bubble_sort_dep2_swap import dep2_swap def sorter_deps(arr): for i in range(len(arr)): @@ -373,7 +430,7 @@ def sorter_deps(arr): if dep1_comparer(arr, j): dep2_swap(arr, j) return arr -""" +```""" ) assert len(code_context.helper_functions) == 2 assert ( diff --git a/tests/test_get_read_only_code.py b/tests/test_get_read_only_code.py index 0c71d9b6c..f6e975d5f 100644 --- a/tests/test_get_read_only_code.py +++ b/tests/test_get_read_only_code.py @@ -2,7 +2,8 @@ import pytest -from codeflash.context.code_context_extractor import get_read_only_code +from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.models.models import CodeContextType def test_basic_class() -> None: @@ -22,7 +23,7 @@ class TestClass: class_var = "value" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -46,7 +47,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -72,7 +73,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True) assert dedent(expected).strip() == output.strip() @@ -97,7 +98,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True) assert dedent(expected).strip() == output.strip() @@ -124,7 +125,7 @@ def __str__(self): return f"Value: {self.x}" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True) assert dedent(expected).strip() == output.strip() @@ -142,7 +143,7 @@ def target_method(self): """ with pytest.raises(ValueError, match="No target functions found in the provided code"): - get_read_only_code(dedent(code), {"Outer.Inner.target_method"}, set()) + parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"Outer.Inner.target_method"}, set()) def test_docstrings() -> None: @@ -164,7 +165,7 @@ class TestClass: \"\"\"Class docstring.\"\"\" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -183,7 +184,7 @@ def class_method(cls, param: int = 42) -> None: expected = """""" - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -203,7 +204,7 @@ def __init__(self): expected = """ """ - output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()) assert dedent(expected).strip() == output.strip() @@ -223,7 +224,7 @@ class TestClass: var2: str """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -245,7 +246,7 @@ class TestClass: var2: str """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -271,7 +272,7 @@ class TestClass: continue """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -307,7 +308,7 @@ class TestClass: var2: str """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -322,7 +323,7 @@ def some_function(): expected = """""" - output = get_read_only_code(dedent(code), {"target_function"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()) assert dedent(expected).strip() == output.strip() @@ -341,7 +342,7 @@ def some_function(): x = 5 """ - output = get_read_only_code(dedent(code), {"target_function"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()) assert dedent(expected).strip() == output.strip() @@ -352,7 +353,7 @@ def target_function(self) -> None: if y: x = 5 - else: + else: z = 10 def some_function(): print("wow") @@ -364,11 +365,11 @@ def some_function(): expected = """ if y: x = 5 - else: + else: z = 10 """ - output = get_read_only_code(dedent(code), {"target_function"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()) assert dedent(expected).strip() == output.strip() @@ -403,7 +404,7 @@ class PlatformClass: platform = "other" """ - output = get_read_only_code(dedent(code), {"PlatformClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -462,7 +463,7 @@ class TestClass: error_type = "cleanup" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -515,7 +516,7 @@ class TestClass: context = "cleanup" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -564,7 +565,7 @@ class TestClass: status = "cancelled" """ - output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -664,7 +665,7 @@ def __str__(self) -> str: pass """ - output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set()) + output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()) assert dedent(expected).strip() == output.strip() @@ -751,7 +752,7 @@ def __str__(self) -> str: pass """ - output = get_read_only_code( - dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True + output = parse_code_and_prune_cst( + dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True ) assert dedent(expected).strip() == output.strip() diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index 1680f2403..d1eeb6e99 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -1,7 +1,8 @@ from textwrap import dedent import pytest -from codeflash.context.code_context_extractor import get_read_writable_code +from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.models.models import CodeContextType def test_simple_function() -> None: @@ -11,7 +12,7 @@ def target_function(): y = 2 return x + y """ - result = get_read_writable_code(dedent(code), {"target_function"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"}) expected = dedent(""" def target_function(): @@ -30,7 +31,7 @@ def target_function(self): y = 2 return x + y """ - result = get_read_writable_code(dedent(code), {"MyClass.target_function"}) + result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}) expected = dedent(""" class MyClass: @@ -54,7 +55,7 @@ def target_method(self): def other_method(self): print("this should be excluded") """ - result = get_read_writable_code(dedent(code), {"MyClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"}) expected = dedent(""" class MyClass: @@ -78,7 +79,7 @@ class Inner: def not_findable(self): return 42 """ - result = get_read_writable_code(dedent(code), {"Outer.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"Outer.target_method"}) expected = dedent(""" class Outer: @@ -98,7 +99,7 @@ def method1(self): def target_function(): return 42 """ - result = get_read_writable_code(dedent(code), {"target_function"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"}) expected = dedent(""" def target_function(): @@ -121,7 +122,7 @@ class ClassC: def process(self): return "C" """ - result = get_read_writable_code(dedent(code), {"ClassA.process", "ClassC.process"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}) expected = dedent(""" class ClassA: @@ -146,7 +147,7 @@ class ErrorClass: def handle_error(self): print("error") """ - result = get_read_writable_code(dedent(code), {"TargetClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}) expected = dedent(""" try: @@ -173,7 +174,7 @@ def other_method(self): def target_method(self): return f"Value: {self.x}" """ - result = get_read_writable_code(dedent(code), {"MyClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"}) expected = dedent(""" class MyClass: @@ -197,7 +198,7 @@ def other_method(self): def target_method(self): return f"Value: {self.x}" """ - result = get_read_writable_code(dedent(code), {"MyClass.target_method"}) + result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.target_method"}) expected = dedent(""" class MyClass: @@ -218,7 +219,7 @@ def target(self): pass """ with pytest.raises(ValueError, match="No target functions found in the provided code"): - get_read_writable_code(dedent(code), {"MyClass.Inner.target"}) + parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) def test_module_var() -> None: @@ -242,7 +243,7 @@ def target_function(self) -> None: var2 = "test" """ - output = get_read_writable_code(dedent(code), {"target_function"}) + output = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"target_function"}) assert dedent(expected).strip() == output.strip() diff --git a/tests/test_get_testgen_code.py b/tests/test_get_testgen_code.py new file mode 100644 index 000000000..da399a243 --- /dev/null +++ b/tests/test_get_testgen_code.py @@ -0,0 +1,745 @@ +from textwrap import dedent + +import pytest + +from codeflash.models.models import CodeContextType +from codeflash.context.code_context_extractor import parse_code_and_prune_cst + +def test_simple_function() -> None: + code = """ + def target_function(): + x = 1 + y = 2 + return x + y + """ + result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()) + + expected = """ + def target_function(): + x = 1 + y = 2 + return x + y + """ + assert dedent(expected).strip() == result.strip() + +def test_basic_class() -> None: + code = """ + class TestClass: + class_var = "value" + + def target_method(self): + print("This should be included") + + def other_method(self): + print("This too") + """ + + expected = """ + class TestClass: + class_var = "value" + + def target_method(self): + print("This should be included") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_dunder_methods() -> None: + code = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + expected = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_dunder_methods_remove_docstring() -> None: + code = """ + class TestClass: + def __init__(self): + \"\"\"Constructor for TestClass.\"\"\" + self.x = 42 + + def __str__(self): + \"\"\"String representation of TestClass.\"\"\" + return f"Value: {self.x}" + + def target_method(self): + \"\"\"Target method docstring.\"\"\" + print("include me") + """ + + expected = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True) + assert dedent(expected).strip() == output.strip() + + +def test_class_remove_docstring() -> None: + code = """ + class TestClass: + \"\"\"Class docstring.\"\"\" + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + expected = """ + class TestClass: + def __init__(self): + self.x = 42 + + def __str__(self): + return f"Value: {self.x}" + + def target_method(self): + print("include me") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True) + assert dedent(expected).strip() == output.strip() + + +def test_target_in_nested_class() -> None: + """Test that attempting to find a target in a nested class raises an error.""" + code = """ + class Outer: + outer_var = 1 + + class Inner: + inner_var = 2 + + def target_method(self): + print("include this") + """ + + with pytest.raises(ValueError, match="No target functions found in the provided code"): + parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"Outer.Inner.target_method"}, set()) + +def test_method_signatures() -> None: + code = """ + class TestClass: + @property + def target_method(self) -> str: + \"\"\"Property docstring.\"\"\" + return "value" + + @classmethod + def class_method(cls, param: int = 42) -> None: + print("class method") + """ + + expected = """ + class TestClass: + @property + def target_method(self) -> str: + \"\"\"Property docstring.\"\"\" + return "value" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() +def test_multiple_top_level_targets() -> None: + code = """ + class TestClass: + def target1(self): + print("include 1") + + def target2(self): + print("include 2") + + def __init__(self): + self.x = 42 + + def other_method(self): + print("include other") + """ + + expected = """ + class TestClass: + def target1(self): + print("include 1") + + def target2(self): + print("include 2") + + def __init__(self): + self.x = 42 + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_class_annotations() -> None: + code = """ + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + expected = """ + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_class_annotations_if() -> None: + code = """ + if True: + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + expected = """ + if True: + class TestClass: + var1: int = 42 + var2: str + + def target_method(self) -> None: + self.var2 = "test" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_conditional_class_definitions() -> None: + code = """ + if PLATFORM == "linux": + class PlatformClass: + platform = "linux" + def target_method(self): + print("linux") + elif PLATFORM == "windows": + class PlatformClass: + platform = "windows" + def target_method(self): + print("windows") + else: + class PlatformClass: + platform = "other" + def target_method(self): + print("other") + """ + + expected = """ + if PLATFORM == "linux": + class PlatformClass: + platform = "linux" + def target_method(self): + print("linux") + elif PLATFORM == "windows": + class PlatformClass: + platform = "windows" + def target_method(self): + print("windows") + else: + class PlatformClass: + platform = "other" + def target_method(self): + print("other") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_try_except_structure() -> None: + code = """ + try: + class TargetClass: + attr = "value" + def target_method(self): + return 42 + except ValueError: + class ErrorClass: + def handle_error(self): + print("error") + """ + + expected = """ + try: + class TargetClass: + attr = "value" + def target_method(self): + return 42 + except ValueError: + class ErrorClass: + def handle_error(self): + print("error") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_module_var() -> None: + code = """ + def target_function(self) -> None: + self.var2 = "test" + + x = 5 + + def some_function(): + print("wow") + """ + + expected = """ + def target_function(self) -> None: + self.var2 = "test" + + x = 5 + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_module_var_if() -> None: + code = """ + def target_function(self) -> None: + var2 = "test" + + if y: + x = 5 + else: + z = 10 + def some_function(): + print("wow") + + def some_function(): + print("wow") + """ + + expected = """ + def target_function(self) -> None: + var2 = "test" + + if y: + x = 5 + else: + z = 10 + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_multiple_classes() -> None: + code = """ + class ClassA: + def process(self): + return "A" + + class ClassB: + def process(self): + return "B" + + class ClassC: + def process(self): + return "C" + """ + + expected = """ + class ClassA: + def process(self): + return "A" + + class ClassC: + def process(self): + return "C" + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_with_statement_and_loops() -> None: + code = """ + with context_manager() as ctx: + while attempt_count < max_attempts: + try: + for item in items: + if item.ready: + class TestClass: + context = "ready" + def target_method(self): + print("ready") + else: + class TestClass: + context = "not_ready" + def target_method(self): + print("not ready") + except ConnectionError: + class TestClass: + context = "connection_error" + def target_method(self): + print("connection error") + continue + finally: + class TestClass: + context = "cleanup" + def target_method(self): + print("cleanup") + """ + + expected = """ + with context_manager() as ctx: + while attempt_count < max_attempts: + try: + for item in items: + if item.ready: + class TestClass: + context = "ready" + def target_method(self): + print("ready") + else: + class TestClass: + context = "not_ready" + def target_method(self): + print("not ready") + except ConnectionError: + class TestClass: + context = "connection_error" + def target_method(self): + print("connection error") + continue + finally: + class TestClass: + context = "cleanup" + def target_method(self): + print("cleanup") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_async_with_try_except() -> None: + code = """ + async with async_context() as ctx: + try: + async for item in items: + if await item.is_valid(): + class TestClass: + status = "valid" + async def target_method(self): + await self.process() + elif await item.can_retry(): + continue + else: + break + except AsyncIOError: + class TestClass: + status = "io_error" + async def target_method(self): + await self.handle_error() + except CancelledError: + class TestClass: + status = "cancelled" + async def target_method(self): + await self.cleanup() + """ + + expected = """ + async with async_context() as ctx: + try: + async for item in items: + if await item.is_valid(): + class TestClass: + status = "valid" + async def target_method(self): + await self.process() + elif await item.can_retry(): + continue + else: + break + except AsyncIOError: + class TestClass: + status = "io_error" + async def target_method(self): + await self.handle_error() + except CancelledError: + class TestClass: + status = "cancelled" + async def target_method(self): + await self.cleanup() + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + +def test_simplified_complete_implementation() -> None: + code = """ + class DataProcessor: + \"\"\"A simple data processing class.\"\"\" + + def __init__(self, data: Dict[str, Any]) -> None: + self.data = data + self._processed = False + self.result = None + + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Process and retrieve a specific key from the data.\"\"\" + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + def _process_data(self) -> None: + \"\"\"Internal method to process the data.\"\"\" + processed = {} + for key, value in self.data.items(): + if isinstance(value, (int, float)): + processed[key] = value * 2 + elif isinstance(value, str): + processed[key] = value.upper() + else: + processed[key] = value + self.result = processed + self._processed = True + + def to_json(self) -> str: + \"\"\"Convert the processed data to JSON string.\"\"\" + if not self._processed: + self._process_data() + return json.dumps(self.result) + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + def __init__(self, processor: DataProcessor): + self.processor = processor + self.cache = {} + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Retrieve and cache results for a key.\"\"\" + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + def clear_cache(self) -> None: + \"\"\"Clear the internal cache.\"\"\" + self.cache.clear() + + def get_stats(self) -> Dict[str, int]: + \"\"\"Get cache statistics.\"\"\" + return { + "cache_size": len(self.cache), + "hits": sum(1 for v in self.cache.values() if v is not None) + } + + except Exception as e: + class ResultHandler: + def __init__(self): + self.error = str(e) + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + expected = """ + class DataProcessor: + \"\"\"A simple data processing class.\"\"\" + + def __init__(self, data: Dict[str, Any]) -> None: + self.data = data + self._processed = False + self.result = None + + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Process and retrieve a specific key from the data.\"\"\" + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + def __init__(self, processor: DataProcessor): + self.processor = processor + self.cache = {} + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Retrieve and cache results for a key.\"\"\" + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + except Exception as e: + class ResultHandler: + def __init__(self): + self.error = str(e) + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()) + assert dedent(expected).strip() == output.strip() + + +def test_simplified_complete_implementation_no_docstring() -> None: + code = """ + class DataProcessor: + \"\"\"A simple data processing class.\"\"\" + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Process and retrieve a specific key from the data.\"\"\" + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + def _process_data(self) -> None: + \"\"\"Internal method to process the data.\"\"\" + processed = {} + for key, value in self.data.items(): + if isinstance(value, (int, float)): + processed[key] = value * 2 + elif isinstance(value, str): + processed[key] = value.upper() + else: + processed[key] = value + self.result = processed + self._processed = True + + def to_json(self) -> str: + \"\"\"Convert the processed data to JSON string.\"\"\" + if not self._processed: + self._process_data() + return json.dumps(self.result) + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + \"\"\"Retrieve and cache results for a key.\"\"\" + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + def clear_cache(self) -> None: + \"\"\"Clear the internal cache.\"\"\" + self.cache.clear() + + def get_stats(self) -> Dict[str, int]: + \"\"\"Get cache statistics.\"\"\" + return { + "cache_size": len(self.cache), + "hits": sum(1 for v in self.cache.values() if v is not None) + } + + except Exception as e: + class ResultHandler: + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + expected = """ + class DataProcessor: + def __repr__(self) -> str: + return f"DataProcessor(processed={self._processed})" + + def target_method(self, key: str) -> Optional[Any]: + if not self._processed: + self._process_data() + return self.result.get(key) if self.result else None + + try: + sample_data = {"number": 42, "text": "hello"} + processor = DataProcessor(sample_data) + + class ResultHandler: + + def __str__(self) -> str: + return f"ResultHandler(cache_size={len(self.cache)})" + + def target_method(self, key: str) -> Optional[Any]: + if key not in self.cache: + self.cache[key] = self.processor.target_method(key) + return self.cache[key] + + except Exception as e: + class ResultHandler: + + def target_method(self, key: str) -> None: + raise RuntimeError(f"Failed to initialize: {self.error}") + """ + + output = parse_code_and_prune_cst( + dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True + ) + assert dedent(expected).strip() == output.strip() diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index bf7373522..ca23b1d23 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -2674,13 +2674,13 @@ def test_code_replacement10() -> None: project_root=str(file_path.parent), original_source_code=original_code, ).unwrap() - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output code_context = opt.get_code_optimization_context( function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code, ) - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output """ expected = """import gc @@ -2739,9 +2739,9 @@ def test_code_replacement10() -> None: with open(file_path) as f: original_code = f.read() code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code).unwrap() - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code) - assert code_context.code_to_optimize_with_helpers == get_code_output + assert code_context.testgen_context_code == get_code_output codeflash_con.close() """ diff --git a/tests/test_type_annotation_context.py b/tests/test_type_annotation_context.py index b10a8ed42..297e41a48 100644 --- a/tests/test_type_annotation_context.py +++ b/tests/test_type_annotation_context.py @@ -1,103 +1,103 @@ -from __future__ import annotations - -import pathlib -from dataclasses import dataclass, field -from typing import List - -from codeflash.code_utils.code_extractor import get_code -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions - - -class CustomType: - def __init__(self) -> None: - self.name = None - self.data: List[int] = [] - - -@dataclass -class CustomDataClass: - name: str = "" - data: List[int] = field(default_factory=list) - - -def function_to_optimize(data: CustomType) -> CustomType: - name = data.name - data.data.sort() - return data - - -def function_to_optimize2(data: CustomDataClass) -> CustomType: - name = data.name - data.data.sort() - return data - - -def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None: - name = data.name - data.data.sort() - return data - - -def test_function_context_includes_type_annotation() -> None: - file_path = pathlib.Path(__file__).resolve() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("function_to_optimize", str(file_path), []), - str(file_path.parent.resolve()), - """def function_to_optimize(data: CustomType): - name = data.name - data.data.sort() - return data""", - 1000, - ) - - assert len(helper_functions) == 1 - assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType" - - -def test_function_context_includes_type_annotation_dataclass() -> None: - file_path = pathlib.Path(__file__).resolve() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("function_to_optimize2", str(file_path), []), - str(file_path.parent.resolve()), - """def function_to_optimize2(data: CustomDataClass) -> CustomType: - name = data.name - data.data.sort() - return data""", - 1000, - ) - - assert len(helper_functions) == 2 - assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" - assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" - - -def test_function_context_works_for_composite_types() -> None: - file_path = pathlib.Path(__file__).resolve() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("function_to_optimize3", str(file_path), []), - str(file_path.parent.resolve()), - """def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]: - name = data.name - data.data.sort() - return data""", - 1000, - ) - - assert len(helper_functions) == 2 - assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" - assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" - - -def test_function_context_custom_datatype() -> None: - project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" - file_path = project_path / "math_utils.py" - code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])]) - assert code is not None - assert contextual_dunder_methods == set() - a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( - FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000 - ) - - assert len(helper_functions) == 1 - assert helper_functions[0].fully_qualified_name == "math_utils.Matrix" +# from __future__ import annotations +# +# import pathlib +# from dataclasses import dataclass, field +# from typing import List +# +# from codeflash.code_utils.code_extractor import get_code +# from codeflash.discovery.functions_to_optimize import FunctionToOptimize +# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions +# +# +# class CustomType: +# def __init__(self) -> None: +# self.name = None +# self.data: List[int] = [] +# +# +# @dataclass +# class CustomDataClass: +# name: str = "" +# data: List[int] = field(default_factory=list) +# +# +# def function_to_optimize(data: CustomType) -> CustomType: +# name = data.name +# data.data.sort() +# return data +# +# +# def function_to_optimize2(data: CustomDataClass) -> CustomType: +# name = data.name +# data.data.sort() +# return data +# +# +# def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None: +# name = data.name +# data.data.sort() +# return data +# +# +# def test_function_context_includes_type_annotation() -> None: +# file_path = pathlib.Path(__file__).resolve() +# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( +# FunctionToOptimize("function_to_optimize", str(file_path), []), +# str(file_path.parent.resolve()), +# """def function_to_optimize(data: CustomType): +# name = data.name +# data.data.sort() +# return data""", +# 1000, +# ) +# +# assert len(helper_functions) == 1 +# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType" +# +# +# def test_function_context_includes_type_annotation_dataclass() -> None: +# file_path = pathlib.Path(__file__).resolve() +# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( +# FunctionToOptimize("function_to_optimize2", str(file_path), []), +# str(file_path.parent.resolve()), +# """def function_to_optimize2(data: CustomDataClass) -> CustomType: +# name = data.name +# data.data.sort() +# return data""", +# 1000, +# ) +# +# assert len(helper_functions) == 2 +# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" +# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" +# +# +# def test_function_context_works_for_composite_types() -> None: +# file_path = pathlib.Path(__file__).resolve() +# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( +# FunctionToOptimize("function_to_optimize3", str(file_path), []), +# str(file_path.parent.resolve()), +# """def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]: +# name = data.name +# data.data.sort() +# return data""", +# 1000, +# ) +# +# assert len(helper_functions) == 2 +# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" +# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" +# +# +# def test_function_context_custom_datatype() -> None: +# project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" +# file_path = project_path / "math_utils.py" +# code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])]) +# assert code is not None +# assert contextual_dunder_methods == set() +# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( +# FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000 +# ) +# +# assert len(helper_functions) == 1 +# assert helper_functions[0].fully_qualified_name == "math_utils.Matrix" From 8bf9f5c14bf00b1e32b2ffb401546b932d7ff28a Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Thu, 6 Mar 2025 14:02:50 -0800 Subject: [PATCH 2/7] slight fixes --- codeflash/optimization/function_optimizer.py | 2 +- tests/test_get_helper_code.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 310d24aae..26bae6751 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -134,7 +134,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: with helper_function_path.open(encoding="utf8") as f: helper_code = f.read() original_helper_code[helper_function_path] = helper_code - if has_any_async_functions(code_context.code_to_optimize_with_helpers): + if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") code_print(code_context.read_writable_code) diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index 553243a8c..a0789218d 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -217,6 +217,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: f.write(code) f.flush() file_path = Path(f.name).resolve() + project_root_path = file_path.parent.resolve() function_to_optimize = FunctionToOptimize( function_name="__call__", file_path=file_path, @@ -227,7 +228,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: test_config = TestConfig( tests_root="tests", tests_project_rootdir=Path.cwd(), - project_root_path=file_path.parent.resolve(), + project_root_path=project_root_path, test_framework="pytest", pytest_cmd="pytest", ) @@ -241,7 +242,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" assert ( code_context.testgen_context_code - == f'''```python:{file_path.name} + == f'''```python:{file_path.relative_to(project_root_path)} _P = ParamSpec("_P") _KEY_T = TypeVar("_KEY_T") _STORE_T = TypeVar("_STORE_T") From 37769f95919dd0daa5df4d1b42ca335ce358216d Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Mon, 10 Mar 2025 14:53:19 -0700 Subject: [PATCH 3/7] fixes --- codeflash/optimization/function_optimizer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index f7188b53a..438845117 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -138,15 +138,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure("Codeflash does not support async functions in the code to optimize.") code_print(code_context.read_writable_code) - # for module_abspath, helper_code_source in original_helper_code.items(): - # code_context.code_to_optimize_with_helpers = add_needed_imports_from_module( - # helper_code_source, - # code_context.code_to_optimize_with_helpers, - # module_abspath, - # self.function_to_optimize.file_path, - # self.args.project_root, - # ) - generated_test_paths = [ get_test_file_path( self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" From 302e9199417d1c9a872666959f9ad7ebab4f3519 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Mon, 10 Mar 2025 16:15:34 -0700 Subject: [PATCH 4/7] Modified testgen context to be a codestring instead of markdown --- codeflash/context/code_context_extractor.py | 115 ++++++++++++++------ tests/test_code_replacement.py | 5 +- tests/test_function_dependencies.py | 12 +- tests/test_get_helper_code.py | 21 ++-- 4 files changed, 100 insertions(+), 53 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index f2ebe5655..7c40a1650 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -39,7 +39,7 @@ def get_code_optimization_context( ) # Extract code context for optimization - final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code + final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, {}, {}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code read_only_code_markdown = extract_code_markdown_context_from_files( helpers_of_fto, helpers_of_fto_fqn, @@ -85,7 +85,7 @@ def get_code_optimization_context( logger.debug("Code context has exceeded token limit, removing read-only code") read_only_context_code = "" # Extract code context for testgen - testgen_code_markdown = extract_code_markdown_context_from_files( + testgen_code_markdown = extract_code_string_context_from_files( helpers_of_fto, helpers_of_fto_fqn, helpers_of_helpers, @@ -94,10 +94,10 @@ def get_code_optimization_context( remove_docstrings=False, code_context_type=CodeContextType.TESTGEN, ) - testgen_context_code = testgen_code_markdown.markdown + testgen_context_code = testgen_code_markdown.code testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) if testgen_context_code_tokens > testgen_token_limit: - testgen_code_markdown = extract_code_markdown_context_from_files( + testgen_code_markdown = extract_code_string_context_from_files( helpers_of_fto, helpers_of_fto_fqn, helpers_of_helpers, @@ -106,41 +106,64 @@ def get_code_optimization_context( remove_docstrings=True, code_context_type=CodeContextType.TESTGEN, ) - testgen_context_code = testgen_code_markdown.markdown + testgen_context_code = testgen_code_markdown.code testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) if testgen_context_code_tokens > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") return CodeOptimizationContext( testgen_context_code = testgen_context_code, - read_writable_code=CodeString(code=final_read_writable_code).code, + read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, helper_functions=helpers_of_fto_obj_list, preexisting_objects=preexisting_objects, ) - def extract_code_string_context_from_files( - helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path + helpers_of_fto: dict[Path, set[str]], + helpers_of_fto_fqn: dict[Path, set[str]], + helpers_of_helpers: dict[Path, set[str]], + helpers_of_helpers_fqn: dict[Path, set[str]], + project_root_path: Path, + remove_docstrings: bool = False, + code_context_type: CodeContextType = CodeContextType.READ_ONLY, ) -> CodeString: - """Extract read-writable code context from files containing target functions and their helpers. + """Extract code context from files containing target functions and their helpers, formatting them as markdown. - This function iterates through each file path that contains functions to optimize (fto) or - their first-degree helpers, reads the original code, extracts relevant parts using CST parsing, - and adds necessary imports from the original modules. + This function processes two sets of files: + 1. Files containing the function to optimize (fto) and their first-degree helpers + 2. Files containing only helpers of helpers (with no overlap with the first set) + + For each file, it extracts relevant code based on the specified context type, adds necessary + imports, and combines them Args: - helpers_of_fto: Dictionary mapping file paths to sets of qualified function names - helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions - project_root_path: Root path of the project for resolving relative imports + helpers_of_fto: Dictionary mapping file paths to sets of function names to be optimized + helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions to be optimized + helpers_of_helpers: Dictionary mapping file paths to sets of helper function names + helpers_of_helpers_fqn: Dictionary mapping file paths to sets of fully qualified names of helper functions + project_root_path: Root path of the project + remove_docstrings: Whether to remove docstrings from the extracted code + code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) Returns: - CodeString object containing the consolidated read-writable code with all necessary - imports for the target functions and their helpers + CodeString containing the extracted code context with necessary imports """ - final_read_writable_code = "" - # Extract code from file paths that contain fto and first degree helpers + # Rearrange to remove overlaps, so we only access each file path once + helpers_of_helpers_no_overlap = defaultdict(set) + helpers_of_helpers_no_overlap_fqn = defaultdict(set) + for file_path in helpers_of_helpers: + if file_path in helpers_of_fto: + # Remove duplicates, in case a helper of helper is also a helper of fto + helpers_of_helpers[file_path] -= helpers_of_fto[file_path] + helpers_of_helpers_fqn[file_path] -= helpers_of_fto_fqn[file_path] + else: + helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path] + helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path] + + final_code_string_context = "" + # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files for file_path, qualified_function_names in helpers_of_fto.items(): try: original_code = file_path.read_text("utf8") @@ -148,22 +171,52 @@ def extract_code_string_context_from_files( logger.exception(f"Error while parsing {file_path}: {e}") continue try: - read_writable_code = parse_code_and_prune_cst(original_code, CodeContextType.READ_WRITABLE, qualified_function_names) + code_context = parse_code_and_prune_cst( + original_code, code_context_type, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings + ) + except ValueError as e: - logger.debug(f"Error while getting read-writable code: {e}") + logger.debug(f"Error while getting read-only code: {e}") + continue + if code_context.strip(): + final_code_string_context += f"\n{code_context}" + final_code_string_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=final_code_string_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions_fqn=helpers_of_fto_fqn.get(file_path, set()) | helpers_of_helpers_fqn.get(file_path, set()), + ) + if code_context_type == CodeContextType.READ_WRITABLE: + return CodeString(code=final_code_string_context) + # Extract code from file paths containing helpers of helpers + for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items(): + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + continue + try: + code_context = parse_code_and_prune_cst( + original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings + ) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") continue - if read_writable_code: - final_read_writable_code += f"\n{read_writable_code}" - final_read_writable_code = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_read_writable_code, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions_fqn=helpers_of_fto_fqn[file_path], + if code_context.strip(): + final_code_string_context += f"\n{code_context}" + final_code_string_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=final_code_string_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions_fqn=helpers_of_helpers_no_overlap_fqn.get(file_path, set()), ) - return CodeString(code=final_read_writable_code) + return CodeString(code=final_code_string_context) + def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index d8855e87f..59bdbcc23 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -747,8 +747,7 @@ def main_method(self): def test_code_replacement10() -> None: - get_code_output = """```python:test_code_replacement.py -from __future__ import annotations + get_code_output = """from __future__ import annotations import os os.environ["CODEFLASH_API_KEY"] = "cf-test-key" @@ -768,7 +767,7 @@ def __init__(self, name): def main_method(self): return HelperClass(self.name).helper_method() -```""" +""" file_path = Path(__file__).resolve() func_top_optimize = FunctionToOptimize( function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")] diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index d921ebb0b..fa4a2ab28 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -162,8 +162,7 @@ def test_class_method_dependencies() -> None: assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert ( code_context.testgen_context_code - == """```python:test_function_dependencies.py -from collections import defaultdict + == """from collections import defaultdict class Graph: def __init__(self, vertices): @@ -188,8 +187,7 @@ def topologicalSort(self): self.topologicalSortUtil(i, visited, stack) # Print contents of stack - return stack -```""" + return stack""" ) def test_recursive_function_context() -> None: @@ -224,8 +222,7 @@ def test_recursive_function_context() -> None: assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert ( code_context.testgen_context_code - == """```python:test_function_dependencies.py -class C: + == """class C: def calculate_something_3(self, num): return num + 1 @@ -233,6 +230,5 @@ def recursive(self, num): if num == 0: return 0 num_1 = self.calculate_something_3(num) - return self.recursive(num) + num_1 -```""" + return self.recursive(num) + num_1""" ) \ No newline at end of file diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index a0789218d..36359d3e3 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -242,8 +242,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" assert ( code_context.testgen_context_code - == f'''```python:{file_path.relative_to(project_root_path)} -_P = ParamSpec("_P") + == f'''_P = ParamSpec("_P") _KEY_T = TypeVar("_KEY_T") _STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): @@ -385,7 +384,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: kwargs=kwargs, lifespan=self.__duration__, ) -```''' +''' ) @@ -411,19 +410,18 @@ def test_bubble_sort_deps() -> None: code_context = ctx_result.unwrap() assert ( code_context.testgen_context_code - == """```python:code_to_optimize/bubble_sort_dep1_helper.py + == """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer +from code_to_optimize.bubble_sort_dep2_swap import dep2_swap + def dep1_comparer(arr, j: int) -> bool: return arr[j] > arr[j + 1] -``` -```python:code_to_optimize/bubble_sort_dep2_swap.py + def dep2_swap(arr, j): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp -``` -```python:code_to_optimize/bubble_sort_deps.py -from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer -from code_to_optimize.bubble_sort_dep2_swap import dep2_swap + + def sorter_deps(arr): for i in range(len(arr)): @@ -431,7 +429,8 @@ def sorter_deps(arr): if dep1_comparer(arr, j): dep2_swap(arr, j) return arr -```""" + +""" ) assert len(code_context.helper_functions) == 2 assert ( From c3b92433a3b19f059ee29382db35b5b1759a4a6f Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Mon, 10 Mar 2025 16:19:24 -0700 Subject: [PATCH 5/7] mypy fix --- codeflash/context/code_context_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 7c40a1650..226869030 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -399,7 +399,7 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode return indented_block def parse_code_and_prune_cst( - code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = {}, remove_docstrings: bool = False + code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = set(), remove_docstrings: bool = False ) -> str: """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables. """ module = cst.parse_module(code) From d0e477c9277438c0af3f221b28e932a5f375e9da Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Tue, 11 Mar 2025 12:03:17 -0700 Subject: [PATCH 6/7] cleaned up code context extractor, use FunctionSources now --- codeflash/context/code_context_extractor.py | 194 ++++++++++--------- codeflash/models/models.py | 13 +- codeflash/optimization/function_optimizer.py | 46 ----- 3 files changed, 115 insertions(+), 138 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 226869030..5d5487b98 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -15,36 +15,38 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource, \ - CodeContextType +from codeflash.models.models import ( + CodeContextType, + CodeOptimizationContext, + CodeString, + CodeStringsMarkdown, + FunctionSource, +) from codeflash.optimization.function_context import belongs_to_function_qualified def get_code_optimization_context( function_to_optimize: FunctionToOptimize, project_root_path: Path, optim_token_limit: int = 8000, testgen_token_limit: int = 8000 ) -> CodeOptimizationContext: - # Get qualified names and fully qualified names(fqn) of helpers - helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict( - {function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path - ) + # Get FunctionSource representation of helpers of FTO + helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi({function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path) + helpers_of_fto_qualified_names_dict = { + file_path: {source.qualified_name for source in sources} + for file_path, sources in helpers_of_fto_dict.items() + } - helpers_of_helpers, helpers_of_helpers_fqn, _ = get_file_path_to_helper_functions_dict( - helpers_of_fto, project_root_path - ) + # Get FunctionSource representation of helpers of helpers of FTO + helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(helpers_of_fto_qualified_names_dict, project_root_path) - # Add function to optimize - helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name) - helpers_of_fto_fqn[function_to_optimize.file_path].add( - function_to_optimize.qualified_name_with_modules_from_root(project_root_path) - ) + # Add function to optimize into helpers of FTO dict, as they'll be processed together + fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path) + helpers_of_fto_dict[function_to_optimize.file_path].add(fto_as_function_source) # Extract code context for optimization - final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto, helpers_of_fto_fqn, {}, {}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code + final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto_dict,{}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code read_only_code_markdown = extract_code_markdown_context_from_files( - helpers_of_fto, - helpers_of_fto_fqn, - helpers_of_helpers, - helpers_of_helpers_fqn, + helpers_of_fto_dict, + helpers_of_helpers_dict, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_ONLY, @@ -71,10 +73,8 @@ def get_code_optimization_context( logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") # Extract read only code without docstrings read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files( - helpers_of_fto, - helpers_of_fto_fqn, - helpers_of_helpers, - helpers_of_helpers_fqn, + helpers_of_fto_dict, + helpers_of_helpers_dict, project_root_path, remove_docstrings=True, ) @@ -84,12 +84,11 @@ def get_code_optimization_context( if total_tokens > optim_token_limit: logger.debug("Code context has exceeded token limit, removing read-only code") read_only_context_code = "" + # Extract code context for testgen testgen_code_markdown = extract_code_string_context_from_files( - helpers_of_fto, - helpers_of_fto_fqn, - helpers_of_helpers, - helpers_of_helpers_fqn, + helpers_of_fto_dict, + helpers_of_helpers_dict, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.TESTGEN, @@ -98,10 +97,8 @@ def get_code_optimization_context( testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code)) if testgen_context_code_tokens > testgen_token_limit: testgen_code_markdown = extract_code_string_context_from_files( - helpers_of_fto, - helpers_of_fto_fqn, - helpers_of_helpers, - helpers_of_helpers_fqn, + helpers_of_fto_dict, + helpers_of_helpers_dict, project_root_path, remove_docstrings=True, code_context_type=CodeContextType.TESTGEN, @@ -115,33 +112,28 @@ def get_code_optimization_context( testgen_context_code = testgen_context_code, read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, - helper_functions=helpers_of_fto_obj_list, + helper_functions=helpers_of_fto_list, preexisting_objects=preexisting_objects, ) def extract_code_string_context_from_files( - helpers_of_fto: dict[Path, set[str]], - helpers_of_fto_fqn: dict[Path, set[str]], - helpers_of_helpers: dict[Path, set[str]], - helpers_of_helpers_fqn: dict[Path, set[str]], + helpers_of_fto: dict[Path, set[FunctionSource]], + helpers_of_helpers: dict[Path, set[FunctionSource]], project_root_path: Path, remove_docstrings: bool = False, code_context_type: CodeContextType = CodeContextType.READ_ONLY, ) -> CodeString: - """Extract code context from files containing target functions and their helpers, formatting them as markdown. - + """Extract code context from files containing target functions and their helpers. This function processes two sets of files: 1. Files containing the function to optimize (fto) and their first-degree helpers 2. Files containing only helpers of helpers (with no overlap with the first set) For each file, it extracts relevant code based on the specified context type, adds necessary - imports, and combines them + imports, and combines them. Args: - helpers_of_fto: Dictionary mapping file paths to sets of function names to be optimized - helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions to be optimized - helpers_of_helpers: Dictionary mapping file paths to sets of helper function names - helpers_of_helpers_fqn: Dictionary mapping file paths to sets of fully qualified names of helper functions + helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers + helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions project_root_path: Root path of the project remove_docstrings: Whether to remove docstrings from the extracted code code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) @@ -152,27 +144,27 @@ def extract_code_string_context_from_files( """ # Rearrange to remove overlaps, so we only access each file path once helpers_of_helpers_no_overlap = defaultdict(set) - helpers_of_helpers_no_overlap_fqn = defaultdict(set) for file_path in helpers_of_helpers: if file_path in helpers_of_fto: - # Remove duplicates, in case a helper of helper is also a helper of fto + # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto helpers_of_helpers[file_path] -= helpers_of_fto[file_path] - helpers_of_helpers_fqn[file_path] -= helpers_of_fto_fqn[file_path] else: helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path] - helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path] final_code_string_context = "" + # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files - for file_path, qualified_function_names in helpers_of_fto.items(): + for file_path, function_sources in helpers_of_fto.items(): try: original_code = file_path.read_text("utf8") except Exception as e: logger.exception(f"Error while parsing {file_path}: {e}") continue try: + qualified_function_names = {func.qualified_name for func in function_sources} + helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} code_context = parse_code_and_prune_cst( - original_code, code_context_type, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings + original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings ) except ValueError as e: @@ -186,18 +178,19 @@ def extract_code_string_context_from_files( src_path=file_path, dst_path=file_path, project_root=project_root_path, - helper_functions_fqn=helpers_of_fto_fqn.get(file_path, set()) | helpers_of_helpers_fqn.get(file_path, set()), + helper_functions= list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) ) if code_context_type == CodeContextType.READ_WRITABLE: return CodeString(code=final_code_string_context) # Extract code from file paths containing helpers of helpers - for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items(): + for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): try: original_code = file_path.read_text("utf8") except Exception as e: logger.exception(f"Error while parsing {file_path}: {e}") continue try: + qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} code_context = parse_code_and_prune_cst( original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) @@ -213,15 +206,13 @@ def extract_code_string_context_from_files( src_path=file_path, dst_path=file_path, project_root=project_root_path, - helper_functions_fqn=helpers_of_helpers_no_overlap_fqn.get(file_path, set()), + helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), ) return CodeString(code=final_code_string_context) def extract_code_markdown_context_from_files( - helpers_of_fto: dict[Path, set[str]], - helpers_of_fto_fqn: dict[Path, set[str]], - helpers_of_helpers: dict[Path, set[str]], - helpers_of_helpers_fqn: dict[Path, set[str]], + helpers_of_fto: dict[Path, set[FunctionSource]], + helpers_of_helpers: dict[Path, set[FunctionSource]], project_root_path: Path, remove_docstrings: bool = False, code_context_type: CodeContextType = CodeContextType.READ_ONLY, @@ -236,10 +227,8 @@ def extract_code_markdown_context_from_files( imports, and combines them into a structured markdown format. Args: - helpers_of_fto: Dictionary mapping file paths to sets of function names to be optimized - helpers_of_fto_fqn: Dictionary mapping file paths to sets of fully qualified names of functions to be optimized - helpers_of_helpers: Dictionary mapping file paths to sets of helper function names - helpers_of_helpers_fqn: Dictionary mapping file paths to sets of fully qualified names of helper functions + helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers + helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions project_root_path: Root path of the project remove_docstrings: Whether to remove docstrings from the extracted code code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) @@ -251,27 +240,25 @@ def extract_code_markdown_context_from_files( """ # Rearrange to remove overlaps, so we only access each file path once helpers_of_helpers_no_overlap = defaultdict(set) - helpers_of_helpers_no_overlap_fqn = defaultdict(set) for file_path in helpers_of_helpers: if file_path in helpers_of_fto: - # Remove duplicates, in case a helper of helper is also a helper of fto + # Remove duplicates within the same file path, in case a helper of helper is also a helper of fto helpers_of_helpers[file_path] -= helpers_of_fto[file_path] - helpers_of_helpers_fqn[file_path] -= helpers_of_fto_fqn[file_path] else: helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path] - helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path] - code_context_markdown = CodeStringsMarkdown() # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files - for file_path, qualified_function_names in helpers_of_fto.items(): + for file_path, function_sources in helpers_of_fto.items(): try: original_code = file_path.read_text("utf8") except Exception as e: logger.exception(f"Error while parsing {file_path}: {e}") continue try: + qualified_function_names = {func.qualified_name for func in function_sources} + helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} code_context = parse_code_and_prune_cst( - original_code, code_context_type, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings + original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings ) except ValueError as e: @@ -285,22 +272,23 @@ def extract_code_markdown_context_from_files( src_path=file_path, dst_path=file_path, project_root=project_root_path, - helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path], + helper_functions=list( + helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) ), file_path=file_path.relative_to(project_root_path), ) code_context_markdown.code_strings.append(code_context_with_imports) - # Extract code from file paths containing helpers of helpers - for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items(): + for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): try: original_code = file_path.read_text("utf8") except Exception as e: logger.exception(f"Error while parsing {file_path}: {e}") continue try: + qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} code_context = parse_code_and_prune_cst( - original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings + original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings, ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -314,7 +302,7 @@ def extract_code_markdown_context_from_files( src_path=file_path, dst_path=file_path, project_root=project_root_path, - helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path], + helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), ), file_path=file_path.relative_to(project_root_path), ) @@ -322,11 +310,39 @@ def extract_code_markdown_context_from_files( return code_context_markdown -def get_file_path_to_helper_functions_dict( +def get_function_to_optimize_as_function_source(function_to_optimize: FunctionToOptimize, + project_root_path: Path) -> FunctionSource: + # Use jedi to find function to optimize + script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path)) + + # Get all names in the file + names = script.get_names(all_scopes=True, definitions=True, references=False) + + # Find the name that matches our function + for name in names: + if (name.type == "function" and + name.full_name and + name.name == function_to_optimize.function_name and + get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name): + + function_source = FunctionSource( + file_path=function_to_optimize.file_path, + qualified_name=function_to_optimize.qualified_name, + fully_qualified_name=name.full_name, + only_function_name=name.name, + source_code=name.get_line_code(), + jedi_definition=name, + ) + return function_source + + raise ValueError( + f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}") + + +def get_function_sources_from_jedi( file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path -) -> tuple[dict[Path, set[str]], dict[Path, set[str]], list[FunctionSource]]: - file_path_to_helper_function_qualified_names = defaultdict(set) - file_path_to_helper_function_fqn = defaultdict(set) +) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + file_path_to_function_source = defaultdict(set) function_source_list: list[FunctionSource] = [] for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) @@ -361,22 +377,18 @@ def get_file_path_to_helper_functions_dict( and definition.type == "function" and not belongs_to_function_qualified(definition, qualified_function_name) ): - file_path_to_helper_function_qualified_names[definition_path].add( - get_qualified_name(definition.module_name, definition.full_name) - ) - file_path_to_helper_function_fqn[definition_path].add(definition.full_name) - function_source_list.append( - FunctionSource( - file_path=definition_path, - qualified_name=get_qualified_name(definition.module_name, definition.full_name), - fully_qualified_name=definition.full_name, - only_function_name=definition.name, - source_code=definition.get_line_code(), - jedi_definition=definition, - ) + function_source = FunctionSource( + file_path=definition_path, + qualified_name=get_qualified_name(definition.module_name, definition.full_name), + fully_qualified_name=definition.full_name, + only_function_name=definition.name, + source_code=definition.get_line_code(), + jedi_definition=definition, ) + file_path_to_function_source[definition_path].add(function_source) + function_source_list.append(function_source) - return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn, function_source_list + return file_path_to_function_source, function_source_list def is_dunder_method(name: str) -> bool: @@ -401,7 +413,7 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode def parse_code_and_prune_cst( code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = set(), remove_docstrings: bool = False ) -> str: - """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables. """ + """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables.""" module = cst.parse_module(code) if code_context_type == CodeContextType.READ_WRITABLE: filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 699cbbb1d..9cf51cf31 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -54,6 +54,18 @@ class FunctionSource: source_code: str jedi_definition: Name + def __eq__(self, other: object) -> bool: + if not isinstance(other, FunctionSource): + return False + return (self.file_path == other.file_path and + self.qualified_name == other.qualified_name and + self.fully_qualified_name == other.fully_qualified_name and + self.only_function_name == other.only_function_name and + self.source_code == other.source_code) + + def __hash__(self) -> int: + return hash((self.file_path, self.qualified_name, self.fully_qualified_name, + self.only_function_name, self.source_code)) class BestOptimization(BaseModel): candidate: OptimizedCandidate @@ -83,7 +95,6 @@ def markdown(self) -> str: class CodeOptimizationContext(BaseModel): - # code_to_optimize_with_helpers: str testgen_context_code: str = "" read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 438845117..e8cf39d4a 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -21,7 +21,6 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( cleanup_paths, @@ -545,50 +544,6 @@ def replace_function_and_helpers_with_optimized_code( return did_update def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: - # code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize]) - # if code_to_optimize is None: - # return Failure("Could not find function to optimize.") - # (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions( - # self.function_to_optimize, self.project_root, code_to_optimize - # ) - # if self.function_to_optimize.parents: - # function_class = self.function_to_optimize.parents[0].name - # same_class_helper_methods = [ - # df - # for df in helper_functions - # if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class - # ] - # optimizable_methods = [ - # FunctionToOptimize( - # df.qualified_name.split(".")[-1], - # df.file_path, - # [FunctionParent(df.qualified_name.split(".")[0], "ClassDef")], - # None, - # None, - # ) - # for df in same_class_helper_methods - # ] + [self.function_to_optimize] - # dedup_optimizable_methods = [] - # added_methods = set() - # for method in reversed(optimizable_methods): - # if f"{method.file_path}.{method.qualified_name}" not in added_methods: - # dedup_optimizable_methods.append(method) - # added_methods.add(f"{method.file_path}.{method.qualified_name}") - # if len(dedup_optimizable_methods) > 1: - # code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods))) - # if code_to_optimize is None: - # return Failure("Could not find function to optimize.") - # code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize - # - # code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module( - # self.function_to_optimize_source_code, - # code_to_optimize_with_helpers, - # self.function_to_optimize.file_path, - # self.function_to_optimize.file_path, - # self.project_root, - # helper_functions, - # ) - try: new_code_ctx = code_context_extractor.get_code_optimization_context( self.function_to_optimize, self.project_root @@ -598,7 +553,6 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: return Success( CodeOptimizationContext( - # code_to_optimize_with_helpers=new_code_ctx.testgen_context_code, # Outdated, fix this! testgen_context_code=new_code_ctx.testgen_context_code, read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, From e9d6bd12a87e2c8b8006f097f8f7711878ed6c3b Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Tue, 11 Mar 2025 16:15:07 -0700 Subject: [PATCH 7/7] cleaned up comments and removed unused old file --- codeflash/optimization/function_context.py | 260 +------------------ codeflash/optimization/function_optimizer.py | 2 - tests/test_function_dependencies.py | 1 - tests/test_type_annotation_context.py | 103 -------- 4 files changed, 2 insertions(+), 364 deletions(-) delete mode 100644 tests/test_type_annotation_context.py diff --git a/codeflash/optimization/function_context.py b/codeflash/optimization/function_context.py index 7840660c3..4f1c892bc 100644 --- a/codeflash/optimization/function_context.py +++ b/codeflash/optimization/function_context.py @@ -1,28 +1,10 @@ from __future__ import annotations -import ast -import os -import re -from collections import defaultdict -from typing import TYPE_CHECKING - -import jedi -import tiktoken from jedi.api.classes import Name - -from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import get_code from codeflash.code_utils.code_utils import ( get_qualified_name, - module_name_from_file_path, - path_belongs_to_site_packages, -) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, FunctionSource - -if TYPE_CHECKING: - from pathlib import Path +) def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool: """Check if the given name belongs to the specified method.""" @@ -58,242 +40,4 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b return get_qualified_name(name.module_name, name.full_name) == qualified_function_name return False except ValueError: - return False - -# -# def get_type_annotation_context( -# function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path -# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]: -# function_name: str = function.function_name -# file_path: Path = function.file_path -# file_contents: str = file_path.read_text(encoding="utf8") -# try: -# module: ast.Module = ast.parse(file_contents) -# except SyntaxError as e: -# logger.exception(f"get_type_annotation_context - Syntax error in code: {e}") -# return [], set() -# sources: list[FunctionSource] = [] -# ast_parents: list[FunctionParent] = [] -# contextual_dunder_methods = set() -# -# def get_annotation_source( -# j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str -# ) -> None: -# try: -# definition: list[Name] = j_script.goto( -# line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False -# ) -# except Exception as ex: -# if hasattr(name, "full_name"): -# logger.exception(f"Error while getting definition for {name.full_name}: {ex}") -# else: -# logger.exception(f"Error while getting definition: {ex}") -# definition = [] -# if definition: # TODO can be multiple definitions -# definition_path = definition[0].module_path -# -# # The definition is part of this project and not defined within the original function -# if ( -# str(definition_path).startswith(str(project_root_path) + os.sep) -# and definition[0].full_name -# and not path_belongs_to_site_packages(definition_path) -# and not belongs_to_function(definition[0], function_name) -# ): -# source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])]) -# if source_code[0]: -# sources.append( -# FunctionSource( -# fully_qualified_name=definition[0].full_name, -# jedi_definition=definition[0], -# source_code=source_code[0], -# file_path=definition_path, -# qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."), -# only_function_name=definition[0].name, -# ) -# ) -# contextual_dunder_methods.update(source_code[1]) -# -# def visit_children( -# node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent] -# ) -> None: -# child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module -# for child in ast.iter_child_nodes(node): -# visit(child, node_parents) -# -# def visit_all_annotation_children( -# node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent] -# ) -> None: -# if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): -# visit_all_annotation_children(node.left, node_parents) -# visit_all_annotation_children(node.right, node_parents) -# if isinstance(node, ast.Name) and hasattr(node, "id"): -# name: str = node.id -# line_no: int = node.lineno -# col_no: int = node.col_offset -# get_annotation_source(jedi_script, name, node_parents, line_no, col_no) -# if isinstance(node, ast.Subscript): -# if hasattr(node, "slice"): -# if isinstance(node.slice, ast.Subscript): -# visit_all_annotation_children(node.slice, node_parents) -# elif isinstance(node.slice, ast.Tuple): -# for elt in node.slice.elts: -# if isinstance(elt, (ast.Name, ast.Subscript)): -# visit_all_annotation_children(elt, node_parents) -# elif isinstance(node.slice, ast.Name): -# visit_all_annotation_children(node.slice, node_parents) -# if hasattr(node, "value"): -# visit_all_annotation_children(node.value, node_parents) -# -# def visit( -# node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, -# node_parents: list[FunctionParent], -# ) -> None: -# if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): -# if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): -# if node.name == function_name and node_parents == function.parents: -# arg: ast.arg -# for arg in node.args.args: -# if arg.annotation: -# visit_all_annotation_children(arg.annotation, node_parents) -# if node.returns: -# visit_all_annotation_children(node.returns, node_parents) -# -# if not isinstance(node, ast.Module): -# node_parents.append(FunctionParent(node.name, type(node).__name__)) -# visit_children(node, node_parents) -# if not isinstance(node, ast.Module): -# node_parents.pop() -# -# visit(module, ast_parents) -# -# return sources, contextual_dunder_methods - - -# def get_function_variables_definitions( -# function_to_optimize: FunctionToOptimize, project_root_path: Path -# ) -> tuple[list[FunctionSource], set[tuple[str, str]]]: -# function_name = function_to_optimize.function_name -# file_path = function_to_optimize.file_path -# script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) -# sources: list[FunctionSource] = [] -# contextual_dunder_methods = set() -# # TODO: The function name condition can be stricter so that it does not clash with other class names etc. -# # TODO: The function could have been imported as some other name, -# # we should be checking for the translation as well. Also check for the original function name. -# names = [] -# for ref in script.get_names(all_scopes=True, definitions=False, references=True): -# if ref.full_name: -# if function_to_optimize.parents: -# # Check if the reference belongs to the specified class when FunctionParent is provided -# if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name): -# names.append(ref) -# elif belongs_to_function(ref, function_name): -# names.append(ref) -# -# for name in names: -# try: -# definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) -# except Exception as e: -# try: -# logger.exception(f"Error while getting definition for {name.full_name}: {e}") -# except Exception as e: -# # name.full_name can also throw exceptions sometimes -# logger.exception(f"Error while getting definition: {e}") -# definitions = [] -# if definitions: -# # TODO: there can be multiple definitions, see how to handle such cases -# definition = definitions[0] -# definition_path = definition.module_path -# -# # The definition is part of this project and not defined within the original function -# if ( -# str(definition_path).startswith(str(project_root_path) + os.sep) -# and not path_belongs_to_site_packages(definition_path) -# and definition.full_name -# and not belongs_to_function(definition, function_name) -# ): -# module_name = module_name_from_file_path(definition_path, project_root_path) -# m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name) -# parents = [] -# if m: -# parents = [FunctionParent(m.group(1), "ClassDef")] -# -# source_code = get_code( -# [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)] -# ) -# if source_code[0]: -# sources.append( -# FunctionSource( -# fully_qualified_name=definition.full_name, -# jedi_definition=definition, -# source_code=source_code[0], -# file_path=definition_path, -# qualified_name=definition.full_name.removeprefix(definition.module_name + "."), -# only_function_name=definition.name, -# ) -# ) -# contextual_dunder_methods.update(source_code[1]) -# annotation_sources, annotation_dunder_methods = get_type_annotation_context( -# function_to_optimize, script, project_root_path -# ) -# sources[:0] = annotation_sources # prepend the annotation sources -# contextual_dunder_methods.update(annotation_dunder_methods) -# existing_fully_qualified_names = set() -# no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set)) -# parent_sources = set() -# for source in sources: -# if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names: -# if not source.qualified_name.count("."): -# no_parent_sources[source.file_path][source.qualified_name].add(source) -# else: -# parent_sources.add(source) -# existing_fully_qualified_names.add(fully_qualified_name) -# deduped_parent_sources = [ -# source -# for source in parent_sources -# if source.file_path not in no_parent_sources -# or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path] -# ] -# deduped_no_parent_sources = [ -# source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2] -# ] -# return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods -# -# -# MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k -# -# -# def get_constrained_function_context_and_helper_functions( -# function_to_optimize: FunctionToOptimize, -# project_root_path: Path, -# code_to_optimize: str, -# max_tokens: int = MAX_PROMPT_TOKENS, -# ) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]: -# helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path) -# tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") -# code_to_optimize_tokens = tokenizer.encode(code_to_optimize) -# -# if not function_to_optimize.parents: -# helper_functions_sources = [function.source_code for function in helper_functions] -# else: -# helper_functions_sources = [ -# function.source_code -# for function in helper_functions -# if not function.qualified_name.count(".") -# or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name -# ] -# helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources] -# -# context_list = [] -# context_len = len(code_to_optimize_tokens) -# logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}") -# logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}") -# for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens): -# if context_len + source_len <= max_tokens: -# context_list.append(function_source) -# context_len += source_len -# else: -# break -# logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}") -# helper_code: str = "\n".join(context_list) -# return helper_code, helper_functions, dunder_methods + return False \ No newline at end of file diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e8cf39d4a..d5c2651b7 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -48,7 +48,6 @@ BestOptimization, CodeOptimizationContext, FunctionCalledInTest, - FunctionParent, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -58,7 +57,6 @@ TestFiles, TestingMode, ) -# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.explanation import Explanation diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index fa4a2ab28..019cc4261 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -6,7 +6,6 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import is_successful from codeflash.models.models import FunctionParent -# from codeflash.optimization.function_context import get_function_variables_definitions from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig diff --git a/tests/test_type_annotation_context.py b/tests/test_type_annotation_context.py deleted file mode 100644 index 297e41a48..000000000 --- a/tests/test_type_annotation_context.py +++ /dev/null @@ -1,103 +0,0 @@ -# from __future__ import annotations -# -# import pathlib -# from dataclasses import dataclass, field -# from typing import List -# -# from codeflash.code_utils.code_extractor import get_code -# from codeflash.discovery.functions_to_optimize import FunctionToOptimize -# from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions -# -# -# class CustomType: -# def __init__(self) -> None: -# self.name = None -# self.data: List[int] = [] -# -# -# @dataclass -# class CustomDataClass: -# name: str = "" -# data: List[int] = field(default_factory=list) -# -# -# def function_to_optimize(data: CustomType) -> CustomType: -# name = data.name -# data.data.sort() -# return data -# -# -# def function_to_optimize2(data: CustomDataClass) -> CustomType: -# name = data.name -# data.data.sort() -# return data -# -# -# def function_to_optimize3(data: dict[CustomDataClass, list[CustomDataClass]]) -> list[CustomType] | None: -# name = data.name -# data.data.sort() -# return data -# -# -# def test_function_context_includes_type_annotation() -> None: -# file_path = pathlib.Path(__file__).resolve() -# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( -# FunctionToOptimize("function_to_optimize", str(file_path), []), -# str(file_path.parent.resolve()), -# """def function_to_optimize(data: CustomType): -# name = data.name -# data.data.sort() -# return data""", -# 1000, -# ) -# -# assert len(helper_functions) == 1 -# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomType" -# -# -# def test_function_context_includes_type_annotation_dataclass() -> None: -# file_path = pathlib.Path(__file__).resolve() -# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( -# FunctionToOptimize("function_to_optimize2", str(file_path), []), -# str(file_path.parent.resolve()), -# """def function_to_optimize2(data: CustomDataClass) -> CustomType: -# name = data.name -# data.data.sort() -# return data""", -# 1000, -# ) -# -# assert len(helper_functions) == 2 -# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" -# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" -# -# -# def test_function_context_works_for_composite_types() -> None: -# file_path = pathlib.Path(__file__).resolve() -# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( -# FunctionToOptimize("function_to_optimize3", str(file_path), []), -# str(file_path.parent.resolve()), -# """def function_to_optimize3(data: set[CustomDataClass[CustomDataClass, int]]) -> list[CustomType]: -# name = data.name -# data.data.sort() -# return data""", -# 1000, -# ) -# -# assert len(helper_functions) == 2 -# assert helper_functions[0].fully_qualified_name == "test_type_annotation_context.CustomDataClass" -# assert helper_functions[1].fully_qualified_name == "test_type_annotation_context.CustomType" -# -# -# def test_function_context_custom_datatype() -> None: -# project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" -# file_path = project_path / "math_utils.py" -# code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])]) -# assert code is not None -# assert contextual_dunder_methods == set() -# a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions( -# FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000 -# ) -# -# assert len(helper_functions) == 1 -# assert helper_functions[0].fully_qualified_name == "math_utils.Matrix"