diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 77be9d64e..b98e95a6e 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -8,17 +8,17 @@ Example Usage: ```python from strands.hooks import HookProvider, HookRegistry - from strands.hooks.events import StartRequestEvent, EndRequestEvent + from strands.hooks.events import BeforeInvocationEvent, AfterInvocationEvent class LoggingHooks(HookProvider): def register_hooks(self, registry: HookRegistry) -> None: - registry.add_callback(StartRequestEvent, self.log_start) - registry.add_callback(EndRequestEvent, self.log_end) + registry.add_callback(BeforeInvocationEvent, self.log_start) + registry.add_callback(AfterInvocationEvent, self.log_end) - def log_start(self, event: StartRequestEvent) -> None: + def log_start(self, event: BeforeInvocationEvent) -> None: print(f"Request started for {event.agent.name}") - def log_end(self, event: EndRequestEvent) -> None: + def log_end(self, event: AfterInvocationEvent) -> None: print(f"Request completed for {event.agent.name}") # Use with agent diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 2ce6d946f..8b218dfa1 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -53,6 +53,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: Type, TypeVar, Union, + cast, get_type_hints, overload, ) @@ -61,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse +from ..types._events import ToolResultEvent, ToolStreamEvent +from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -454,43 +456,67 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw # Inject special framework-provided parameters self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state) - # "Too few arguments" expected, hence the type ignore - if inspect.iscoroutinefunction(self._tool_func): + # Note: "Too few arguments" expected for the _tool_func calls, hence the type ignore + + # Async-generators, yield streaming events and final tool result + if inspect.isasyncgenfunction(self._tool_func): + sub_events = self._tool_func(**validated_input) # type: ignore + async for sub_event in sub_events: + yield ToolStreamEvent(tool_use, sub_event) + + # The last event is the result + yield self._wrap_tool_result(tool_use_id, sub_event) + + # Async functions, yield only the result + elif inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(**validated_input) # type: ignore - else: - result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) - # FORMAT THE RESULT for Strands Agent - if isinstance(result, dict) and "status" in result and "content" in result: - # Result is already in the expected format, just add toolUseId - result["toolUseId"] = tool_use_id - yield result + # Other functions, yield only the result else: - # Wrap any other return value in the standard format - # Always include at least one content item for consistency - yield { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": str(result)}], - } + result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) except ValueError as e: # Special handling for validation errors error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_msg}"}], + }, + ) except Exception as e: # Return error result with exception details for any other error error_type = type(e).__name__ error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_type} - {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_type} - {error_msg}"}], + }, + ) + + def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + # FORMAT THE RESULT for Strands Agent + if isinstance(result, dict) and "status" in result and "content" in result: + # Result is already in the expected format, just add toolUseId + result["toolUseId"] = tool_use_d + return ToolResultEvent(cast(ToolResult, result)) + else: + # Wrap any other return value in the standard format + # Always include at least one content item for consistency + return ToolResultEvent( + { + "toolUseId": tool_use_d, + "status": "success", + "content": [{"text": str(result)}], + } + ) @property def supports_hot_reload(self) -> bool: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 701a3bac0..5354991c3 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -119,7 +119,20 @@ async def _stream( return async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - yield ToolStreamEvent(tool_use, event) + # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() + # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. + # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent + # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in + # ToolStreamEvent and the last even is just the result + + if isinstance(event, ToolResultEvent): + # below the last "event" must point to the tool_result + event = event.tool_result + break + elif isinstance(event, ToolStreamEvent): + yield event + else: + yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index f9c8d6061..f15bb1718 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -11,6 +11,7 @@ from mcp.types import Tool as MCPTool from typing_extensions import override +from ...types._events import ToolResultEvent from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse if TYPE_CHECKING: @@ -96,4 +97,4 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw name=self.tool_name, arguments=tool_use["input"], ) - yield result + yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 5fafed5dc..66eda08ae 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -9,7 +9,7 @@ from mcp.shared.message import SessionMessage from typing_extensions import NotRequired -from strands.types.tools import ToolResult +from ...types.tools import ToolResult """ MCPTransport defines the interface for MCP transport implementations. This abstracts diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index fd395ae77..471472a64 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -190,6 +190,13 @@ def register_tool(self, tool: AgentTool) -> None: tool.is_dynamic, ) + # Check duplicate tool name, throw on duplicate tool names except if hot_reloading is enabled + if tool.tool_name in self.registry and not tool.supports_hot_reload: + raise ValueError( + f"Tool name '{tool.tool_name}' already exists. Cannot register tools with exact same name." + ) + + # Check for normalized name conflicts (- vs _) if self.registry.get(tool.tool_name) is None: normalized_name = tool.tool_name.replace("-", "_") diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 465063095..9e1c0e608 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -12,6 +12,7 @@ from typing_extensions import override +from ..types._events import ToolResultEvent from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -211,7 +212,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw """ if inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(tool_use, **invocation_state) + yield ToolResultEvent(result) else: result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state) - - yield result + yield ToolResultEvent(result) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1a7f48d4b..5f316686c 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -275,25 +275,20 @@ def is_callback_event(self) -> bool: class ToolStreamEvent(TypedEvent): """Event emitted when a tool yields sub-events as part of tool execution.""" - def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None: + def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: """Initialize with tool streaming data. Args: tool_use: The tool invocation producing the stream - tool_sub_event: The yielded event from the tool execution + tool_stream_data: The yielded event from the tool execution """ - super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event}) + super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_data": tool_stream_data}) @property def tool_use_id(self) -> str: """The toolUseId associated with this stream.""" return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId")) - @property - @override - def is_callback_event(self) -> bool: - return False - class ModelMessageEvent(TypedEvent): """Event emitted when the model invocation has completed. diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 04b832259..231872183 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -260,18 +260,18 @@ async def test_stream_e2e_success(alist): "role": "assistant", } }, + { + "tool_stream_data": {"tool_streaming": True}, + "tool_stream_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + }, + { + "tool_stream_data": "Final result", + "tool_stream_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + }, { "message": { "content": [ - { - "toolResult": { - # TODO update this text when we get tool streaming implemented; right now this - # TODO is of the form '' - "content": [{"text": ANY}], - "status": "success", - "toolUseId": "12345", - } - }, + {"toolResult": {"content": [{"text": "Final result"}], "status": "success", "toolUseId": "12345"}} ], "role": "user", } diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 140537add..f7fc64b25 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import ConcurrentToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -22,13 +22,11 @@ async def test_concurrent_executor_execute( tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 56caa950a..903a11e5a 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -1,4 +1,5 @@ import unittest.mock +from unittest.mock import MagicMock import pytest @@ -39,7 +40,6 @@ async def test_executor_stream_yields_result( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events @@ -67,6 +67,76 @@ async def test_executor_stream_yields_result( assert tru_hook_events == exp_hook_events +@pytest.mark.asyncio +async def test_executor_stream_wraps_results( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + weather_tool.stream.return_value = agenerator( + ["value 1", {"nested": True}, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}] + ) + + tru_events = await alist(stream) + exp_events = [ + ToolStreamEvent(tool_use, "value 1"), + ToolStreamEvent(tool_use, {"nested": True}), + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_executor_stream_passes_through_typed_events( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + event_1 = ToolStreamEvent(tool_use, "value 1") + event_2 = ToolStreamEvent(tool_use, {"nested": True}) + event_3 = ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}) + weather_tool.stream.return_value = agenerator( + [ + event_1, + event_2, + event_3, + ] + ) + + tru_events = await alist(stream) + assert tru_events[0] is event_1 + assert tru_events[1] is event_2 + + # ToolResults are not passed through directly, they're unwrapped then wraped again + assert tru_events[2] == event_3 + + +@pytest.mark.asyncio +async def test_executor_stream_wraps_stream_events_if_no_result( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + last_event = ToolStreamEvent(tool_use, "value 1") + # Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent + weather_tool.stream.return_value = agenerator( + [ + last_event, + ] + ) + + tru_events = await alist(stream) + exp_events = [last_event, ToolResultEvent(last_event)] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_executor_stream_yields_tool_error( executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist @@ -129,7 +199,6 @@ async def test_executor_stream_with_trace( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index d4e98223e..37e098142 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent @pytest.fixture @@ -21,13 +21,11 @@ async def test_sequential_executor_execute( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 874006683..1c025f5f2 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -4,6 +4,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPAgentTool, MCPClient +from strands.types._events import ToolResultEvent @pytest.fixture @@ -62,7 +63,7 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) - exp_events = [mock_mcp_client.call_tool_async.return_value] + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 02e7eb445..5b4b5cdda 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,13 +3,14 @@ """ from asyncio import Queue -from typing import Any, Dict, Optional, Union +from typing import Any, AsyncGenerator, Dict, Optional, Union from unittest.mock import MagicMock import pytest import strands from strands import Agent +from strands.types._events import ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -117,7 +118,7 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"toolUseId": "t1", "input": {"a": 2}}, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]}] + exp_events = [ToolResultEvent({"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]})] assert tru_events == exp_events @@ -131,7 +132,9 @@ def identity(a: int, agent: dict = None): stream = identity.stream({"input": {"a": 2}}, {"agent": {"state": 1}}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}) + ] assert tru_events == exp_events @@ -180,7 +183,9 @@ def test_tool(param1: str, param2: int) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] assert tru_events == exp_events # Make sure these are set properly @@ -229,7 +234,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}) + ] assert tru_events == exp_events # Test with both params @@ -237,7 +244,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] @pytest.mark.asyncio @@ -256,8 +265,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "validation error for test_tooltool\nrequired\n" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "validation error for test_tooltool\nrequired\n" in result["tool_result"]["content"][0]["text"].lower(), ( "Validation error should indicate which argument is missing" ) @@ -266,8 +275,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test error" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "test error" in result["tool_result"]["content"][0]["text"].lower(), ( "Runtime error should contain the original error message" ) @@ -313,14 +322,14 @@ def test_tool(param: str, agent=None) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "Param: test" + assert result["tool_result"]["content"][0]["text"] == "Param: test" # Test with agent stream = test_tool.stream(tool_use, {"agent": mock_agent}) result = (await alist(stream))[-1] - assert "Agent:" in result["content"][0]["text"] - assert "test" in result["content"][0]["text"] + assert "Agent:" in result["tool_result"]["content"][0]["text"] + assert "test" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -350,23 +359,23 @@ def none_return_tool(param: str) -> None: stream = dict_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" - assert result["toolUseId"] == "test-id" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" + assert result["tool_result"]["toolUseId"] == "test-id" # Test the string return - should wrap in standard format stream = string_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" # Test None return - should still create valid ToolResult with "None" text stream = none_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -403,7 +412,7 @@ def test_method(self, param: str) -> str: stream = instance.test_method.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "Test: tool-value" in result["content"][0]["text"] + assert "Test: tool-value" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -422,7 +431,9 @@ class MyThing: ... stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -444,7 +455,9 @@ def test_method(param: str) -> str: stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -474,14 +487,14 @@ def tool_with_defaults(required: str, optional: str = "default", number: int = 4 stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 42" + assert result["tool_result"]["content"][0]["text"] == "hello default 42" # Call with some but not all optional parameters tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "number": 100}} stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 100" + assert result["tool_result"]["content"][0]["text"] == "hello default 100" @pytest.mark.asyncio @@ -496,14 +509,15 @@ def test_tool(required: str) -> str: # Test with completely empty tool use stream = test_tool.stream({}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "unknown" in result["toolUseId"] + print(result) + assert result["tool_result"]["status"] == "error" + assert "unknown" in result["tool_result"]["toolUseId"] # Test with missing input stream = test_tool.stream({"toolUseId": "test-id"}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test-id" in result["toolUseId"] + assert result["tool_result"]["status"] == "error" + assert "test-id" in result["tool_result"]["toolUseId"] @pytest.mark.asyncio @@ -529,8 +543,8 @@ def add_numbers(a: int, b: int) -> int: stream = add_numbers.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "5" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "5" @pytest.mark.asyncio @@ -565,8 +579,8 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "hello, default_str, 42, True, 3.14" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "hello, default_str, 42, True, 3.14" in result["tool_result"]["content"][0]["text"] # Test calling with some optional parameters tool_use = { @@ -576,7 +590,7 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "hello, default_str, 100, True, 2.718" in result["content"][0]["text"] + assert "hello, default_str, 100, True, 2.718" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -603,8 +617,8 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "42" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "42" # Test with return that doesn't match declared type # Note: This should still work because Python doesn't enforce return types at runtime @@ -613,16 +627,16 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "not an int" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "not an int" # Test with None return from a non-None return type tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Define tool with Union return type @strands.tool @@ -644,22 +658,25 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "{'key': 'value'}" in result["content"][0]["text"] or '{"key": "value"}' in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert ( + "{'key': 'value'}" in result["tool_result"]["content"][0]["text"] + or '{"key": "value"}' in result["tool_result"]["content"][0]["text"] + ) tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "string result" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "string result" tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -682,8 +699,8 @@ def no_params_tool() -> str: stream = no_params_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Success - no parameters needed" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Success - no parameters needed" # Test direct call direct_result = no_params_tool() @@ -711,8 +728,8 @@ def complex_type_tool(config: Dict[str, Any]) -> str: stream = complex_type_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "Got config with 3 keys" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "Got config with 3 keys" in result["tool_result"]["content"][0]["text"] # Direct call direct_result = complex_type_tool(nested_dict) @@ -742,12 +759,12 @@ def custom_result_tool(param: str) -> Dict[str, Any]: # The wrapper should preserve our format and just add the toolUseId result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["toolUseId"] == "custom-id" - assert len(result["content"]) == 2 - assert result["content"][0]["text"] == "First line: test" - assert result["content"][1]["text"] == "Second line" - assert result["content"][1]["type"] == "markdown" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["toolUseId"] == "custom-id" + assert len(result["tool_result"]["content"]) == 2 + assert result["tool_result"]["content"][0]["text"] == "First line: test" + assert result["tool_result"]["content"][1]["text"] == "Second line" + assert result["tool_result"]["content"][1]["type"] == "markdown" def test_docstring_parsing(): @@ -816,8 +833,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] # Test missing required parameter tool_use = { @@ -831,8 +848,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -855,16 +872,16 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Test with empty dict tool_use = {"toolUseId": "test-id", "input": {"param": {}}} stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "{}" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "{}" # Test with a complex nested dictionary nested_dict = {"key1": {"nested": [1, 2, 3]}, "key2": None} @@ -872,9 +889,9 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "key1" in result["content"][0]["text"] - assert "nested" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "key1" in result["tool_result"]["content"][0]["text"] + assert "nested" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -922,8 +939,8 @@ def test_method(self): stream = instance.test_method.stream({"toolUseId": "test-id", "input": {"param": "direct"}}, {}) direct_result = (await alist(stream))[-1] - assert direct_result["status"] == "success" - assert direct_result["content"][0]["text"] == "Method Got: direct" + assert direct_result["tool_result"]["status"] == "success" + assert direct_result["tool_result"]["content"][0]["text"] == "Method Got: direct" # Create a standalone function to test regular function calls @strands.tool @@ -944,8 +961,8 @@ def standalone_tool(p1: str, p2: str = "default") -> str: stream = standalone_tool.stream({"toolUseId": "test-id", "input": {"p1": "value1"}}, {}) tool_use_result = (await alist(stream))[-1] - assert tool_use_result["status"] == "success" - assert tool_use_result["content"][0]["text"] == "Standalone: value1, default" + assert tool_use_result["tool_result"]["status"] == "success" + assert tool_use_result["tool_result"]["content"][0]["text"] == "Standalone: value1, default" @pytest.mark.asyncio @@ -976,9 +993,9 @@ def failing_tool(param: str) -> str: stream = failing_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" + assert result["tool_result"]["status"] == "error" - error_message = result["content"][0]["text"] + error_message = result["tool_result"]["content"][0]["text"] # Check that error type is included if error_type == "value_error": @@ -1011,33 +1028,33 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "list: [1, 2, 3]" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "list: [1, 2, 3]" in result["tool_result"]["content"][0]["text"] # Test with a dict tool_use = {"toolUseId": "test-id", "input": {"union_param": {"key": "value"}}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "dict:" in result["content"][0]["text"] - assert "key" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "dict:" in result["tool_result"]["content"][0]["text"] + assert "key" in result["tool_result"]["content"][0]["text"] # Test with a string tool_use = {"toolUseId": "test-id", "input": {"union_param": "test_string"}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "str: test_string" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "str: test_string" in result["tool_result"]["content"][0]["text"] # Test with None tool_use = {"toolUseId": "test-id", "input": {"union_param": None}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "NoneType: None" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "NoneType: None" in result["tool_result"]["content"][0]["text"] async def _run_context_injection_test(context_tool: AgentTool, additional_context=None): @@ -1061,15 +1078,17 @@ async def _run_context_injection_test(context_tool: AgentTool, additional_contex assert len(tool_results) == 1 tool_result = tool_results[0] - assert tool_result == { - "status": "success", - "content": [ - {"text": "Tool 'context_tool' (ID: test-id)"}, - {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"}, - ], - "toolUseId": "test-id", - } + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [ + {"text": "Tool 'context_tool' (ID: test-id)"}, + {"text": "injected agent 'test_agent' processed: some_message"}, + {"text": "context agent 'test_agent'"}, + ], + "toolUseId": "test-id", + } + ) @pytest.mark.asyncio @@ -1164,9 +1183,9 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: tool_result = tool_results[0] # Should get a validation error because tool_context is required but not provided - assert tool_result["status"] == "error" - assert "tool_context" in tool_result["content"][0]["text"].lower() - assert "validation" in tool_result["content"][0]["text"].lower() + assert tool_result["tool_result"]["status"] == "error" + assert "tool_context" in tool_result["tool_result"]["content"][0]["text"].lower() + assert "validation" in tool_result["tool_result"]["content"][0]["text"].lower() @pytest.mark.asyncio @@ -1196,8 +1215,151 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_result = tool_results[0] # Should succeed with the string parameter - assert tool_result == { - "status": "success", - "content": [{"text": "success"}], + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [{"text": "success"}], + "toolUseId": "test-id-2", + } + ) + + +@pytest.mark.asyncio +async def test_tool_async_generator(): + """Test that async generators yield results appropriately.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 0 + yield "Value 1" + yield {"nested": "value"} + yield { + "status": "success", + "content": [{"text": "Looks like tool result"}], + "toolUseId": "test-id-2", + } + yield "final result" + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 0), + ToolStreamEvent(tool_use, "Value 1"), + ToolStreamEvent(tool_use, {"nested": "value"}), + ToolStreamEvent( + tool_use, + { + "status": "success", + "content": [{"text": "Looks like tool result"}], + "toolUseId": "test-id-2", + }, + ), + ToolStreamEvent(tool_use, "final result"), + ToolResultEvent( + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results + + +@pytest.mark.asyncio +async def test_tool_async_generator_exceptions_result_in_error(): + """Test that async generators handle exceptions.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 13 + raise ValueError("It's an error!") + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 13), + ToolResultEvent( + { + "status": "error", + "content": [{"text": "Error: It's an error!"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results + + +@pytest.mark.asyncio +async def test_tool_async_generator_yield_object_result(): + """Test that async generators handle exceptions.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 13 + yield { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + + tool: AgentTool = async_generator + tool_use: ToolUse = { "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 13), + ToolStreamEvent( + tool_use, + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + }, + ), + ToolResultEvent( + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 66494c987..ca3cded4c 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -120,3 +120,39 @@ def function() -> str: "tool_f", ] assert tru_tool_names == exp_tool_names + + +def test_register_tool_duplicate_name_without_hot_reload(): + """Test that registering a tool with duplicate name raises ValueError when hot reload is not supported.""" + tool_1 = PythonAgentTool(tool_name="duplicate_tool", tool_spec=MagicMock(), tool_func=lambda: None) + tool_2 = PythonAgentTool(tool_name="duplicate_tool", tool_spec=MagicMock(), tool_func=lambda: None) + + tool_registry = ToolRegistry() + tool_registry.register_tool(tool_1) + + with pytest.raises( + ValueError, match="Tool name 'duplicate_tool' already exists. Cannot register tools with exact same name." + ): + tool_registry.register_tool(tool_2) + + +def test_register_tool_duplicate_name_with_hot_reload(): + """Test that registering a tool with duplicate name succeeds when hot reload is supported.""" + # Create mock tools with hot reload support + tool_1 = MagicMock(spec=PythonAgentTool) + tool_1.tool_name = "hot_reload_tool" + tool_1.supports_hot_reload = True + tool_1.is_dynamic = False + + tool_2 = MagicMock(spec=PythonAgentTool) + tool_2.tool_name = "hot_reload_tool" + tool_2.supports_hot_reload = True + tool_2.is_dynamic = False + + tool_registry = ToolRegistry() + tool_registry.register_tool(tool_1) + + tool_registry.register_tool(tool_2) + + # Verify the second tool replaced the first + assert tool_registry.registry["hot_reload_tool"] == tool_2 diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 240c24717..b305a1a90 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -9,6 +9,7 @@ validate_tool_use, validate_tool_use_name, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -506,5 +507,5 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"tool_use": 1}, {"a": 2}) tru_events = await alist(stream) - exp_events = [({"tool_use": 1}, 2)] + exp_events = [ToolResultEvent(({"tool_use": 1}, 2))] assert tru_events == exp_events diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 4683918cb..e25bf3cca 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -138,9 +138,25 @@ def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processi response1 = agent("Say the word.") response2 = agent("Hello!") assert response1.stop_reason == "guardrail_intervened" - assert BLOCKED_OUTPUT in str(response1) - assert response2.stop_reason != "guardrail_intervened" - assert BLOCKED_OUTPUT not in str(response2) + + """ + In async streaming: The buffering is non-blocking. + Tokens are streamed while Guardrails processes the buffered content in the background. + This means the response may be returned before Guardrails has finished processing. + As a result, we cannot guarantee that the REDACT_MESSAGE is in the response + """ + if processing_mode == "sync": + assert BLOCKED_OUTPUT in str(response1) + assert response2.stop_reason != "guardrail_intervened" + assert BLOCKED_OUTPUT not in str(response2) + else: + cactus_returned_in_response1_blocked_by_input_guardrail = BLOCKED_INPUT in str(response2) + cactus_blocked_in_response1_allows_next_response = ( + BLOCKED_OUTPUT not in str(response2) and response2.stop_reason != "guardrail_intervened" + ) + assert ( + cactus_returned_in_response1_blocked_by_input_guardrail or cactus_blocked_in_response1_allows_next_response + ) @pytest.mark.parametrize("processing_mode", ["sync", "async"]) @@ -164,10 +180,27 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi response1 = agent("Say the word.") response2 = agent("Hello!") + assert response1.stop_reason == "guardrail_intervened" - assert REDACT_MESSAGE in str(response1) - assert response2.stop_reason != "guardrail_intervened" - assert REDACT_MESSAGE not in str(response2) + + """ + In async streaming: The buffering is non-blocking. + Tokens are streamed while Guardrails processes the buffered content in the background. + This means the response may be returned before Guardrails has finished processing. + As a result, we cannot guarantee that the REDACT_MESSAGE is in the response + """ + if processing_mode == "sync": + assert REDACT_MESSAGE in str(response1) + assert response2.stop_reason != "guardrail_intervened" + assert REDACT_MESSAGE not in str(response2) + else: + cactus_returned_in_response1_blocked_by_input_guardrail = BLOCKED_INPUT in str(response2) + cactus_blocked_in_response1_allows_next_response = ( + REDACT_MESSAGE not in str(response2) and response2.stop_reason != "guardrail_intervened" + ) + assert ( + cactus_returned_in_response1_blocked_by_input_guardrail or cactus_blocked_in_response1_allows_next_response + ) def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir):