Skip to content

Commit 6f97004

Browse files
authored
Merge branch 'main' into skip-formatting-for-large-diffs
2 parents a1510a3 + 62a4575 commit 6f97004

28 files changed

+2507
-247
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ https://github.com/user-attachments/assets/38f44f4e-be1c-4f84-8db9-63d5ee3e61e5
6565

6666
Join our community for support and discussions. If you have any questions, feel free to reach out to us using one of the following methods:
6767

68+
- [Free live Installation Support](https://calendly.com/codeflash-saurabh/codeflash-setup)
6869
- [Join our Discord](https://www.codeflash.ai/discord)
6970
- [Follow us on Twitter](https://x.com/codeflashAI)
7071
- [Follow us on Linkedin](https://www.linkedin.com/in/saurabh-misra/)
71-
- [Email founders](mailto:[email protected])
7272

7373
## License
7474

codeflash/LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Business Source License 1.1
33
Parameters
44

55
Licensor: CodeFlash Inc.
6-
Licensed Work: Codeflash Client version 0.13.x
6+
Licensed Work: Codeflash Client version 0.14.x
77
The Licensed Work is (c) 2024 CodeFlash Inc.
88

99
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
1313
Platform. Please visit codeflash.ai for further
1414
information.
1515

16-
Change Date: 2029-06-03
16+
Change Date: 2029-06-09
1717

1818
Change License: MIT
1919

codeflash/api/aiservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def optimize_python_code( # noqa: D417
118118

119119
if response.status_code == 200:
120120
optimizations_json = response.json()["optimizations"]
121-
logger.info(f"Generated {len(optimizations_json)} candidates.")
121+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
122122
console.rule()
123123
end_time = time.perf_counter()
124124
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
@@ -189,7 +189,7 @@ def optimize_python_code_line_profiler( # noqa: D417
189189

190190
if response.status_code == 200:
191191
optimizations_json = response.json()["optimizations"]
192-
logger.info(f"Generated {len(optimizations_json)} candidates.")
192+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
193193
console.rule()
194194
return [
195195
OptimizedCandidate(

codeflash/api/cfapi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Optional
99

10+
import git
1011
import requests
1112
import sentry_sdk
1213
from pydantic.json import pydantic_encoder
@@ -191,3 +192,35 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
191192
return {}
192193

193194
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
195+
196+
197+
def is_function_being_optimized_again(
198+
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
199+
) -> Any: # noqa: ANN401
200+
"""Check if the function being optimized is being optimized again."""
201+
response = make_cfapi_request(
202+
"/is-already-optimized",
203+
"POST",
204+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts},
205+
)
206+
response.raise_for_status()
207+
return response.json()
208+
209+
210+
def add_code_context_hash(code_context_hash: str) -> None:
211+
"""Add code context to the DB cache."""
212+
pr_number = get_pr_number()
213+
if pr_number is None:
214+
return
215+
try:
216+
owner, repo = get_repo_owner_and_name()
217+
pr_number = get_pr_number()
218+
except git.exc.InvalidGitRepositoryError:
219+
return
220+
221+
if owner and repo and pr_number is not None:
222+
make_cfapi_request(
223+
"/add-code-hash",
224+
"POST",
225+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
226+
)

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
123123
"disable_telemetry",
124124
"disable_imports_sorting",
125125
"git_remote",
126+
"override_fixtures",
126127
]
127128
for key in supported_keys:
128129
if key in pyproject_config and (

codeflash/cli_cmds/console.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,34 @@ def code_print(code_str: str) -> None:
6666

6767

6868
@contextmanager
69-
def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]:
70-
"""Display a progress bar with a spinner and elapsed time."""
71-
progress = Progress(
72-
SpinnerColumn(next(spinners)),
73-
*Progress.get_default_columns(),
74-
TimeElapsedColumn(),
75-
console=console,
76-
transient=transient,
77-
)
78-
task = progress.add_task(message, total=None)
79-
with progress:
80-
yield task
69+
def progress_bar(
70+
message: str, *, transient: bool = False, revert_to_print: bool = False
71+
) -> Generator[TaskID, None, None]:
72+
"""Display a progress bar with a spinner and elapsed time.
73+
74+
If revert_to_print is True, falls back to printing a single logger.info message
75+
instead of showing a progress bar.
76+
"""
77+
if revert_to_print:
78+
logger.info(message)
79+
80+
# Create a fake task ID since we still need to yield something
81+
class DummyTask:
82+
def __init__(self) -> None:
83+
self.id = 0
84+
85+
yield DummyTask().id
86+
else:
87+
progress = Progress(
88+
SpinnerColumn(next(spinners)),
89+
*Progress.get_default_columns(),
90+
TimeElapsedColumn(),
91+
console=console,
92+
transient=transient,
93+
)
94+
task = progress.add_task(message, total=None)
95+
with progress:
96+
yield task
8197

8298

8399
@contextmanager

codeflash/code_utils/code_replacer.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from functools import lru_cache
66
from typing import TYPE_CHECKING, Optional, TypeVar
77

8+
import isort
89
import libcst as cst
10+
import libcst.matchers as m
911

1012
from codeflash.cli_cmds.console import logger
1113
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
14+
from codeflash.code_utils.config_parser import find_conftest_files
15+
from codeflash.code_utils.line_profile_utils import ImportAdder
1216
from codeflash.models.models import FunctionParent
1317

1418
if TYPE_CHECKING:
@@ -33,6 +37,142 @@ def normalize_code(code: str) -> str:
3337
return ast.unparse(normalize_node(ast.parse(code)))
3438

3539

40+
class PytestMarkAdder(cst.CSTTransformer):
41+
"""Transformer that adds pytest marks to test functions."""
42+
43+
def __init__(self, mark_name: str) -> None:
44+
super().__init__()
45+
self.mark_name = mark_name
46+
self.has_pytest_import = False
47+
48+
def visit_Module(self, node: cst.Module) -> None:
49+
"""Check if pytest is already imported."""
50+
for statement in node.body:
51+
if isinstance(statement, cst.SimpleStatementLine):
52+
for stmt in statement.body:
53+
if isinstance(stmt, cst.Import):
54+
for import_alias in stmt.names:
55+
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
56+
self.has_pytest_import = True
57+
58+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
59+
"""Add pytest import if not present."""
60+
if not self.has_pytest_import:
61+
# Create import statement
62+
import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])])
63+
# Add import at the beginning
64+
updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body])
65+
return updated_node
66+
67+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
68+
"""Add pytest mark to test functions."""
69+
# Check if the mark already exists
70+
for decorator in updated_node.decorators:
71+
if self._is_pytest_mark(decorator.decorator, self.mark_name):
72+
return updated_node
73+
74+
# Create the pytest mark decorator
75+
mark_decorator = self._create_pytest_mark()
76+
77+
# Add the decorator
78+
new_decorators = [*list(updated_node.decorators), mark_decorator]
79+
return updated_node.with_changes(decorators=new_decorators)
80+
81+
def _is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool:
82+
"""Check if a decorator is a specific pytest mark."""
83+
if isinstance(decorator, cst.Attribute):
84+
if (
85+
isinstance(decorator.value, cst.Attribute)
86+
and isinstance(decorator.value.value, cst.Name)
87+
and decorator.value.value.value == "pytest"
88+
and decorator.value.attr.value == "mark"
89+
and decorator.attr.value == mark_name
90+
):
91+
return True
92+
elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute):
93+
return self._is_pytest_mark(decorator.func, mark_name)
94+
return False
95+
96+
def _create_pytest_mark(self) -> cst.Decorator:
97+
"""Create a pytest mark decorator."""
98+
# Base: pytest.mark.{mark_name}
99+
mark_attr = cst.Attribute(
100+
value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name)
101+
)
102+
decorator = mark_attr
103+
return cst.Decorator(decorator=decorator)
104+
105+
106+
class AutouseFixtureModifier(cst.CSTTransformer):
107+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
108+
# Matcher for '@fixture' or '@pytest.fixture'
109+
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))
110+
111+
for decorator in original_node.decorators:
112+
if m.matches(
113+
decorator,
114+
m.Decorator(
115+
decorator=m.Call(
116+
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
117+
)
118+
),
119+
):
120+
# Found a matching fixture with autouse=True
121+
122+
# 1. The original body of the function will become the 'else' block.
123+
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
124+
else_block = cst.Else(body=updated_node.body)
125+
126+
# 2. Create the new 'if' block that will exit the fixture early.
127+
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
128+
yield_statement = cst.parse_statement("yield")
129+
if_body = cst.IndentedBlock(body=[yield_statement])
130+
131+
# 3. Construct the full if/else statement.
132+
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
133+
134+
# 4. Replace the entire function's body with our new single statement.
135+
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
136+
return updated_node
137+
138+
139+
def disable_autouse(test_path: Path) -> str:
140+
file_content = test_path.read_text(encoding="utf-8")
141+
module = cst.parse_module(file_content)
142+
disable_autouse_fixture = AutouseFixtureModifier()
143+
modified_module = module.visit(disable_autouse_fixture)
144+
test_path.write_text(modified_module.code, encoding="utf-8")
145+
return file_content
146+
147+
148+
def modify_autouse_fixture(test_paths: list[Path]) -> dict[Path, list[str]]:
149+
# find fixutre definition in conftetst.py (the one closest to the test)
150+
# get fixtures present in override-fixtures in pyproject.toml
151+
# add if marker closest return
152+
file_content_map = {}
153+
conftest_files = find_conftest_files(test_paths)
154+
for cf_file in conftest_files:
155+
# iterate over all functions in the file
156+
# if function has autouse fixture, modify function to bypass with custom marker
157+
original_content = disable_autouse(cf_file)
158+
file_content_map[cf_file] = original_content
159+
return file_content_map
160+
161+
162+
# # reuse line profiler utils to add decorator and import to test fns
163+
def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
164+
for test_path in test_paths:
165+
# read file
166+
file_content = test_path.read_text(encoding="utf-8")
167+
module = cst.parse_module(file_content)
168+
importadder = ImportAdder("import pytest")
169+
modified_module = module.visit(importadder)
170+
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
171+
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
172+
modified_module = modified_module.visit(pytest_mark_adder)
173+
test_path.write_text(modified_module.code, encoding="utf-8")
174+
175+
36176
class OptimFunctionCollector(cst.CSTVisitor):
37177
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
38178

