Skip to content

Commit ec4afc4

Browse files
authored
Merge branch 'main' into add-info-to-codeflash-all-docs
2 parents 7112c5e + 2ba89e6 commit ec4afc4

File tree

6 files changed

+34
-27
lines changed

6 files changed

+34
-27
lines changed

codeflash/api/aiservice.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

3-
import time
4-
53
import json
64
import os
75
import platform
6+
import time
87
from typing import TYPE_CHECKING, Any
98

109
import requests
@@ -122,7 +121,7 @@ def optimize_python_code(
122121
logger.info(f"Generated {len(optimizations_json)} candidates.")
123122
console.rule()
124123
end_time = time.perf_counter()
125-
logger.debug(f"Optimization took {end_time - start_time:.2f} seconds.")
124+
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
126125
return [
127126
OptimizedCandidate(
128127
source_code=opt["source_code"],
@@ -177,7 +176,7 @@ def optimize_python_code_line_profiler(
177176

178177
logger.info("Generating optimized candidates…")
179178
console.rule()
180-
if line_profiler_results=="":
179+
if line_profiler_results == "":
181180
logger.info("No LineProfiler results were provided, Skipping optimization.")
182181
console.rule()
183182
return []
@@ -209,7 +208,6 @@ def optimize_python_code_line_profiler(
209208
console.rule()
210209
return []
211210

212-
213211
def log_results(
214212
self,
215213
function_trace_id: str,
@@ -272,9 +270,10 @@ def generate_regression_tests(
272270
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.
273271
274272
"""
275-
assert test_framework in ["pytest", "unittest"], (
276-
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
277-
)
273+
assert test_framework in [
274+
"pytest",
275+
"unittest",
276+
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
278277
payload = {
279278
"source_code_being_tested": source_code_being_tested,
280279
"function_to_optimize": function_to_optimize,

codeflash/cli_cmds/cmd_init.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def collect_setup_info() -> SetupInfo:
239239
else:
240240
apologize_and_exit()
241241
else:
242-
tests_root = Path(curdir) / Path(cast(str, tests_root_answer))
242+
tests_root = Path(curdir) / Path(cast("str", tests_root_answer))
243243
tests_root = tests_root.relative_to(curdir)
244244
ph("cli-tests-root-provided")
245245

@@ -302,7 +302,7 @@ def collect_setup_info() -> SetupInfo:
302302
elif benchmarks_answer == no_benchmarks_option:
303303
benchmarks_root = None
304304
else:
305-
benchmarks_root = tests_root / Path(cast(str, benchmarks_answer))
305+
benchmarks_root = tests_root / Path(cast("str", benchmarks_answer))
306306

307307
# TODO: Implement other benchmark framework options
308308
# if benchmarks_root:
@@ -354,9 +354,9 @@ def collect_setup_info() -> SetupInfo:
354354
module_root=str(module_root),
355355
tests_root=str(tests_root),
356356
benchmarks_root=str(benchmarks_root) if benchmarks_root else None,
357-
test_framework=cast(str, test_framework),
357+
test_framework=cast("str", test_framework),
358358
ignore_paths=ignore_paths,
359-
formatter=cast(str, formatter),
359+
formatter=cast("str", formatter),
360360
git_remote=str(git_remote),
361361
)
362362

@@ -466,7 +466,7 @@ def check_for_toml_or_setup_file() -> str | None:
466466
click.echo("⏩️ Skipping pyproject.toml creation.")
467467
apologize_and_exit()
468468
click.echo()
469-
return cast(str, project_name)
469+
return cast("str", project_name)
470470

471471

472472
def install_github_actions(override_formatter_check: bool = False) -> None:
@@ -852,7 +852,8 @@ def enter_api_key_and_save_to_rc() -> None:
852852

853853

854854
def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]:
855-
bubble_sort_content = """def sorter(arr):
855+
bubble_sort_content = """from typing import Union, List
856+
def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
856857
for i in range(len(arr)):
857858
for j in range(len(arr) - 1):
858859
if arr[j] > arr[j + 1]:

codeflash/code_utils/code_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
from codeflash.cli_cmds.console import logger
1212

13+
def encoded_tokens_len(s: str) -> int:
14+
'''Function for returning the approximate length of the encoded tokens
15+
It's an approximation of BPE encoding (https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)'''
16+
return int(len(s)*0.25)
1317

1418
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
1519
if not full_qualified_name:

codeflash/context/code_context_extractor.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77

88
import jedi
99
import libcst as cst
10-
import tiktoken
1110
from jedi.api.classes import Name
1211
from libcst import CSTNode
1312

1413
from codeflash.cli_cmds.console import logger
1514
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
16-
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
15+
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages, encoded_tokens_len
1716
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
1817
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1918
from codeflash.models.models import (
@@ -73,8 +72,7 @@ def get_code_optimization_context(
7372
)
7473

7574
# Handle token limits
76-
tokenizer = tiktoken.encoding_for_model("gpt-4o")
77-
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
75+
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
7876
if final_read_writable_tokens > optim_token_limit:
7977
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
8078

@@ -87,7 +85,7 @@ def get_code_optimization_context(
8785
)
8886
read_only_context_code = read_only_code_markdown.markdown
8987

90-
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code))
88+
read_only_code_markdown_tokens = encoded_tokens_len(read_only_context_code)
9189
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
9290
if total_tokens > optim_token_limit:
9391
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
@@ -96,7 +94,7 @@ def get_code_optimization_context(
9694
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True
9795
)
9896
read_only_context_code = read_only_code_no_docstring_markdown.markdown
99-
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code))
97+
read_only_code_no_docstring_markdown_tokens = encoded_tokens_len(read_only_context_code)
10098
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
10199
if total_tokens > optim_token_limit:
102100
logger.debug("Code context has exceeded token limit, removing read-only code")
@@ -111,7 +109,7 @@ def get_code_optimization_context(
111109
code_context_type=CodeContextType.TESTGEN,
112110
)
113111
testgen_context_code = testgen_code_markdown.code
114-
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
112+
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
115113
if testgen_context_code_tokens > testgen_token_limit:
116114
testgen_code_markdown = extract_code_string_context_from_files(
117115
helpers_of_fto_dict,
@@ -121,7 +119,7 @@ def get_code_optimization_context(
121119
code_context_type=CodeContextType.TESTGEN,
122120
)
123121
testgen_context_code = testgen_code_markdown.code
124-
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
122+
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
125123
if testgen_context_code_tokens > testgen_token_limit:
126124
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
127125

codeflash/optimization/function_optimizer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import ast
44
import concurrent.futures
55
import os
6-
import shutil
76
import subprocess
87
import time
98
import uuid
@@ -393,7 +392,7 @@ def determine_best_candidate(
393392
try:
394393
candidate_index = 0
395394
original_len = len(candidates)
396-
while candidates:
395+
while True:
397396
done = True if future_line_profile_results is None else future_line_profile_results.done()
398397
if done and (future_line_profile_results is not None):
399398
line_profile_results = future_line_profile_results.result()
@@ -403,8 +402,14 @@ def determine_best_candidate(
403402
f"Added results from line profiler to candidates, total candidates now: {original_len}"
404403
)
405404
future_line_profile_results = None
405+
try:
406+
candidate = candidates.popleft()
407+
except IndexError:
408+
if done:
409+
break
410+
time.sleep(0.1)
411+
continue
406412
candidate_index += 1
407-
candidate = candidates.popleft()
408413
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
409414
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
410415
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
@@ -512,7 +517,8 @@ def determine_best_candidate(
512517
self.write_code_and_helpers(
513518
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
514519
)
515-
520+
if done and not candidates:
521+
break
516522
except KeyboardInterrupt as e:
517523
self.write_code_and_helpers(
518524
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ pytest = ">=7.0.0,!=8.3.4"
7373
gitpython = ">=3.1.31"
7474
libcst = ">=1.0.1"
7575
jedi = ">=0.19.1"
76-
tiktoken = ">=0.7.0"
7776
timeout-decorator = ">=0.5.0"
7877
pytest-timeout = ">=2.1.0"
7978
tomlkit = ">=0.11.7"

0 commit comments

Comments
 (0)