From 6db214d320879aca639e43c3cf2765f21e824600 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 20 May 2025 15:31:06 -0400 Subject: [PATCH] adding type hints so that tests generated are sane --- codeflash/cli_cmds/cmd_init.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 7e6d2cd57..8c6196d8b 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -239,7 +239,7 @@ def collect_setup_info() -> SetupInfo: else: apologize_and_exit() else: - tests_root = Path(curdir) / Path(cast(str, tests_root_answer)) + tests_root = Path(curdir) / Path(cast("str", tests_root_answer)) tests_root = tests_root.relative_to(curdir) ph("cli-tests-root-provided") @@ -302,7 +302,7 @@ def collect_setup_info() -> SetupInfo: elif benchmarks_answer == no_benchmarks_option: benchmarks_root = None else: - benchmarks_root = tests_root / Path(cast(str, benchmarks_answer)) + benchmarks_root = tests_root / Path(cast("str", benchmarks_answer)) # TODO: Implement other benchmark framework options # if benchmarks_root: @@ -354,9 +354,9 @@ def collect_setup_info() -> SetupInfo: module_root=str(module_root), tests_root=str(tests_root), benchmarks_root=str(benchmarks_root) if benchmarks_root else None, - test_framework=cast(str, test_framework), + test_framework=cast("str", test_framework), ignore_paths=ignore_paths, - formatter=cast(str, formatter), + formatter=cast("str", formatter), git_remote=str(git_remote), ) @@ -466,7 +466,7 @@ def check_for_toml_or_setup_file() -> str | None: click.echo("⏩️ Skipping pyproject.toml creation.") apologize_and_exit() click.echo() - return cast(str, project_name) + return cast("str", project_name) def install_github_actions(override_formatter_check: bool = False) -> None: @@ -852,7 +852,7 @@ def enter_api_key_and_save_to_rc() -> None: def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]: - bubble_sort_content = """def sorter(arr): + bubble_sort_content = """def sorter(arr: list[int] | list[float]) -> list[int] | list[float]: for i in range(len(arr)): for j in range(len(arr) - 1): if arr[j] > arr[j + 1]: