diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fd857707c..424e9a39e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -34,7 +34,7 @@ from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.models import Model -from ..types.tools import ToolConfig +from ..types.tools import ToolConfig, ToolResult from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -182,7 +182,7 @@ def caller(**kwargs: Any) -> Any: } # Execute the tool - tool_result = self._agent.tool_handler.process( + events = self._agent.tool_handler.process( tool=tool_use, model=self._agent.model, system_prompt=self._agent.system_prompt, @@ -194,6 +194,7 @@ def caller(**kwargs: Any) -> Any: agent=self._agent, **handler_kwargs, ) + tool_result = list(events)[-1] if record_direct_tool_call: # Create a record of this tool execution in the message history @@ -576,7 +577,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs def _record_tool_execution( self, tool: Dict[str, Any], - tool_result: Dict[str, Any], + tool_result: ToolResult, user_message_override: Optional[str], messages: List[Dict[str, Any]], ) -> None: diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 96e3637f0..b3b4a5444 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -12,7 +12,7 @@ import time import uuid from functools import partial -from typing import Any, Callable, Generator, Optional, cast +from typing import Any, Callable, Generator, Optional from ..telemetry.metrics import EventLoopMetrics, Trace from ..telemetry.tracer import get_tracer @@ -352,11 +352,10 @@ def _handle_tool_execution( **kwargs, ) - run_tools( + yield from run_tools( handler=tool_handler_process, tool_uses=tool_uses, event_loop_metrics=event_loop_metrics, - request_state=cast(Any, kwargs["request_state"]), invalid_tool_use_ids=invalid_tool_use_ids, tool_results=tool_results, cycle_trace=cycle_trace, diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py index bc4ec1ce9..9e76e5516 100644 --- a/src/strands/handlers/tool_handler.py +++ b/src/strands/handlers/tool_handler.py @@ -1,7 +1,7 @@ """This module provides handlers for managing tool invocations.""" import logging -from typing import Any, List, Optional +from typing import Any, Generator, Optional, Union from ..tools.registry import ToolRegistry from ..types.models import Model @@ -49,11 +49,11 @@ def process( *, model: Model, system_prompt: Optional[str], - messages: List[Any], + messages: list[Any], tool_config: Any, callback_handler: Any, **kwargs: Any, - ) -> Any: + ) -> Generator[Union[ToolResult, Any], None, None]: """Process a tool invocation. Looks up the tool in the registry and invokes it with the provided parameters. @@ -67,10 +67,10 @@ def process( callback_handler: Callback for processing events as they happen. **kwargs: Additional keyword arguments passed to the tool. - Returns: - The result of the tool invocation, or an error response if the tool fails or is not found. + Yields: + Events of the tool invocation. The final event is always the tool result. """ - logger.debug("tool=<%s> | invoking", tool) + logger.debug("tool=<%s> | streaming", tool) tool_use_id = tool["toolUseId"] tool_name = tool["name"] @@ -86,11 +86,13 @@ def process( tool_name, list(self.tool_registry.registry.keys()), ) - return { + yield { "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } + return + # Add standard arguments to kwargs for Python tools kwargs.update( { @@ -102,11 +104,11 @@ def process( } ) - return tool_func.invoke(tool, **kwargs) + yield from tool_func.stream(tool, **kwargs) except Exception as e: logger.exception("tool_name=<%s> | failed to process tool", tool_name) - return { + yield { "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {str(e)}"}], diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 3d579c3a0..f536be2d0 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -46,7 +46,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: from typing import ( Any, Callable, - Dict, + Generator, Generic, Optional, ParamSpec, @@ -119,7 +119,7 @@ def _create_input_model(self) -> Type[BaseModel]: Returns: A Pydantic BaseModel class customized for the function's parameters. """ - field_definitions: Dict[str, Any] = {} + field_definitions: dict[str, Any] = {} for name, param in self.signature.parameters.items(): # Skip special parameters @@ -179,7 +179,7 @@ def extract_metadata(self) -> ToolSpec: return tool_spec - def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None: + def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: """Clean up Pydantic schema to match Strands' expected format. Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could @@ -227,7 +227,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None: if key in prop_schema: del prop_schema[key] - def validate_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]: """Validate input data using the Pydantic model. This method ensures that the input data meets the expected schema before it's passed to the actual function. It @@ -353,12 +353,15 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: # This block is only for backwards compatability so we cast as any for now logger.warning( "issue=<%s> | " - "passing tool use into a function instead of using .invoke will be removed in a future release", + "passing tool use into a function instead of using .stream will be removed in a future release", "https://github.com/strands-agents/sdk-python/pull/258", ) tool_use = cast(Any, args[0]) - return cast(R, self.invoke(tool_use, **kwargs)) + events = self.stream(tool_use, **kwargs) + result = list(events)[-1] + + return cast(R, result) return self.original_function(*args, **kwargs) @@ -389,7 +392,8 @@ def tool_type(self) -> str: """ return "function" - def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + @override + def stream(self, tool: ToolUse, *args: Any, **kwargs: Any) -> Generator[Union[ToolResult, Any], None, None]: """Invoke the tool with a tool use specification. This method handles tool use invocations from a Strands Agent. It validates the input, @@ -408,8 +412,8 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes *args: Additional positional arguments (not typically used). **kwargs: Additional keyword arguments, may include 'agent' reference. - Returns: - A standardized tool result dictionary with status and content. + Yields: + Events of the tool invocation. The final event is always the tool result. """ # This is a tool use call - process accordingly tool_use = tool @@ -424,27 +428,35 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes if "agent" in kwargs and "agent" in self._metadata.signature.parameters: validated_input["agent"] = kwargs.get("agent") + # User will need to piece together a tool result themselves if using a generator + if inspect.isgeneratorfunction(self.original_function): + validated_input["tool_use_id"] = tool_use_id + # We get "too few arguments here" but because that's because fof the way we're calling it result = self.original_function(**validated_input) # type: ignore + if inspect.isgenerator(result): + yield from result + return # FORMAT THE RESULT for Strands Agent if isinstance(result, dict) and "status" in result and "content" in result: # Result is already in the expected format, just add toolUseId result["toolUseId"] = tool_use_id - return cast(ToolResult, result) - else: - # Wrap any other return value in the standard format - # Always include at least one content item for consistency - return { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": str(result)}], - } + yield result + return + + # Wrap any other return value in the standard format + # Always include at least one content item for consistency + yield { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": str(result)}], + } except ValueError as e: # Special handling for validation errors error_msg = str(e) - return { + yield { "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error_msg}"}], @@ -453,7 +465,7 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes # Return error result with exception details for any other error error_type = type(e).__name__ error_msg = str(e) - return { + yield { "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error_type} - {error_msg}"}], diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index c90202393..2be7dadb3 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -1,9 +1,10 @@ """Tool execution functionality for the event loop.""" import logging +import queue +import threading import time -from concurrent.futures import TimeoutError -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Generator, Optional, Union from opentelemetry import trace @@ -18,128 +19,108 @@ def run_tools( - handler: Callable[[ToolUse], ToolResult], - tool_uses: List[ToolUse], + handler: Callable[[ToolUse], Generator[Union[ToolResult, Any], None, None]], + tool_uses: list[ToolUse], event_loop_metrics: EventLoopMetrics, - request_state: Any, - invalid_tool_use_ids: List[str], - tool_results: List[ToolResult], + invalid_tool_use_ids: list[str], + tool_results: list[ToolResult], cycle_trace: Trace, parent_span: Optional[trace.Span] = None, parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None, -) -> bool: +) -> Generator[dict[str, Any], None, None]: """Execute tools either in parallel or sequentially. Args: handler: Tool handler processing function. tool_uses: List of tool uses to execute. event_loop_metrics: Metrics collection object. - request_state: Current request state. invalid_tool_use_ids: List of invalid tool use IDs. tool_results: List to populate with tool results. cycle_trace: Parent trace for the current cycle. parent_span: Parent span for the current cycle. parallel_tool_executor: Optional executor for parallel processing. - Returns: - bool: True if any tool failed, False otherwise. + Yields: + . """ - def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]: - result = None - tool_succeeded = False - + def handle(tool: ToolUse) -> Generator[dict[str, Any], None, None]: tracer = get_tracer() tool_call_span = tracer.start_tool_call_span(tool, parent_span) - try: - if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids: - tool_name = tool["name"] - tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) - tool_start_time = time.time() - result = handler(tool) - tool_success = result.get("status") == "success" - if tool_success: - tool_succeeded = True - - tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) - cycle_trace.add_child(tool_trace) - - if tool_call_span: - tracer.end_tool_call_span(tool_call_span, result) - except Exception as e: - if tool_call_span: - tracer.end_span_with_error(tool_call_span, str(e), e) - - return tool_succeeded, result - - any_tool_failed = False + tool_name = tool["name"] + tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) + tool_start_time = time.time() + + for event in handler(tool): + yield {"callback": event} + + result = event + + tool_success = result.get("status") == "success" + tool_duration = time.time() - tool_start_time + message = Message(role="user", content=[{"toolResult": result}]) + event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) + cycle_trace.add_child(tool_trace) + + if tool_call_span: + tracer.end_tool_call_span(tool_call_span, result) + + def work( + tool: ToolUse, + worker_id: int, + worker_queue: queue.Queue, + worker_event: threading.Event, + worker_lock: threading.Lock, + ) -> None: + for event in handle(tool): + worker_queue.put((worker_id, event)) + worker_event.wait() + + with worker_lock: + tool_results.append(event["callback"]) + + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] + if parallel_tool_executor: logger.debug( "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel", len(tool_uses), type(parallel_tool_executor).__name__, ) - # Submit all tasks with their associated tools - future_to_tool = { - parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses - } + + worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() + worker_events = [threading.Event() for _ in range(len(tool_uses))] + worker_lock = threading.Lock() + + workers = [ + parallel_tool_executor.submit( + work, tool_use, worker_id, worker_queue, worker_events[worker_id], worker_lock + ) + for worker_id, tool_use in enumerate(tool_uses) + ] logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) - # Collect results truly in parallel using the provided executor's as_completed method - completed_results = [] - try: - for future in parallel_tool_executor.as_completed(future_to_tool): - try: - succeeded, result = future.result() - if result is not None: - completed_results.append(result) - if not succeeded: - any_tool_failed = True - except Exception as e: - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e) - any_tool_failed = True - except TimeoutError: - logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout) - # Process any completed tasks - for future in future_to_tool: - if future.done(): # type: ignore - try: - succeeded, result = future.result(timeout=0) - if result is not None: - completed_results.append(result) - except Exception as tool_e: - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e) - else: - # This future didn't complete within the timeout - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution timed out", tool["name"]) - - any_tool_failed = True - - # Add completed results to tool_results - tool_results.extend(completed_results) + while not all(worker.done() for worker in workers): + if not worker_queue.empty(): + worker_id, event = worker_queue.get() + yield event + worker_events[worker_id].set() + else: # Sequential execution fallback for tool_use in tool_uses: - succeeded, result = _handle_tool_execution(tool_use) - if result is not None: - tool_results.append(result) - if not succeeded: - any_tool_failed = True + for event in handle(tool_use): + yield event - return any_tool_failed + tool_results.append(event["callback"]) def validate_and_prepare_tools( message: Message, - tool_uses: List[ToolUse], - tool_results: List[ToolResult], - invalid_tool_use_ids: List[str], + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + invalid_tool_use_ids: list[str], ) -> None: """Validate tool uses and prepare them for execution. diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 1b3cfddbc..fd47cf2a7 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -191,7 +191,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: if not callable(tool_func): raise TypeError(f"Tool {tool_name} function is not callable") - return PythonAgentTool(tool_name, tool_spec, callback=tool_func) + return PythonAgentTool(tool_name, tool_spec, tool_func) except Exception: logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index e24c30b48..5fde8719b 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -6,9 +6,10 @@ """ import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generator, Union from mcp.types import Tool as MCPTool +from typing_extensions import override from ...types.tools import AgentTool, ToolResult, ToolSpec, ToolUse @@ -73,13 +74,14 @@ def tool_type(self) -> str: """ return "python" - def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + @override + def stream(self, tool: ToolUse, *args: Any, **kwargs: Any) -> Generator[Union[ToolResult, Any], None, None]: """Invoke the MCP tool. This method delegates the tool invocation to the MCP server connection, passing the tool use ID, tool name, and input arguments. """ logger.debug("invoking MCP tool '%s' with tool_use_id=%s", self.tool_name, tool["toolUseId"]) - return self.mcp_client.call_tool_sync( + yield self.mcp_client.call_tool_sync( tool_use_id=tool["toolUseId"], name=self.tool_name, arguments=tool["input"] ) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index acaf6e368..4ded2e766 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -349,7 +349,7 @@ def reload_tool(self, tool_name: str) -> None: new_tool = PythonAgentTool( tool_name=tool_name, tool_spec=module.TOOL_SPEC, - callback=tool_function, + tool_func=tool_function, ) # Register the tool @@ -433,7 +433,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None: tool = PythonAgentTool( tool_name=tool_name, tool_spec=tool_spec, - callback=tool_function, + tool_func=tool_function, ) self.register_tool(tool) successful_loads += 1 @@ -465,7 +465,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None: tool = PythonAgentTool( tool_name=tool_name, tool_spec=tool_spec, - callback=tool_function, + tool_func=tool_function, ) self.register_tool(tool) successful_loads += 1 diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 8dde1d09e..09e347b8d 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -7,11 +7,11 @@ import inspect import logging import re -from typing import Any, Callable, Dict, Optional, cast +from typing import Any, Generator, Optional, Union, cast -from typing_extensions import Unpack +from typing_extensions import override -from ..types.tools import AgentTool, ToolResult, ToolSpec, ToolUse +from ..types.tools import AgentTool, ToolFunc, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -63,7 +63,7 @@ def validate_tool_use_name(tool: ToolUse) -> None: raise InvalidToolUseNameException(message) -def _normalize_property(prop_name: str, prop_def: Any) -> Dict[str, Any]: +def _normalize_property(prop_name: str, prop_def: Any) -> dict[str, Any]: """Normalize a single property definition. Args: @@ -86,7 +86,7 @@ def _normalize_property(prop_name: str, prop_def: Any) -> Dict[str, Any]: return normalized_prop -def normalize_schema(schema: Dict[str, Any]) -> Dict[str, Any]: +def normalize_schema(schema: dict[str, Any]) -> dict[str, Any]: """Normalize a JSON schema to match expectations. This function recursively processes nested objects to preserve the complete schema structure. @@ -152,11 +152,11 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: return super().__new__(cls) - def __init__(self, func: Callable[[ToolUse, Unpack[Any]], ToolResult], tool_name: Optional[str] = None) -> None: + def __init__(self, tool_func: ToolFunc, tool_name: Optional[str] = None) -> None: """Initialize a function-based tool. Args: - func: The decorated function. + tool_func: The decorated function. tool_name: Optional tool name (defaults to function name). Raises: @@ -164,19 +164,19 @@ def __init__(self, func: Callable[[ToolUse, Unpack[Any]], ToolResult], tool_name """ super().__init__() - self._func = func + self._tool_func = tool_func # Get TOOL_SPEC from the decorated function - if hasattr(func, "TOOL_SPEC") and isinstance(func.TOOL_SPEC, dict): - self._tool_spec = cast(ToolSpec, func.TOOL_SPEC) + if hasattr(tool_func, "TOOL_SPEC") and isinstance(tool_func.TOOL_SPEC, dict): + self._tool_spec = cast(ToolSpec, tool_func.TOOL_SPEC) # Use name from tool spec if available, otherwise use function name or passed tool_name - name = self._tool_spec.get("name", tool_name or func.__name__) + name = self._tool_spec.get("name", tool_name or tool_func.__name__) if isinstance(name, str): self._name = name else: raise ValueError(f"Tool name must be a string, got {type(name)}") else: - raise ValueError(f"Function {func.__name__} is not decorated with @tool") + raise ValueError(f"Function {tool_func.__name__} is not decorated with @tool") @property def tool_name(self) -> str: @@ -214,7 +214,8 @@ def supports_hot_reload(self) -> bool: """ return True - def invoke(self, tool: ToolUse, *args: Any, **kwargs: Any) -> ToolResult: + @override + def stream(self, tool: ToolUse, *args: Any, **kwargs: Any) -> Generator[Union[ToolResult, Any], None, None]: """Execute the function with the given tool use request. Args: @@ -222,37 +223,39 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: Any) -> ToolResult: *args: Additional positional arguments to pass to the function. **kwargs: Additional keyword arguments to pass to the function. - Returns: - A ToolResult containing the status and content from the function execution. + Yields: + Events of the tool invocation. The final event is always the tool result. """ - # Make sure to pass through all kwargs, including 'agent' if provided + # Check if the function accepts agent as a keyword argument + sig = inspect.signature(self._tool_func) + if "agent" not in sig.parameters: + # Skip passing agent if function doesn't accept it + kwargs = {k: v for k, v in kwargs.items() if k != "agent"} + try: - # Check if the function accepts agent as a keyword argument - sig = inspect.signature(self._func) - if "agent" in sig.parameters: - # Pass agent if function accepts it - return self._func(tool, **kwargs) + result = self._tool_func(tool, *args, **kwargs) + if inspect.isgenerator(result): + yield from result else: - # Skip passing agent if function doesn't accept it - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "agent"} - return self._func(tool, **filtered_kwargs) + yield result + except Exception as e: - return { + yield { "toolUseId": tool.get("toolUseId", "unknown"), "status": "error", "content": [{"text": f"Error executing function: {str(e)}"}], } @property - def original_function(self) -> Callable: + def original_function(self) -> ToolFunc: """Get the original function (without wrapper). Returns: Undecorated function. """ - if hasattr(self._func, "original_function"): - return cast(Callable, self._func.original_function) - return self._func + if hasattr(self._tool_func, "original_function"): + return cast(ToolFunc, self._tool_func.original_function) + return self._tool_func def get_display_properties(self) -> dict[str, str]: """Get properties to display in UI representations. @@ -272,25 +275,23 @@ class PythonAgentTool(AgentTool): as SDK tools. """ - _callback: Callable[[ToolUse, Any, dict[str, Any]], ToolResult] + _tool_func: ToolFunc _tool_name: str _tool_spec: ToolSpec - def __init__( - self, tool_name: str, tool_spec: ToolSpec, callback: Callable[[ToolUse, Any, dict[str, Any]], ToolResult] - ) -> None: + def __init__(self, tool_name: str, tool_spec: ToolSpec, tool_func: ToolFunc) -> None: """Initialize a Python-based tool. Args: tool_name: Unique identifier for the tool. tool_spec: Tool specification defining parameters and behavior. - callback: Python function to execute when the tool is invoked. + tool_func: Python function to execute when the tool is invoked. """ super().__init__() self._tool_name = tool_name self._tool_spec = tool_spec - self._callback = callback + self._tool_func = tool_func @property def tool_name(self) -> str: @@ -319,7 +320,8 @@ def tool_type(self) -> str: """ return "python" - def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + @override + def stream(self, tool: ToolUse, *args: Any, **kwargs: Any) -> Generator[Union[ToolResult, Any], None, None]: """Execute the Python function with the given tool use request. Args: @@ -327,7 +329,11 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes *args: Additional positional arguments to pass to the underlying callback function. **kwargs: Additional keyword arguments to pass to the underlying callback function. - Returns: - A ToolResult containing the status and content from the callback execution. + Yields: + Events of the tool invocation. The final event is always the tool result. """ - return self._callback(tool, *args, **kwargs) + result = self._tool_func(tool, *args, **kwargs) + if inspect.isgenerator(result): + yield from result + else: + yield result diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index bbf4df95b..08ad8dc0d 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -65,6 +65,9 @@ def result(self, timeout: Optional[int] = None) -> Any: Any: The result of the asynchronous operation. """ + def done(self) -> bool: + """Returns true if future is done executing.""" + @runtime_checkable class ParallelToolExecutorInterface(Protocol): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index ef536f01d..c70972321 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Generator, Literal, Optional, Protocol, Union from typing_extensions import TypedDict @@ -90,7 +90,7 @@ class ToolResult(TypedDict): toolUseId: The unique identifier of the tool use request that produced this result. """ - content: List[ToolResultContent] + content: list[ToolResultContent] status: ToolResultStatus toolUseId: str @@ -122,9 +122,9 @@ class ToolChoiceTool(TypedDict): ToolChoice = Union[ - Dict[Literal["auto"], ToolChoiceAuto], - Dict[Literal["any"], ToolChoiceAny], - Dict[Literal["tool"], ToolChoiceTool], + dict[Literal["auto"], ToolChoiceAuto], + dict[Literal["any"], ToolChoiceAny], + dict[Literal["tool"], ToolChoiceTool], ] """ Configuration for how the model should choose tools. @@ -143,10 +143,32 @@ class ToolConfig(TypedDict): toolChoice: Configuration for how the model should choose tools. """ - tools: List[Tool] + tools: list[Tool] toolChoice: ToolChoice +class ToolFunc(Protocol): + """Python based tool function.""" + + __name__: str + + def __call__( + self, *args: Any, **kwargs: Any + ) -> Union[ + ToolResult, + Coroutine[Any, Any, ToolResult], + Generator[Union[ToolResult, Any], None, None], + AsyncGenerator[Union[ToolResult, Any], None], + ]: + """Function signature. + + Returns: + If a generator, yields invocation events with the last being a tool result. Non-generators return a tool + result directly. + """ + ... + + class AgentTool(ABC): """Abstract base class for all SDK tools. @@ -195,7 +217,7 @@ def supports_hot_reload(self) -> bool: @abstractmethod # pragma: no cover - def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + def stream(self, tool: ToolUse, *args: Any, **kwargs: Any) -> Generator[Union[ToolResult, Any], None, None]: """Execute the tool's functionality with the given tool use request. Args: @@ -203,8 +225,8 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes *args: Positional arguments to pass to the tool. **kwargs: Keyword arguments to pass to the tool. - Returns: - The result of the tool execution. + Yields: + Events of the tool invocation. The final event is always the tool result. """ pass @@ -272,7 +294,7 @@ def process( tool_config: ToolConfig, callback_handler: Any, **kwargs: Any, - ) -> ToolResult: + ) -> Generator[Union[ToolResult, Any], None, None]: """Process a tool use request and execute the tool. Args: @@ -284,7 +306,7 @@ def process( callback_handler: Callback for processing events as they happen. **kwargs: Additional context-specific arguments. - Returns: - The result of the tool execution. + Yields: + Events of the tool invocation. The final event is always the tool result. """ ... diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c813a1a91..c9f75e6f9 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -775,6 +775,7 @@ def test_agent_init_with_no_model_or_model_id(): def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint): agent.tool_handler = unittest.mock.Mock() + agent.tool_handler.process.return_value = [{}] @strands.tools.tool(name="system_prompter") def function(system_prompt: str) -> str: @@ -805,6 +806,7 @@ def function(system_prompt: str) -> str: def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint): agent.tool_handler = unittest.mock.Mock() + agent.tool_handler.process.return_value = [{}] tool_name = "system-prompter" diff --git a/tests/strands/handlers/test_tool_handler.py b/tests/strands/handlers/test_tool_handler.py index 0684fbe0c..98a7bb151 100644 --- a/tests/strands/handlers/test_tool_handler.py +++ b/tests/strands/handlers/test_tool_handler.py @@ -45,7 +45,7 @@ def test_preprocess(tool_handler, tool_use_identity): def test_process(tool_handler, tool_use_identity): - tru_result = tool_handler.process( + stream = tool_handler.process( tool_use_identity, model=unittest.mock.Mock(), system_prompt="p1", @@ -53,13 +53,15 @@ def test_process(tool_handler, tool_use_identity): tool_config={}, callback_handler=unittest.mock.Mock(), ) + + tru_result = list(stream)[-1] exp_result = {"toolUseId": "identity", "status": "success", "content": [{"text": "1"}]} assert tru_result == exp_result def test_process_missing_tool(tool_handler): - tru_result = tool_handler.process( + stream = tool_handler.process( tool={"toolUseId": "missing", "name": "missing", "input": {}}, model=unittest.mock.Mock(), system_prompt="p1", @@ -67,6 +69,8 @@ def test_process_missing_tool(tool_handler): tool_config={}, callback_handler=unittest.mock.Mock(), ) + + tru_result = list(stream)[-1] exp_result = { "toolUseId": "missing", "status": "error", diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index eba4ad6c2..ab2af5273 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -57,10 +57,10 @@ def test_tool_spec_without_description(mock_mcp_tool, mock_mcp_client): assert tool_spec["description"] == "Tool which performs test_tool" -def test_invoke(mcp_agent_tool, mock_mcp_client): +def test_stream(mcp_agent_tool, mock_mcp_client): tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} - result = mcp_agent_tool.invoke(tool_use) + result = list(mcp_agent_tool.stream(tool_use))[-1] mock_mcp_client.call_tool_sync.assert_called_once_with( tool_use_id="test-123", name="test_tool", arguments={"param": "value"} diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 50333474c..64cc019dc 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -50,7 +50,7 @@ def test_tool(param1: str, param2: int) -> str: # Test actual usage tool_use = {"toolUseId": "test-id", "input": {"param1": "hello", "param2": 42}} - result = test_tool.invoke(tool_use) + result = list(test_tool.stream(tool_use))[-1] assert result["toolUseId"] == "test-id" assert result["status"] == "success" assert result["content"][0]["text"] == "Result: hello 42" @@ -98,14 +98,14 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: # Test with only required param tool_use = {"toolUseId": "test-id", "input": {"required": "hello"}} - result = test_tool.invoke(tool_use) + result = list(test_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "Result: hello" # Test with both params tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "optional": 42}} - result = test_tool.invoke(tool_use) + result = list(test_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "Result: hello 42" @@ -123,7 +123,7 @@ def test_tool(required: str) -> str: # Test with missing required param tool_use = {"toolUseId": "test-id", "input": {}} - result = test_tool.invoke(tool_use) + result = list(test_tool.stream(tool_use))[-1] assert result["status"] == "error" assert "validation error for test_tooltool\nrequired\n" in result["content"][0]["text"].lower(), ( "Validation error should indicate which argument is missing" @@ -132,7 +132,7 @@ def test_tool(required: str) -> str: # Test with exception in tool function tool_use = {"toolUseId": "test-id", "input": {"required": "error"}} - result = test_tool.invoke(tool_use) + result = list(test_tool.stream(tool_use))[-1] assert result["status"] == "error" assert "test error" in result["content"][0]["text"].lower(), ( "Runtime error should contain the original error message" @@ -176,11 +176,11 @@ def test_tool(param: str, agent=None) -> str: tool_use = {"toolUseId": "test-id", "input": {"param": "test"}} # Test without agent - result = test_tool.invoke(tool_use) + result = list(test_tool.stream(tool_use))[-1] assert result["content"][0]["text"] == "Param: test" # Test with agent - result = test_tool.invoke(tool_use, agent=mock_agent) + result = list(test_tool.stream(tool_use, agent=mock_agent))[-1] assert "Agent:" in result["content"][0]["text"] assert "test" in result["content"][0]["text"] @@ -231,18 +231,18 @@ def none_return_tool(param: str) -> None: # Test the dict return - should preserve dict format but add toolUseId tool_use: ToolUse = {"toolUseId": "test-id", "input": {"param": "test"}} - result = dict_return_tool.invoke(tool_use) + result = list(dict_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "Result: test" assert result["toolUseId"] == "test-id" # Test the string return - should wrap in standard format - result = string_return_tool.invoke(tool_use) + result = list(string_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "Result: test" # Test None return - should still create valid ToolResult with "None" text - result = none_return_tool.invoke(tool_use) + result = list(none_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "None" @@ -277,7 +277,7 @@ def test_method(self, param: str) -> str: # Test tool-style call tool_use = {"toolUseId": "test-id", "input": {"param": "tool-value"}} - result = instance.test_method.invoke(tool_use) + result = list(instance.test_method.stream(tool_use))[-1] assert "Test: tool-value" in result["content"][0]["text"] @@ -294,7 +294,7 @@ class MyThing: ... result = instance.field("example") assert result == "param: example" - result2 = instance.field.invoke({"toolUseId": "test-id", "input": {"param": "example"}}) + result2 = list(instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}))[-1] assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} @@ -314,7 +314,7 @@ def test_method(param: str) -> str: result = instance.field("example") assert result == "param: example" - result2 = instance.field.invoke({"toolUseId": "test-id", "input": {"param": "example"}}) + result2 = list(instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}))[-1] assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} @@ -341,12 +341,12 @@ def tool_with_defaults(required: str, optional: str = "default", number: int = 4 # Call with just required parameter tool_use = {"toolUseId": "test-id", "input": {"required": "hello"}} - result = tool_with_defaults.invoke(tool_use) + result = list(tool_with_defaults.stream(tool_use))[-1] assert result["content"][0]["text"] == "hello default 42" # Call with some but not all optional parameters tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "number": 100}} - result = tool_with_defaults.invoke(tool_use) + result = list(tool_with_defaults.stream(tool_use))[-1] assert result["content"][0]["text"] == "hello default 100" @@ -359,12 +359,12 @@ def test_tool(required: str) -> str: return f"Got: {required}" # Test with completely empty tool use - result = test_tool.invoke({}) + result = list(test_tool.stream({}))[-1] assert result["status"] == "error" assert "unknown" in result["toolUseId"] # Test with missing input - result = test_tool.invoke({"toolUseId": "test-id"}) + result = list(test_tool.stream({"toolUseId": "test-id"}))[-1] assert result["status"] == "error" assert "test-id" in result["toolUseId"] @@ -388,7 +388,7 @@ def add_numbers(a: int, b: int) -> int: # Call through tool interface tool_use = {"toolUseId": "test-id", "input": {"a": 2, "b": 3}} - result = add_numbers.invoke(tool_use) + result = list(add_numbers.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "5" @@ -421,7 +421,7 @@ def multi_default_tool( # Test calling with only required parameter tool_use = {"toolUseId": "test-id", "input": {"required_param": "hello"}} - result = multi_default_tool.invoke(tool_use) + result = list(multi_default_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "hello, default_str, 42, True, 3.14" in result["content"][0]["text"] @@ -430,7 +430,7 @@ def multi_default_tool( "toolUseId": "test-id", "input": {"required_param": "hello", "optional_int": 100, "optional_float": 2.718}, } - result = multi_default_tool.invoke(tool_use) + result = list(multi_default_tool.stream(tool_use))[-1] assert "hello, default_str, 100, True, 2.718" in result["content"][0]["text"] @@ -454,7 +454,7 @@ def int_return_tool(param: str) -> int: # Test with return that matches declared type tool_use = {"toolUseId": "test-id", "input": {"param": "valid"}} - result = int_return_tool.invoke(tool_use) + result = list(int_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "42" @@ -462,13 +462,13 @@ def int_return_tool(param: str) -> int: # Note: This should still work because Python doesn't enforce return types at runtime # but the function will return a string instead of an int tool_use = {"toolUseId": "test-id", "input": {"param": "invalid_type"}} - result = int_return_tool.invoke(tool_use) + result = list(int_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "not an int" # Test with None return from a non-None return type tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} - result = int_return_tool.invoke(tool_use) + result = list(int_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "None" @@ -489,17 +489,17 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: # Test with each possible return type in the Union tool_use = {"toolUseId": "test-id", "input": {"param": "dict"}} - result = union_return_tool.invoke(tool_use) + result = list(union_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "{'key': 'value'}" in result["content"][0]["text"] or '{"key": "value"}' in result["content"][0]["text"] tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} - result = union_return_tool.invoke(tool_use) + result = list(union_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "string result" tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} - result = union_return_tool.invoke(tool_use) + result = list(union_return_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "None" @@ -520,7 +520,7 @@ def no_params_tool() -> str: # Test tool use call tool_use = {"toolUseId": "test-id", "input": {}} - result = no_params_tool.invoke(tool_use) + result = list(no_params_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "Success - no parameters needed" @@ -546,7 +546,7 @@ def complex_type_tool(config: Dict[str, Any]) -> str: # Call via tool use tool_use = {"toolUseId": "test-id", "input": {"config": nested_dict}} - result = complex_type_tool.invoke(tool_use) + result = list(complex_type_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "Got config with 3 keys" in result["content"][0]["text"] @@ -573,7 +573,7 @@ def custom_result_tool(param: str) -> Dict[str, Any]: # Test via tool use tool_use = {"toolUseId": "custom-id", "input": {"param": "test"}} - result = custom_result_tool.invoke(tool_use) + result = list(custom_result_tool.stream(tool_use))[-1] # The wrapper should preserve our format and just add the toolUseId assert result["status"] == "success" @@ -646,7 +646,7 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: "bool_param": True, }, } - result = validation_tool.invoke(tool_use) + result = list(validation_tool.stream(tool_use))[-1] assert result["status"] == "error" assert "int_param" in result["content"][0]["text"] @@ -659,7 +659,7 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: "bool_param": True, }, } - result = validation_tool.invoke(tool_use) + result = list(validation_tool.stream(tool_use))[-1] assert result["status"] == "error" assert "int_param" in result["content"][0]["text"] @@ -680,20 +680,20 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: # Test with None value tool_use = {"toolUseId": "test-id", "input": {"param": None}} - result = edge_case_tool.invoke(tool_use) + result = list(edge_case_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "None" # Test with empty dict tool_use = {"toolUseId": "test-id", "input": {"param": {}}} - result = edge_case_tool.invoke(tool_use) + result = list(edge_case_tool.stream(tool_use))[-1] assert result["status"] == "success" assert result["content"][0]["text"] == "{}" # Test with a complex nested dictionary nested_dict = {"key1": {"nested": [1, 2, 3]}, "key2": None} tool_use = {"toolUseId": "test-id", "input": {"param": nested_dict}} - result = edge_case_tool.invoke(tool_use) + result = list(edge_case_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "key1" in result["content"][0]["text"] assert "nested" in result["content"][0]["text"] @@ -740,7 +740,7 @@ def test_method(self): assert instance.test_method("test") == "Method Got: test" # Test direct function call - direct_result = instance.test_method.invoke({"toolUseId": "test-id", "input": {"param": "direct"}}) + direct_result = list(instance.test_method.stream({"toolUseId": "test-id", "input": {"param": "direct"}}))[-1] assert direct_result["status"] == "success" assert direct_result["content"][0]["text"] == "Method Got: direct" @@ -760,7 +760,7 @@ def standalone_tool(p1: str, p2: str = "default") -> str: assert result == "Standalone: param1, param2" # And that it works with tool use call too - tool_use_result = standalone_tool.invoke({"toolUseId": "test-id", "input": {"p1": "value1"}}) + tool_use_result = list(standalone_tool.stream({"toolUseId": "test-id", "input": {"p1": "value1"}}))[-1] assert tool_use_result["status"] == "success" assert tool_use_result["content"][0]["text"] == "Standalone: value1, default" @@ -789,7 +789,7 @@ def failing_tool(param: str) -> str: error_types = ["value_error", "type_error", "attribute_error", "key_error"] for error_type in error_types: tool_use = {"toolUseId": "test-id", "input": {"param": error_type}} - result = failing_tool.invoke(tool_use) + result = list(failing_tool.stream(tool_use))[-1] assert result["status"] == "error" error_message = result["content"][0]["text"] @@ -821,25 +821,25 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] # Test with a list tool_use = {"toolUseId": "test-id", "input": {"union_param": [1, 2, 3]}} - result = complex_schema_tool.invoke(tool_use) + result = list(complex_schema_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "list: [1, 2, 3]" in result["content"][0]["text"] # Test with a dict tool_use = {"toolUseId": "test-id", "input": {"union_param": {"key": "value"}}} - result = complex_schema_tool.invoke(tool_use) + result = list(complex_schema_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "dict:" in result["content"][0]["text"] assert "key" in result["content"][0]["text"] # Test with a string tool_use = {"toolUseId": "test-id", "input": {"union_param": "test_string"}} - result = complex_schema_tool.invoke(tool_use) + result = list(complex_schema_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "str: test_string" in result["content"][0]["text"] # Test with None tool_use = {"toolUseId": "test-id", "input": {"union_param": None}} - result = complex_schema_tool.invoke(tool_use) + result = list(complex_schema_tool.stream(tool_use))[-1] assert result["status"] == "success" assert "NoneType: None" in result["content"][0]["text"] diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index 4b2387923..8e2a39a35 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -18,7 +18,7 @@ def moto_autouse(moto_env): @pytest.fixture def tool_handler(request): def handler(tool_use): - return { + yield { **params, "toolUseId": tool_use["toolUseId"], } @@ -54,11 +54,6 @@ def event_loop_metrics(): return strands.telemetry.metrics.EventLoopMetrics() -@pytest.fixture -def request_state(): - return {} - - @pytest.fixture def invalid_tool_use_ids(request): return request.param if hasattr(request, "param") else [] @@ -92,24 +87,22 @@ def test_run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert not failed + list(stream) tru_results = tool_results exp_results = [ @@ -132,24 +125,22 @@ def test_run_tools_invalid_tool( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert failed + list(stream) tru_results = tool_results exp_results = [] @@ -162,24 +153,22 @@ def test_run_tools_failed_tool( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert failed + list(stream) tru_results = tool_results exp_results = [ @@ -222,23 +211,21 @@ def test_run_tools_sequential( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, None, # parallel_tool_executor ) - assert failed + list(stream) tru_results = tool_results exp_results = [ @@ -311,7 +298,6 @@ def test_run_tools_creates_and_ends_span_on_success( tool_uses, mock_metrics_client, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -329,17 +315,17 @@ def test_run_tools_creates_and_ends_span_on_success( tool_results = [] # Run the tool - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -359,7 +345,6 @@ def test_run_tools_creates_and_ends_span_on_failure( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -377,17 +362,17 @@ def test_run_tools_creates_and_ends_span_on_failure( tool_results = [] # Run the tool - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -399,96 +384,6 @@ def test_run_tools_creates_and_ends_span_on_failure( assert args[1]["status"] == "failed" -@unittest.mock.patch("strands.tools.executor.get_tracer") -def test_run_tools_handles_exception_in_tool_execution( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - cycle_trace, - parallel_tool_executor, -): - """Test that run_tools properly handles exceptions during tool execution.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Make the tool handler throw an exception - exception = ValueError("Test tool execution error") - mock_handler = unittest.mock.MagicMock(side_effect=exception) - - tool_results = [] - - # Run the tool - the exception should be caught inside run_tools and not propagate - # because of the try-except block in the new implementation - failed = strands.tools.executor.run_tools( - mock_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, - parallel_tool_executor, - ) - - # Tool execution should have failed - assert failed - - # Verify span was created - mock_tracer.start_tool_call_span.assert_called_once() - - # Verify span was ended with the error - mock_tracer.end_span_with_error.assert_called_once_with(mock_span, str(exception), exception) - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -def test_run_tools_with_invalid_tool_use_id_still_creates_span( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - cycle_trace, - parallel_tool_executor, -): - """Test that run_tools creates a span even when the tool use ID is invalid.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Mark the tool use ID as invalid - invalid_tool_use_ids = [tool_uses[0]["toolUseId"]] - - tool_results = [] - - # Run the tool - strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, - parallel_tool_executor, - ) - - # Verify span was created - mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], None) - - # Verify span was ended even though the tool wasn't executed - mock_tracer.end_tool_call_span.assert_called_once() - - @unittest.mock.patch("strands.tools.executor.get_tracer") @pytest.mark.parametrize( ("tool_uses", "invalid_tool_use_ids"), @@ -516,7 +411,6 @@ def test_run_tools_parallel_execution_with_spans( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -535,17 +429,17 @@ def test_run_tools_parallel_execution_with_spans( tool_results = [] # Run the tools - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify spans were created for both tools assert mock_tracer.start_tool_call_span.call_count == 2 diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 1b274f46b..31a152021 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -29,8 +29,8 @@ def test_process_tools_with_invalid_path(): def test_register_tool_with_similar_name_raises(): - tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None) - tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None) + tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), tool_func=lambda: None) + tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), tool_func=lambda: None) tool_registry = ToolRegistry() diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 37a0db2ee..26b550ebb 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -497,14 +497,16 @@ def test_get_display_properties(tool): assert tru_properties == exp_properties -def test_invoke(tool): - tru_output = tool.invoke({"input": {"a": 2}}) - exp_output = {"toolUseId": "unknown", "status": "success", "content": [{"text": "2"}]} +def test_stream(tool): + stream = tool.stream({"input": {"a": 2}}) - assert tru_output == exp_output + tru_result = list(stream)[-1] + exp_result = {"toolUseId": "unknown", "status": "success", "content": [{"text": "2"}]} + assert tru_result == exp_result -def test_invoke_with_agent(): + +def test_stream_with_agent(): @strands.tools.tool def identity(a: int, agent: dict = None): return a, agent @@ -513,14 +515,15 @@ def identity(a: int, agent: dict = None): # FunctionTool is a pass through for AgentTool instances until we remove it in a future release (#258) assert tool == identity - exp_output = {"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]} + stream = tool.stream({"input": {"a": 2}}, agent={"state": 1}) - tru_output = tool.invoke({"input": {"a": 2}}, agent={"state": 1}) + tru_result = list(stream)[-1] + exp_result = {"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]} - assert tru_output == exp_output + assert tru_result == exp_result -def test_invoke_exception(): +def test_stream_exception(): def identity(a: int): return a @@ -528,22 +531,24 @@ def identity(a: int): tool = FunctionTool(identity, tool_name="identity") - tru_output = tool.invoke({}, invalid=1) - exp_output = { + stream = tool.stream({}, invalid=1) + + tru_result = list(stream)[-1] + exp_result = { "toolUseId": "unknown", "status": "error", "content": [ { "text": ( "Error executing function: " - "test_invoke_exception..identity() " + "test_stream_exception..identity() " "got an unexpected keyword argument 'invalid'" ) } ], } - assert tru_output == exp_output + assert tru_result == exp_result # Tests from test_python_agent_tool.py @@ -566,7 +571,7 @@ def identity(tool_use, a): }, }, }, - callback=identity, + tool_func=identity, ) @@ -602,8 +607,10 @@ def test_python_tool_type(python_tool): assert tru_type == exp_type -def test_python_invoke(python_tool): - tru_output = python_tool.invoke({"tool_use": 1}, a=2) - exp_output = ({"tool_use": 1}, 2) +def test_python_stream(python_tool): + stream = python_tool.stream({"tool_use": 1}, a=2) + + tru_result = list(stream)[-1] + exp_result = ({"tool_use": 1}, 2) - assert tru_output == exp_output + assert tru_result == exp_result