Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,6 @@ async def _handle_tool_execution(

validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
if not tool_uses:
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
return

if agent._interrupt_state.activated:
tool_results.extend(agent._interrupt_state.context["tool_results"])
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/mocked_model_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union
from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union

from pydantic import BaseModel

Expand All @@ -25,8 +25,8 @@ class MockedModelProvider(Model):
to stream mock responses as events.
"""

def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]):
self.agent_responses = agent_responses
def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]):
self.agent_responses = [*agent_responses]
self.index = 0

def format_chunk(self, event: Any) -> StreamEvent:
Expand Down
47 changes: 47 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,3 +2065,50 @@ def test_agent_tool_caller_interrupt(user):
exp_message = r"cannot directly call tool during interrupt"
with pytest.raises(RuntimeError, match=exp_message):
agent.tool.test_tool()


def test_agent__call__invalid_tool_name():
@strands.tool
def shell(command: str):
pass

model = MockedModelProvider(
[
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tool_use_id",
"name": "invalid tool",
"input": "{}",
}
}
],
},
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
]
)

agent = Agent(tools=[shell], model=model)
result = agent("Test")

# Ensure the stop_reason is
assert result.stop_reason == "end_turn"

# Assert that there exists a message with a toolResponse
assert agent.messages[-2] == {
"content": [
{
"toolResult": {
"content": [{"text": "Error: tool_name=<invalid tool> | invalid tool name pattern"}],
"status": "error",
"toolUseId": "tool_use_id",
}
}
],
"role": "user",
}

# And that it continued to the LLM call
assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"}
55 changes: 54 additions & 1 deletion tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import concurrent
import unittest.mock
from unittest.mock import MagicMock, call, patch
from unittest.mock import ANY, MagicMock, call, patch

import pytest

Expand All @@ -18,13 +18,15 @@
from strands.telemetry.metrics import EventLoopMetrics
from strands.tools.executors import SequentialToolExecutor
from strands.tools.registry import ToolRegistry
from strands.types._events import EventLoopStopEvent
from strands.types.exceptions import (
ContextWindowOverflowException,
EventLoopException,
MaxTokensReachedException,
ModelThrottledException,
)
from tests.fixtures.mock_hook_provider import MockHookProvider
from tests.fixtures.mocked_model_provider import MockedModelProvider


@pytest.fixture
Expand Down Expand Up @@ -744,6 +746,8 @@ async def test_event_loop_cycle_with_parent_span(
async def test_request_state_initialization(alist):
# Create a mock agent
mock_agent = MagicMock()
# not setting this to False results in endless recursion
mock_agent._interrupt_state.activated = False
mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock())

# Call without providing request_state
Expand Down Expand Up @@ -1011,3 +1015,52 @@ def interrupt_callback(event):
"interrupts": {},
}
assert tru_state == exp_state


@pytest.mark.asyncio
async def test_invalid_tool_names_adds_tool_uses(agent, model, alist):
model.stream = MockedModelProvider(
[
{
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tool_use_id",
"name": "invalid tool",
"input": "{}",
}
}
],
},
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
]
).stream

stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
events = await alist(stream)

# ensure that we got end_turn and not tool_use
assert events[-1] == EventLoopStopEvent(
stop_reason="end_turn",
message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"},
metrics=ANY,
request_state={},
)

# Ensure that an "invalid tool name" message was added properly
assert agent.messages[-2] == {
"content": [
{
"toolResult": {
"content": [{"text": "Error: tool_name=<invalid tool> | invalid tool name pattern"}],
"status": "error",
"toolUseId": "tool_use_id",
}
}
],
"role": "user",
}
Empty file removed tests/strands/types/__init__.py
Empty file.