codeflash/code_utils/code_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,8 @@ def cleanup_paths(paths: list[Path]) -> None:
208208
shutil.rmtree(path, ignore_errors=True)
209209
else:
210210
path.unlink(missing_ok=True)
211+
212+
213+
def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
214+
for path, file_content in path_to_content_map.items():
215+
path.write_text(file_content, encoding="utf8")

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
1010
COVERAGE_THRESHOLD = 60.0
1111
MIN_TESTCASE_PASSED_THRESHOLD = 6
12+
REPEAT_OPTIMIZATION_PROBABILITY = 0.1

codeflash/code_utils/config_parser.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
3131
raise ValueError(msg)
3232

3333

34+
def find_conftest_files(test_paths: list[Path]) -> list[Path]:
35+
list_of_conftest_files = set()
36+
for test_path in test_paths:
37+
# Find the conftest file on the root of the project
38+
dir_path = Path.cwd()
39+
cur_path = test_path
40+
while cur_path != dir_path:
41+
config_file = cur_path / "conftest.py"
42+
if config_file.exists():
43+
list_of_conftest_files.add(config_file)
44+
# Search for conftest.py in the parent directories
45+
cur_path = cur_path.parent
46+
return list(list_of_conftest_files)
47+
48+
3449
def parse_config_file(
3550
config_file_path: Path | None = None,
3651
override_formatter_check: bool = False, # noqa: FBT001, FBT002
@@ -56,7 +71,12 @@ def parse_config_file(
5671
path_keys = ["module-root", "tests-root", "benchmarks-root"]
5772
path_list_keys = ["ignore-paths"]
5873
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
59-
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False}
74+
bool_keys = {
75+
"override-fixtures": False,
76+
"disable-telemetry": False,
77+
"disable-imports-sorting": False,
78+
"benchmark": False,
79+
}
6080
list_str_keys = {"formatter-cmds": ["black $file"]}
6181

6282
for key, default_value in str_keys.items():

0 commit comments

Comments
 (0)