diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 116f7956d..5ea062283 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -427,9 +427,6 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] - if not tool_uses: - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) - return if agent._interrupt_state.activated: tool_results.extend(agent._interrupt_state.context["tool_results"]) diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index 4523a8352..56817a6e4 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union from pydantic import BaseModel @@ -25,8 +25,8 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]): - self.agent_responses = agent_responses + def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]): + self.agent_responses = [*agent_responses] self.index = 0 def format_chunk(self, event: Any) -> StreamEvent: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 9d490c0de..892ff86d1 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2065,3 +2065,50 @@ def test_agent_tool_caller_interrupt(user): exp_message = r"cannot directly call tool during interrupt" with pytest.raises(RuntimeError, match=exp_message): agent.tool.test_tool() + + +def test_agent__call__invalid_tool_name(): + @strands.tool + def shell(command: str): + pass + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_use_id", + "name": "invalid tool", + "input": "{}", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ) + + agent = Agent(tools=[shell], model=model) + result = agent("Test") + + # Ensure the stop_reason is + assert result.stop_reason == "end_turn" + + # Assert that there exists a message with a toolResponse + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "Error: tool_name= | invalid tool name pattern"}], + "status": "error", + "toolUseId": "tool_use_id", + } + } + ], + "role": "user", + } + + # And that it continued to the LLM call + assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"} diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 2d9af1741..72c63e897 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,6 +1,6 @@ import concurrent import unittest.mock -from unittest.mock import MagicMock, call, patch +from unittest.mock import ANY, MagicMock, call, patch import pytest @@ -18,6 +18,7 @@ from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry +from strands.types._events import EventLoopStopEvent from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -25,6 +26,7 @@ ModelThrottledException, ) from tests.fixtures.mock_hook_provider import MockHookProvider +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -744,6 +746,8 @@ async def test_event_loop_cycle_with_parent_span( async def test_request_state_initialization(alist): # Create a mock agent mock_agent = MagicMock() + # not setting this to False results in endless recursion + mock_agent._interrupt_state.activated = False mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) # Call without providing request_state @@ -1011,3 +1015,52 @@ def interrupt_callback(event): "interrupts": {}, } assert tru_state == exp_state + + +@pytest.mark.asyncio +async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): + model.stream = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_use_id", + "name": "invalid tool", + "input": "{}", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ).stream + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + + # ensure that we got end_turn and not tool_use + assert events[-1] == EventLoopStopEvent( + stop_reason="end_turn", + message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"}, + metrics=ANY, + request_state={}, + ) + + # Ensure that an "invalid tool name" message was added properly + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "Error: tool_name= | invalid tool name pattern"}], + "status": "error", + "toolUseId": "tool_use_id", + } + } + ], + "role": "user", + } diff --git a/tests/strands/types/__init__.py b/tests/strands/types/__init__.py deleted file mode 100644 index e69de29bb..000000000