Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def parse_config_file(
assert isinstance(config, dict)

# default values:
path_keys = ["module-root", "tests-root"]
path_list_keys = ["ignore-paths"]
path_keys = {"module-root", "tests-root"}
path_list_keys = {"ignore-paths", }
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False}
list_str_keys = {"formatter-cmds": ["black $file"]}
Expand Down Expand Up @@ -83,7 +83,7 @@ def parse_config_file(
else: # Default to empty list
config[key] = []

assert config["test-framework"] in ["pytest", "unittest"], (
assert config["test-framework"] in {"pytest", "unittest"}, (
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
)
if len(config["formatter-cmds"]) > 0:
Expand Down
16 changes: 8 additions & 8 deletions codeflash/code_utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _pipe_segment_with_colons(align, colwidth):
"""Return a segment of a horizontal line with optional colons which
indicate column's alignment (as in `pipe` output format)."""
w = colwidth
if align in ["right", "decimal"]:
if align in {"right", "decimal"}:
return ("-" * (w - 1)) + ":"
elif align == "center":
return ":" + ("-" * (w - 2)) + ":"
Expand Down Expand Up @@ -176,7 +176,7 @@ def _isconvertible(conv, string):
def _isnumber(string):
return (
# fast path
type(string) in (float, int)
type(string) in {float, int}
# covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type
# convertible to int/float.
or (
Expand All @@ -188,7 +188,7 @@ def _isnumber(string):
# just an over/underflow
or (
not (math.isinf(float(string)) or math.isnan(float(string)))
or string.lower() in ["inf", "-inf", "nan"]
or string.lower() in {"inf", "-inf", "nan"}
)
)
)
Expand All @@ -210,7 +210,7 @@ def _isint(string, inttype=int):

def _isbool(string):
return type(string) is bool or (
isinstance(string, (bytes, str)) and string in ("True", "False")
isinstance(string, (bytes, str)) and string in {"True", "False"}
)


Expand Down Expand Up @@ -570,7 +570,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
# values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)
keys = list(tabular_data)
if (
showindex in ["default", "always", True]
showindex in {"default", "always", True}
and tabular_data.index.name is not None
):
if isinstance(tabular_data.index.name, list):
Expand Down Expand Up @@ -686,7 +686,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows))

# add or remove an index column
showindex_is_a_str = type(showindex) in [str, bytes]
showindex_is_a_str = type(showindex) in {str, bytes}
if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str):
pass

