diff --git a/src/core/services/tool_call_handlers/pytest_full_suite_handler.py b/src/core/services/tool_call_handlers/pytest_full_suite_handler.py index c15a82a6..aad48ce9 100644 --- a/src/core/services/tool_call_handlers/pytest_full_suite_handler.py +++ b/src/core/services/tool_call_handlers/pytest_full_suite_handler.py @@ -13,6 +13,7 @@ import logging import re +import shlex from dataclasses import dataclass from pathlib import Path from typing import Any @@ -27,7 +28,20 @@ # Matches commands invoking pytest (pytest, python -m pytest, py.test, etc.) -_PYTEST_ROOT_PATTERN = re.compile(r"\b(pytest|py\.test)(?:\b|\.py\b)", re.IGNORECASE) +_PYTEST_ROOT_PATTERN = re.compile( + r"^(pytest|py\.test)(?:$|\.py$|\.exe$|\.bat$)", + re.IGNORECASE, +) + + +def _token_invokes_pytest(token: str) -> bool: + """Return True when the token represents a pytest executable.""" + + if not token: + return False + + executable_name = Path(token).name + return bool(_PYTEST_ROOT_PATTERN.fullmatch(executable_name)) DEFAULT_STEERING_MESSAGE = ( @@ -91,6 +105,56 @@ def _normalize_whitespace(command: str) -> str: return " ".join(command.strip().split()) +def _split_command_tokens(command: str) -> list[str]: + try: + return shlex.split(command, posix=True) + except ValueError: + return command.split() + + +def _command_invokes_pytest(command: str) -> bool: + tokens = _split_command_tokens(command) + if not tokens: + return False + + separators = {"&&", ";", "||", "|"} + last_separator_index = -1 + + for index, token in enumerate(tokens): + if token in separators: + last_separator_index = index + continue + + if not _token_invokes_pytest(token): + continue + + segment_start = last_separator_index + 1 + segment_tokens = tokens[segment_start : index + 1] + + if _segment_represents_installation(segment_tokens[:-1]): + continue + + return True + + return False + + +def _segment_represents_installation(tokens: list[str]) -> bool: + if not tokens: + return False + + installation_keywords = { + "install", + "add", + "remove", + "uninstall", + "update", + "upgrade", + } + + return any(token.lower() in installation_keywords for token in tokens) + + def _looks_like_full_suite(command: str) -> bool: """Determine if the pytest command targets the entire suite. @@ -101,7 +165,7 @@ def _looks_like_full_suite(command: str) -> bool: """ normalized = _normalize_whitespace(command) - if not _PYTEST_ROOT_PATTERN.search(normalized): + if not normalized: return False tokens = normalized.split() @@ -110,7 +174,7 @@ def _looks_like_full_suite(command: str) -> bool: # Identify index where pytest command appears and inspect subsequent tokens. try: pytest_index = next( - i for i, tok in enumerate(tokens) if _PYTEST_ROOT_PATTERN.search(tok) + i for i, tok in enumerate(tokens) if _token_invokes_pytest(tok) ) except StopIteration: return False @@ -254,14 +318,16 @@ def _extract_pytest_command(self, context: ToolCallContext) -> str | None: command = _extract_command(arguments) - if normalized_tool_name in shell_tools: + if command and normalized_tool_name in shell_tools: + if not _command_invokes_pytest(command): + return None return command # Some providers map pytest directly as function name - if _PYTEST_ROOT_PATTERN.search(tool_name): + if _PYTEST_ROOT_PATTERN.fullmatch(tool_name): return command or tool_name - if command and _PYTEST_ROOT_PATTERN.search(command): + if command and _command_invokes_pytest(command): prefix = tool_name if prefix: return f"{prefix} {command}".strip() diff --git a/tests/unit/core/services/tool_call_handlers/test_pytest_full_suite_handler.py b/tests/unit/core/services/tool_call_handlers/test_pytest_full_suite_handler.py index 11ddf82c..34665771 100644 --- a/tests/unit/core/services/tool_call_handlers/test_pytest_full_suite_handler.py +++ b/tests/unit/core/services/tool_call_handlers/test_pytest_full_suite_handler.py @@ -14,6 +14,8 @@ ("pytest", True), ("python -m pytest", True), ("py.test", True), + ("./.venv/bin/pytest", True), + (r"C:\\venv\\Scripts\\pytest.exe", True), ("pytest -q", True), ("pytest --maxfail=1", True), ("pytest tests/unit", False), @@ -22,6 +24,7 @@ ("pytest some/test/path", False), ("pytest some/test/path::TestSuite::test_case", False), ("pytest tests.unit.test_example", False), + ("./.venv/bin/pytest tests/unit", False), ("pytest .", True), ("pytest ./tests", False), ], @@ -143,6 +146,17 @@ async def test_handler_detects_list_based_command() -> None: assert result.should_swallow is True +@pytest.mark.asyncio +async def test_handler_detects_path_qualified_pytest_invocation() -> None: + handler = PytestFullSuiteHandler(enabled=True) + context = _build_context("./.venv/bin/pytest") + + assert await handler.can_handle(context) is True + result = await handler.handle(context) + + assert result.should_swallow is True + + @pytest.mark.asyncio async def test_handler_enabled_flag_controls_behavior() -> None: handler = PytestFullSuiteHandler(enabled=False) @@ -184,3 +198,25 @@ async def test_handler_allows_targeted_python_pytest_invocation() -> None: result = await handler.handle(context) assert result.should_swallow is False + + +@pytest.mark.asyncio +async def test_handler_ignores_pytest_installation_commands() -> None: + handler = PytestFullSuiteHandler(enabled=True) + context = _build_context("pip install pytest") + + assert await handler.can_handle(context) is False + result = await handler.handle(context) + + assert result.should_swallow is False + + +@pytest.mark.asyncio +async def test_handler_ignores_pytest_plugin_installation() -> None: + handler = PytestFullSuiteHandler(enabled=True) + context = _build_context("pip install pytest-cov") + + assert await handler.can_handle(context) is False + result = await handler.handle(context) + + assert result.should_swallow is False