Skip to content

Commit ba6a1c0

Browse files
authored
Merge pull request #245 from codeflash-ai/deferred-imports
optimize codeflash import time
2 parents f80a076 + 9b1be30 commit ba6a1c0

File tree

6 files changed

+64
-38
lines changed

6 files changed

+64
-38
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,12 @@
33
from argparse import SUPPRESS, ArgumentParser, Namespace
44
from pathlib import Path
55

6-
import git
7-
86
from codeflash.cli_cmds import logging_config
97
from codeflash.cli_cmds.cli_common import apologize_and_exit
108
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
119
from codeflash.cli_cmds.console import logger
1210
from codeflash.code_utils import env_utils
1311
from codeflash.code_utils.config_parser import parse_config_file
14-
from codeflash.code_utils.git_utils import (
15-
check_and_push_branch,
16-
check_running_in_git_repo,
17-
confirm_proceeding_with_no_git_repo,
18-
get_repo_owner_and_name,
19-
)
20-
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
2112
from codeflash.version import __version__ as version
2213

2314

@@ -75,6 +66,13 @@ def parse_args() -> Namespace:
7566

7667

7768
def process_and_validate_cmd_args(args: Namespace) -> Namespace:
69+
from codeflash.code_utils.git_utils import (
70+
check_running_in_git_repo,
71+
confirm_proceeding_with_no_git_repo,
72+
get_repo_owner_and_name,
73+
)
74+
from codeflash.code_utils.github_utils import require_github_app_or_exit
75+
7876
is_init: bool = args.command.startswith("init") if args.command else False
7977
if args.verbose:
8078
logging_config.set_level(logging.DEBUG, echo_setting=not is_init)
@@ -144,21 +142,26 @@ def process_pyproject_config(args: Namespace) -> Namespace:
144142
assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), (
145143
f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}"
146144
)
147-
if env_utils.get_pr_number() is not None:
148-
assert env_utils.ensure_codeflash_api_key(), (
149-
"Codeflash API key not found. When running in a Github Actions Context, provide the "
150-
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
151-
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
152-
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
153-
f"Here's a direct link: {get_github_secrets_page_url()}\n"
154-
"Exiting..."
155-
)
145+
if env_utils.get_pr_number() is not None:
146+
import git
147+
148+
from codeflash.code_utils.git_utils import get_repo_owner_and_name
149+
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
150+
151+
assert env_utils.ensure_codeflash_api_key(), (
152+
"Codeflash API key not found. When running in a Github Actions Context, provide the "
153+
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
154+
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
155+
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
156+
f"Here's a direct link: {get_github_secrets_page_url()}\n"
157+
"Exiting..."
158+
)
156159

157-
repo = git.Repo(search_parent_directories=True)
160+
repo = git.Repo(search_parent_directories=True)
158161

159-
owner, repo_name = get_repo_owner_and_name(repo)
162+
owner, repo_name = get_repo_owner_and_name(repo)
160163

161-
require_github_app_or_exit(owner, repo_name)
164+
require_github_app_or_exit(owner, repo_name)
162165

163166
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
164167
normalized_ignore_paths = []
@@ -187,6 +190,11 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
187190

188191
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
189192
if hasattr(args, "all"):
193+
import git
194+
195+
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
196+
from codeflash.code_utils.github_utils import require_github_app_or_exit
197+
190198
# Ensure that the user can actually open PRs on the repo.
191199
try:
192200
git_repo = git.Repo(search_parent_directories=True)

codeflash/code_utils/code_extractor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import TYPE_CHECKING, Optional
66

