diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index d814f12d0..e9ff9a735 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -52,8 +52,8 @@ def parse_config_file( assert isinstance(config, dict) # default values: - path_keys = ["module-root", "tests-root"] - path_list_keys = ["ignore-paths"] + path_keys = {"module-root", "tests-root"} + path_list_keys = {"ignore-paths", } str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"} bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False} list_str_keys = {"formatter-cmds": ["black $file"]} @@ -83,7 +83,7 @@ def parse_config_file( else: # Default to empty list config[key] = [] - assert config["test-framework"] in ["pytest", "unittest"], ( + assert config["test-framework"] in {"pytest", "unittest"}, ( "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest." ) if len(config["formatter-cmds"]) > 0: diff --git a/codeflash/code_utils/tabulate.py b/codeflash/code_utils/tabulate.py index c278bfeae..c75dcd03e 100644 --- a/codeflash/code_utils/tabulate.py +++ b/codeflash/code_utils/tabulate.py @@ -70,7 +70,7 @@ def _pipe_segment_with_colons(align, colwidth): """Return a segment of a horizontal line with optional colons which indicate column's alignment (as in `pipe` output format).""" w = colwidth - if align in ["right", "decimal"]: + if align in {"right", "decimal"}: return ("-" * (w - 1)) + ":" elif align == "center": return ":" + ("-" * (w - 2)) + ":" @@ -176,7 +176,7 @@ def _isconvertible(conv, string): def _isnumber(string): return ( # fast path - type(string) in (float, int) + type(string) in {float, int} # covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type # convertible to int/float. or ( @@ -188,7 +188,7 @@ def _isnumber(string): # just an over/underflow or ( not (math.isinf(float(string)) or math.isnan(float(string))) - or string.lower() in ["inf", "-inf", "nan"] + or string.lower() in {"inf", "-inf", "nan"} ) ) ) @@ -210,7 +210,7 @@ def _isint(string, inttype=int): def _isbool(string): return type(string) is bool or ( - isinstance(string, (bytes, str)) and string in ("True", "False") + isinstance(string, (bytes, str)) and string in {"True", "False"} ) @@ -570,7 +570,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) keys = list(tabular_data) if ( - showindex in ["default", "always", True] + showindex in {"default", "always", True} and tabular_data.index.name is not None ): if isinstance(tabular_data.index.name, list): @@ -686,7 +686,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows)) # add or remove an index column - showindex_is_a_str = type(showindex) in [str, bytes] + showindex_is_a_str = type(showindex) in {str, bytes} if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str): pass @@ -820,7 +820,7 @@ def tabulate( if colglobalalign is not None: # if global alignment provided aligns = [colglobalalign] * len(cols) else: # default - aligns = [numalign if ct in [int, float] else stralign for ct in coltypes] + aligns = [numalign if ct in {int, float} else stralign for ct in coltypes] # then specific alignments if colalign is not None: assert isinstance(colalign, Iterable) @@ -1044,4 +1044,4 @@ def _format_table( output = "\n".join(lines) return output else: # a completely empty table - return "" \ No newline at end of file + return "" diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index ce97183b4..d62c750cf 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str: units = re.split(r",|\s", runtime_human)[1] - if units in ("microseconds", "microsecond"): + if units in {"microseconds", "microsecond"}: runtime_human = f"{time_micro:.3g}" - elif units in ("milliseconds", "millisecond"): + elif units in {"milliseconds", "millisecond"}: runtime_human = "%.3g" % (time_micro / 1000) - elif units in ("seconds", "second"): + elif units in {"seconds", "second"}: runtime_human = "%.3g" % (time_micro / (1000**2)) - elif units in ("minutes", "minute"): + elif units in {"minutes", "minute"}: runtime_human = "%.3g" % (time_micro / (60 * 1000**2)) else: # hours runtime_human = "%.3g" % (time_micro / (3600 * 1000**2)) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 93def83c0..7fa8805c9 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -793,7 +793,7 @@ def establish_original_code_baseline( line_profile_results = {"timings": {}, "unit": 0, "str_out": ""} # For the original function - run the tests and get the runtime, plus coverage with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"): - assert (test_framework := self.args.test_framework) in ["pytest", "unittest"] + assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} success = True test_env = os.environ.copy() @@ -941,7 +941,7 @@ def run_optimized_candidate( original_helper_code: dict[Path, str], file_path_to_helper_classes: dict[Path, set[str]], ) -> Result[OptimizedCandidateResult, str]: - assert (test_framework := self.args.test_framework) in ["pytest", "unittest"] + assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} with progress_bar("Testing optimization candidate"): test_env = os.environ.copy() @@ -1118,7 +1118,7 @@ def run_and_parse_tests( f"stdout: {run_result.stdout}\n" f"stderr: {run_result.stderr}\n" ) - if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]: + if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}: results, coverage_results = parse_test_results( test_xml_path=result_file_path, test_files=test_files, diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index 50b8dae2e..6cf82f3e3 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -10,7 +10,7 @@ class ProfileStats(pstats.Stats): def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None: assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist" - assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}" + assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}" self.trace_file_path = trace_file_path self.time_unit = time_unit logger.debug(hasattr(self, "create_stats")) @@ -59,10 +59,10 @@ def print_stats(self, *amount): time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit] print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream) print(file=self.stream) - width, list = self.get_print_list(amount) - if list: + width, list_ = self.get_print_list(amount) + if list_: self.print_title() - for func in list: + for func in list_: self.print_line(func) print(file=self.stream) print(file=self.stream) diff --git a/codeflash/tracing/replay_test.py b/codeflash/tracing/replay_test.py index eca1e50ef..d3439c81b 100644 --- a/codeflash/tracing/replay_test.py +++ b/codeflash/tracing/replay_test.py @@ -42,7 +42,7 @@ def get_function_alias(module: str, function_name: str) -> str: def create_trace_replay_test( trace_file: str, functions: list[FunctionModules], test_framework: str = "pytest", max_run_count=100 ) -> str: - assert test_framework in ["pytest", "unittest"] + assert test_framework in {"pytest", "unittest"} imports = f"""import dill as pickle {"import unittest" if test_framework == "unittest" else ""} diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 60372fcb4..f047d5b3c 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -233,7 +233,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"} return comparator(orig_keys, new_keys, superset_obj) - if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]: + if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}: return new == orig if str(type(orig)) == "": return True diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 6703298c0..67b9de439 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -40,7 +40,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR superset_obj = False if original_test_result.verification_type and ( original_test_result.verification_type - in (VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO) + in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): @@ -70,7 +70,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR are_equal = False break - if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and ( + if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST} and ( cdd_test_result.did_pass != original_test_result.did_pass ): are_equal = False diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py index 0536d4825..5e753b932 100644 --- a/codeflash/verification/parse_line_profile_test_output.py +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -16,12 +16,12 @@ def show_func(filename, start_lineno, func_name, timings, unit): return '' scalar = 1 if os.path.exists(filename): - out_table+=f'## Function: {func_name}\n' + out_table += f'## Function: {func_name}\n' # Clear the cache to ensure that we get up-to-date results. linecache.clearcache() all_lines = linecache.getlines(filename) sublines = inspect.getblock(all_lines[start_lineno - 1:]) - out_table+='## Total time: %g s\n' % (total_time * unit) + out_table += '## Total time: %g s\n' % (total_time * unit) # Define minimum column sizes so text fits and usually looks consistent default_column_sizes = { 'hits': 9, @@ -57,7 +57,7 @@ def show_func(filename, start_lineno, func_name, timings, unit): if 'def' in line_ or nhits!='': table_rows.append((nhits, time, per_hit, percent, line_)) pass - out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True) + out_table += tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True) out_table+='\n' return out_table @@ -65,12 +65,12 @@ def show_text(stats: dict) -> str: """ Show text for the given timings. """ out_table = "" - out_table+='# Timer unit: %g s\n' % stats['unit'] + out_table += '# Timer unit: %g s\n' % stats['unit'] stats_order = sorted(stats['timings'].items()) # Show detailed per-line information for each function. for (fn, lineno, name), timings in stats_order: - table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit']) - out_table+=table_md + table_md = show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit']) + out_table += table_md return out_table def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict: @@ -83,6 +83,6 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic stats = pickle.load(f) stats_dict['timings'] = stats.timings stats_dict['unit'] = stats.unit - str_out=show_text(stats_dict) - stats_dict['str_out']=str_out - return stats_dict, None \ No newline at end of file + str_out = show_text(stats_dict) + stats_dict['str_out'] = str_out + return stats_dict, None diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 924e2876a..2228559f9 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -127,7 +127,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes iteration_id = val[5] runtime = val[6] verification_type = val[8] - if verification_type in (VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER): + if verification_type in {VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER}: test_type = TestType.INIT_STATE_TEST else: # TODO : this is because sqlite writes original file module path. Should make it consistent diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index d8c58eb9a..483695b1a 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -164,7 +164,7 @@ def run_line_profile_tests( ) test_files: list[str] = [] for file in test_paths.test_files: - if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file: + if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file: test_files.extend( [ str(file.benchmarking_file_path) @@ -224,7 +224,7 @@ def run_benchmarking_tests( ) test_files: list[str] = [] for file in test_paths.test_files: - if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file: + if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file: test_files.extend( [ str(file.benchmarking_file_path) diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 79f1b9656..91ed31757 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -8,7 +8,7 @@ def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path: - assert test_type in ["unit", "inspired", "replay", "perf"] + assert test_type in {"unit", "inspired", "replay", "perf"} function_name = function_name.replace(".", "_") path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}.py" if path.exists():