Expand Down Expand Up @@ -820,7 +820,7 @@ def tabulate(
if colglobalalign is not None: # if global alignment provided
aligns = [colglobalalign] * len(cols)
else: # default
aligns = [numalign if ct in [int, float] else stralign for ct in coltypes]
aligns = [numalign if ct in {int, float} else stralign for ct in coltypes]
# then specific alignments
if colalign is not None:
assert isinstance(colalign, Iterable)
Expand Down Expand Up @@ -1044,4 +1044,4 @@ def _format_table(
output = "\n".join(lines)
return output
else: # a completely empty table
return ""
return ""
8 changes: 4 additions & 4 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str:

units = re.split(r",|\s", runtime_human)[1]

if units in ("microseconds", "microsecond"):
if units in {"microseconds", "microsecond"}:
runtime_human = f"{time_micro:.3g}"
elif units in ("milliseconds", "millisecond"):
elif units in {"milliseconds", "millisecond"}:
runtime_human = "%.3g" % (time_micro / 1000)
elif units in ("seconds", "second"):
elif units in {"seconds", "second"}:
runtime_human = "%.3g" % (time_micro / (1000**2))
elif units in ("minutes", "minute"):
elif units in {"minutes", "minute"}:
runtime_human = "%.3g" % (time_micro / (60 * 1000**2))
else: # hours
runtime_human = "%.3g" % (time_micro / (3600 * 1000**2))
Expand Down
6 changes: 3 additions & 3 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def establish_original_code_baseline(
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
# For the original function - run the tests and get the runtime, plus coverage
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
success = True

test_env = os.environ.copy()
Expand Down Expand Up @@ -941,7 +941,7 @@ def run_optimized_candidate(
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
) -> Result[OptimizedCandidateResult, str]:
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}

with progress_bar("Testing optimization candidate"):
test_env = os.environ.copy()
Expand Down Expand Up @@ -1118,7 +1118,7 @@ def run_and_parse_tests(
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
test_files=test_files,
Expand Down
8 changes: 4 additions & 4 deletions codeflash/tracing/profile_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class ProfileStats(pstats.Stats):
def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}"
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
self.trace_file_path = trace_file_path
self.time_unit = time_unit
logger.debug(hasattr(self, "create_stats"))
Expand Down Expand Up @@ -59,10 +59,10 @@ def print_stats(self, *amount):
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream)
print(file=self.stream)
width, list = self.get_print_list(amount)
if list:
width, list_ = self.get_print_list(amount)
if list_:
self.print_title()
for func in list:
for func in list_:
self.print_line(func)
print(file=self.stream)
print(file=self.stream)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/tracing/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_function_alias(module: str, function_name: str) -> str:
def create_trace_replay_test(
trace_file: str, functions: list[FunctionModules], test_framework: str = "pytest", max_run_count=100
) -> str:
assert test_framework in ["pytest", "unittest"]
assert test_framework in {"pytest", "unittest"}

imports = f"""import dill as pickle
{"import unittest" if test_framework == "unittest" else ""}
Expand Down
2 changes: 1 addition & 1 deletion codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
return comparator(orig_keys, new_keys, superset_obj)

if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}:
return new == orig
if str(type(orig)) == "<class 'object'>":
return True
Expand Down
4 changes: 2 additions & 2 deletions codeflash/verification/equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
superset_obj = False
if original_test_result.verification_type and (
original_test_result.verification_type
in (VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO)
in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO}
):
superset_obj = True
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
Expand Down Expand Up @@ -70,7 +70,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
are_equal = False
break

if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (
if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST} and (
cdd_test_result.did_pass != original_test_result.did_pass
):
are_equal = False
Expand Down
18 changes: 9 additions & 9 deletions codeflash/verification/parse_line_profile_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def show_func(filename, start_lineno, func_name, timings, unit):
return ''
scalar = 1
if os.path.exists(filename):
out_table+=f'## Function: {func_name}\n'
out_table += f'## Function: {func_name}\n'
# Clear the cache to ensure that we get up-to-date results.
linecache.clearcache()
all_lines = linecache.getlines(filename)
sublines = inspect.getblock(all_lines[start_lineno - 1:])
out_table+='## Total time: %g s\n' % (total_time * unit)
out_table += '## Total time: %g s\n' % (total_time * unit)
# Define minimum column sizes so text fits and usually looks consistent
default_column_sizes = {
'hits': 9,
Expand Down Expand Up @@ -57,20 +57,20 @@ def show_func(filename, start_lineno, func_name, timings, unit):
if 'def' in line_ or nhits!='':
table_rows.append((nhits, time, per_hit, percent, line_))
pass
out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
out_table += tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
out_table+='\n'
return out_table

def show_text(stats: dict) -> str:
""" Show text for the given timings.
"""
out_table = ""
out_table+='# Timer unit: %g s\n' % stats['unit']
out_table += '# Timer unit: %g s\n' % stats['unit']
stats_order = sorted(stats['timings'].items())
# Show detailed per-line information for each function.
for (fn, lineno, name), timings in stats_order:
table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
out_table+=table_md
table_md = show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
out_table += table_md
return out_table

def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict:
Expand All @@ -83,6 +83,6 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic
stats = pickle.load(f)
stats_dict['timings'] = stats.timings
stats_dict['unit'] = stats.unit
str_out=show_text(stats_dict)
stats_dict['str_out']=str_out
return stats_dict, None
str_out = show_text(stats_dict)
stats_dict['str_out'] = str_out
return stats_dict, None
2 changes: 1 addition & 1 deletion codeflash/verification/parse_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
iteration_id = val[5]
runtime = val[6]
verification_type = val[8]
if verification_type in (VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER):
if verification_type in {VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER}:
test_type = TestType.INIT_STATE_TEST
else:
# TODO : this is because sqlite writes original file module path. Should make it consistent
Expand Down
4 changes: 2 additions & 2 deletions codeflash/verification/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def run_line_profile_tests(
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
Expand Down Expand Up @@ -224,7 +224,7 @@ def run_benchmarking_tests(
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/verification/verification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
assert test_type in ["unit", "inspired", "replay", "perf"]
assert test_type in {"unit", "inspired", "replay", "perf"}
function_name = function_name.replace(".", "_")
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}.py"
if path.exists():
Expand Down
Loading