77
import libcst as cst
8-
import libcst.matchers as m
98
from libcst.codemod import CodemodContext
109
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
1110
from libcst.helpers import calculate_module_and_package
@@ -248,6 +247,8 @@ class FutureAliasedImportTransformer(cst.CSTTransformer):
248247
def leave_ImportFrom(
249248
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
250249
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
250+
import libcst.matchers as m
251+
251252
if (
252253
(updated_node_module := updated_node.module)
253254
and updated_node_module.value == "__future__"

codeflash/context/code_context_extractor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from collections import defaultdict
55
from itertools import chain
66
from pathlib import Path # noqa: TC003
7+
from typing import TYPE_CHECKING
78

8-
import jedi
99
import libcst as cst
10-
from jedi.api.classes import Name # noqa: TC002
1110
from libcst import CSTNode # noqa: TC002
1211

1312
from codeflash.cli_cmds.console import logger
@@ -24,6 +23,9 @@
2423
)
2524
from codeflash.optimization.function_context import belongs_to_function_qualified
2625

26+
if TYPE_CHECKING:
27+
from jedi.api.classes import Name
28+
2729

2830
def get_code_optimization_context(
2931
function_to_optimize: FunctionToOptimize,
@@ -354,6 +356,8 @@ def extract_code_markdown_context_from_files(
354356
def get_function_to_optimize_as_function_source(
355357
function_to_optimize: FunctionToOptimize, project_root_path: Path
356358
) -> FunctionSource:
359+
import jedi
360+
357361
# Use jedi to find function to optimize
358362
script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path))
359363

@@ -389,6 +393,8 @@ def get_function_to_optimize_as_function_source(
389393
def get_function_sources_from_jedi(
390394
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
391395
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
396+
import jedi
397+
392398
file_path_to_function_source = defaultdict(set)
393399
function_source_list: list[FunctionSource] = []
394400
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():

codeflash/discovery/discover_unit_tests.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pathlib import Path
1313
from typing import TYPE_CHECKING, Callable, Optional
1414

15-
import jedi
1615
import pytest
1716
from pydantic.dataclasses import dataclass
1817

@@ -281,6 +280,8 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
281280
def process_test_files(
282281
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
283282
) -> dict[str, list[FunctionCalledInTest]]:
283+
import jedi
284+
284285
project_root_path = cfg.project_root_path
285286
test_framework = cfg.test_framework
286287

codeflash/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from codeflash.cli_cmds.console import paneled_text
1212
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
1313
from codeflash.code_utils.config_parser import parse_config_file
14-
from codeflash.optimization import optimizer
1514
from codeflash.telemetry import posthog_cf
1615
from codeflash.telemetry.sentry import init_sentry
1716

@@ -41,6 +40,9 @@ def main() -> None:
4140
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
4241
init_sentry(not args.disable_telemetry, exclude_errors=True)
4342
posthog_cf.initialize_posthog(not args.disable_telemetry)
43+
44+
from codeflash.optimization import optimizer
45+
4446
optimizer.run_with_args(args)
4547

4648

codeflash/optimization/optimizer.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,20 @@
99
from typing import TYPE_CHECKING
1010

1111
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
12-
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
13-
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
14-
from codeflash.benchmarking.replay_test import generate_replay_test
15-
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
16-
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
1712
from codeflash.cli_cmds.console import console, logger, progress_bar
1813
from codeflash.code_utils import env_utils
19-
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
20-
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
21-
from codeflash.code_utils.code_utils import cleanup_paths
22-
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
23-
from codeflash.discovery.discover_unit_tests import discover_unit_tests
24-
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
2514
from codeflash.either import is_successful
2615
from codeflash.models.models import ValidCode
27-
from codeflash.optimization.function_optimizer import FunctionOptimizer
2816
from codeflash.telemetry.posthog_cf import ph
2917
from codeflash.verification.verification_utils import TestConfig
3018

3119
if TYPE_CHECKING:
3220
from argparse import Namespace
3321

22+
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
3423
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
3524
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest
25+
from codeflash.optimization.function_optimizer import FunctionOptimizer
3626

3727

3828
class Optimizer:
@@ -63,6 +53,8 @@ def create_function_optimizer(
6353
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
6454
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
6555
) -> FunctionOptimizer:
56+
from codeflash.optimization.function_optimizer import FunctionOptimizer
57+
6658
return FunctionOptimizer(
6759
function_to_optimize=function_to_optimize,
6860
test_cfg=self.test_cfg,
@@ -77,6 +69,16 @@ def create_function_optimizer(
7769
)
7870

7971
def run(self) -> None:
72+
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
73+
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
74+
from codeflash.code_utils.code_utils import cleanup_paths
75+
from codeflash.code_utils.static_analysis import (
76+
analyze_imported_modules,
77+
get_first_top_level_function_or_method_ast,
78+
)
79+
from codeflash.discovery.discover_unit_tests import discover_unit_tests
80+
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
81+
8082
ph("cli-optimize-run-start")
8183
logger.info("Running optimizer.")
8284
console.rule()
@@ -102,6 +104,12 @@ def run(self) -> None:
102104
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
103105
total_benchmark_timings: dict[BenchmarkKey, int] = {}
104106
if self.args.benchmark and num_optimizable_functions > 0:
107+
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
108+
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
109+
from codeflash.benchmarking.replay_test import generate_replay_test
110+
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
111+
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
112+
105113
with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True):
106114
# Insert decorator
107115
file_path_to_source_code = defaultdict(str)

0 commit comments

Comments
 (0)