Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/core/services/tool_call_handlers/pytest_full_suite_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,14 @@ async def can_handle(self, context: ToolCallContext) -> bool:
if not _looks_like_full_suite(normalized):
return False

state = self._session_state.get(context.session_id)
session_key = context.session_id.strip() if context.session_id else None
if not session_key:
# Without a stable session identifier we cannot track per-session
# steering state. Treat the command as new so that we do not leak
# steering behaviour across unrelated requests.
return True

state = self._session_state.get(session_key)
return not (state and state.last_command == normalized)

async def handle(self, context: ToolCallContext) -> ToolCallReactionResult:
Expand All @@ -213,15 +220,36 @@ async def handle(self, context: ToolCallContext) -> ToolCallReactionResult:
if not _looks_like_full_suite(normalized):
return ToolCallReactionResult(should_swallow=False)

state = self._session_state.setdefault(context.session_id, _SessionState())
session_key = context.session_id.strip() if context.session_id else None

if session_key is None:
# Without a reliable session identifier we cannot remember previous
# steering decisions. Swallow this invocation but avoid mutating the
# global state dictionary so that other requests are unaffected.
logger.info(
"Steering full-suite pytest command without session id: %s",
normalized,
)
return ToolCallReactionResult(
should_swallow=True,
replacement_response=self._message,
metadata={
"handler": self.name,
"tool_name": context.tool_name,
"command": normalized,
"source": "pytest_full_suite_steering",
},
)

state = self._session_state.setdefault(session_key, _SessionState())
if state.last_command == normalized:
return ToolCallReactionResult(should_swallow=False)

state.last_command = normalized

logger.info(
"Steering full-suite pytest command in session %s: %s",
context.session_id,
session_key,
normalized,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ async def test_handler_allows_second_session_immediately() -> None:
assert await handler.can_handle(second) is True


@pytest.mark.asyncio
async def test_handler_handles_missing_session_id_without_state_leak() -> None:
handler = PytestFullSuiteHandler(enabled=True)
first = _build_context("pytest", session_id="")
second = _build_context("pytest", session_id="")

assert await handler.can_handle(first) is True
first_result = await handler.handle(first)
assert first_result.should_swallow is True

assert await handler.can_handle(second) is True
second_result = await handler.handle(second)
assert second_result.should_swallow is True


@pytest.mark.asyncio
async def test_handler_passes_through_targeted_pytest() -> None:
handler = PytestFullSuiteHandler(enabled=True)
Expand Down