diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 000000000..feca86135 Binary files /dev/null and b/.DS_Store differ diff --git a/pyproject.toml b/pyproject.toml index 8fb3ab743..6244b89bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dependencies = [ "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", - "opentelemetry-instrumentation-threading>=0.51b0,<1.00b0", ] [project.urls] @@ -83,12 +82,8 @@ openai = [ otel = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] -writer = [ - "writer-sdk>=2.2.0,<3.0.0" -] - a2a = [ - "a2a-sdk[sql]>=0.2.11", + "a2a-sdk>=0.2.6", "uvicorn>=0.34.2", "httpx>=0.28.1", "fastapi>=0.115.12", @@ -100,7 +95,7 @@ a2a = [ source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -124,7 +119,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -140,7 +135,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel","mistral"] [tool.hatch.envs.a2a] dev-mode = true @@ -148,10 +143,10 @@ features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"] [tool.hatch.envs.a2a.scripts] run = [ - "pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a {args}" + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a {args}" ] run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a --cov --cov-config=pyproject.toml {args}" + "pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a --cov --cov-config=pyproject.toml {args}" ] lint-check = [ "ruff check", @@ -164,11 +159,11 @@ python = ["3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] run = [ # excluding due to A2A and OTEL http exporter dependency conflict - "pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/strands/multiagent/a2a" + "pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/multiagent/a2a" ] run-cov = [ # excluding due to A2A and OTEL http exporter dependency conflict - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/strands/multiagent/a2a" + "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/multiagent/a2a" ] cov-combine = [] @@ -195,7 +190,7 @@ test = [ "hatch test --cover --cov-report html --cov-report xml {args}" ] test-integ = [ - "hatch test tests_integ {args}" + "hatch test tests-integ {args}" ] prepare = [ "hatch fmt --linter", @@ -230,7 +225,7 @@ ignore_missing_imports = true [tool.ruff] line-length = 120 -include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"] +include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"] [tool.ruff.lint] select = [ @@ -290,4 +285,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 000000000..b1a31ea5d Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index cbe36d2f9..cc11be043 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,6 +12,7 @@ import asyncio import json import logging +import os import random from concurrent.futures import ThreadPoolExecutor from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast @@ -30,7 +31,7 @@ from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.models import Model -from ..types.tools import ToolResult, ToolUse +from ..types.tools import ToolConfig, ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -127,18 +128,14 @@ def caller( "input": kwargs.copy(), } - async def acall() -> ToolResult: - async for event in run_tool(self._agent, tool_use, kwargs): - _ = event + # Execute the tool + events = run_tool(agent=self._agent, tool=tool_use, kwargs=kwargs) - return cast(ToolResult, event) - - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() + try: + while True: + next(events) + except StopIteration as stop: + tool_result = cast(ToolResult, stop.value) if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -189,6 +186,7 @@ def __init__( Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] ] = _DEFAULT_CALLBACK_HANDLER, conversation_manager: Optional[ConversationManager] = None, + max_parallel_tools: int = os.cpu_count() or 1, record_direct_tool_call: bool = True, load_tools_from_directory: bool = True, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, @@ -221,6 +219,8 @@ def __init__( If explicitly set to None, null_callback_handler is used. conversation_manager: Manager for conversation history and context window. Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None. + max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls. + Defaults to os.cpu_count() or 1. record_direct_tool_call: Whether to record direct tool calls in message history. Defaults to True. load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. @@ -232,6 +232,9 @@ def __init__( Defaults to None. state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. + + Raises: + ValueError: If max_parallel_tools is less than 1. """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] @@ -260,6 +263,14 @@ def __init__( ): self.trace_attributes[k] = v + # If max_parallel_tools is 1, we execute tools sequentially + self.thread_pool = None + self.thread_pool_wrapper = None + if max_parallel_tools > 1: + self.thread_pool = ThreadPoolExecutor(max_workers=max_parallel_tools) + elif max_parallel_tools < 1: + raise ValueError("max_parallel_tools must be greater than 0") + self.record_direct_tool_call = record_direct_tool_call self.load_tools_from_directory = load_tools_from_directory @@ -324,14 +335,32 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult: + @property + def tool_config(self) -> ToolConfig: + """Get the tool configuration for this agent. + + Returns: + The complete tool configuration. + """ + return self.tool_registry.initialize_tool_config() + + def __del__(self) -> None: + """Clean up resources when Agent is garbage collected. + + Ensures proper shutdown of the thread pool executor if one exists. + """ + if self.thread_pool: + self.thread_pool.shutdown(wait=False) + logger.debug("thread pool executor shutdown complete") + + def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to the conversation history, processes it through the model, executes any tool calls, and returns the final result. Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: The natural language prompt from the user. **kwargs: Additional parameters to pass through the event loop. Returns: @@ -350,14 +379,14 @@ def execute() -> AgentResult: future = executor.submit(execute) return future.result() - async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult: + async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to the conversation history, processes it through the model, executes any tool calls, and returns the final result. Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: The natural language prompt from the user. **kwargs: Additional parameters to pass through the event loop. Returns: @@ -436,7 +465,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[ finally: self._hooks.invoke_callbacks(EndRequestEvent(agent=self)) - async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]: + async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. This method provides an asynchronous interface for streaming agent events, allowing @@ -445,7 +474,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A async environments. Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: The natural language prompt from the user. **kwargs: Additional parameters to pass to the event loop. Returns: @@ -468,13 +497,10 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A """ callback_handler = kwargs.get("callback_handler", self.callback_handler) - content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - message: Message = {"role": "user", "content": content} - - self._start_agent_trace_span(message) + self._start_agent_trace_span(prompt) try: - events = self._run_loop(message, kwargs) + events = self._run_loop(prompt, kwargs) async for event in events: if "callback" in event: callback_handler(**event["callback"]) @@ -490,22 +516,18 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A self._end_agent_trace_span(error=e) raise - async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: - """Execute the agent's event loop with the given message and parameters. - - Args: - message: The user message to add to the conversation. - kwargs: Additional parameters to pass to the event loop. - - Yields: - Events from the event loop cycle. - """ + async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: + """Execute the agent's event loop with the given prompt and parameters.""" self._hooks.invoke_callbacks(StartRequestEvent(agent=self)) try: + # Extract key parameters yield {"callback": {"init_event_loop": True, **kwargs}} - self.messages.append(message) + # Set up the user message with optional knowledge base retrieval + message_content: list[ContentBlock] = [{"text": prompt}] + new_message: Message = {"role": "user", "content": message_content} + self.messages.append(new_message) # Execute the event loop cycle with retry logic for context limits events = self._execute_event_loop_cycle(kwargs) @@ -600,16 +622,16 @@ def _record_tool_execution( messages.append(tool_result_msg) messages.append(assistant_msg) - def _start_agent_trace_span(self, message: Message) -> None: + def _start_agent_trace_span(self, prompt: str) -> None: """Starts a trace span for the agent. Args: - message: The user message. + prompt: The natural language prompt from the user. """ model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None self.trace_span = self.tracer.start_agent_span( - message=message, + prompt=prompt, agent_name=self.name, model_id=model_id, tools=self.tool_names, diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 0c7cb4124..effb32e54 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,17 +11,15 @@ import logging import time import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator -from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent -from ..experimental.hooks.registry import get_registry from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException from ..types.streaming import Metrics, StopReason -from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ..types.tools import ToolGenerator, ToolResult, ToolUse from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages @@ -58,7 +56,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener - event_loop_cycle_span: Current tracing Span for this cycle Yields: - Model and tool stream events. The last event is a tuple containing: + Model and tool invocation events. The last event is a tuple containing: - StopReason: Reason the model stopped generating (e.g., "tool_use") - Message: The generated message from the model @@ -112,12 +110,10 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener model_id=model_id, ) - tool_specs = agent.tool_registry.get_all_tool_specs() - try: # TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding # to the callback handler. This will be revisited when migrating to strongly typed events. - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): + async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, agent.tool_config): if "callback" in event: yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}} @@ -174,6 +170,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener # If the model is requesting to use tools if stop_reason == "tool_use": + if agent.tool_config is None: + raise EventLoopException( + Exception("Model requested tool use but no tool config provided"), + kwargs["request_state"], + ) + # Handle tool execution events = _handle_tool_execution( stop_reason, @@ -252,119 +254,64 @@ async def recurse_event_loop(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGen recursive_trace.end() -async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator: +def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGenerator: """Process a tool invocation. - Looks up the tool in the registry and streams it with the provided parameters. + Looks up the tool in the registry and invokes it with the provided parameters. Args: agent: The agent for which the tool is being executed. - tool_use: The tool object to process, containing name and parameters. + tool: The tool object to process, containing name and parameters. kwargs: Additional keyword arguments passed to the tool. Yields: - Tool events with the last being the tool result. + Events of the tool invocation. + + Returns: + The final tool result or an error response if the tool fails or is not found. """ - logger.debug("tool_use=<%s> | streaming", tool_use) - tool_name = tool_use["name"] + logger.debug("tool=<%s> | invoking", tool) + tool_use_id = tool["toolUseId"] + tool_name = tool["name"] # Get the tool info tool_info = agent.tool_registry.dynamic_tools.get(tool_name) tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) - # Add standard arguments to kwargs for Python tools - kwargs.update( - { - "model": agent.model, - "system_prompt": agent.system_prompt, - "messages": agent.messages, - "tool_config": ToolConfig( # for backwards compatability - tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], - toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), - ), - } - ) - - before_event = get_registry(agent).invoke_callbacks( - BeforeToolInvocationEvent( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - kwargs=kwargs, - ) - ) - try: - selected_tool = before_event.selected_tool - tool_use = before_event.tool_use - # Check if tool exists - if not selected_tool: - if tool_func == selected_tool: - logger.error( - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", - tool_name, - list(agent.tool_registry.registry.keys()), - ) - else: - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - tool_name, - str(tool_use.get("toolUseId")), - ) - - result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), + if not tool_func: + logger.error( + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", + tool_name, + list(agent.tool_registry.registry.keys()), + ) + return { + "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - # for every Before event call, we need to have an AfterEvent call - after_event = get_registry(agent).invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - kwargs=kwargs, - result=result, - ) - ) - yield after_event.result - return - - async for event in selected_tool.stream(tool_use, kwargs): - yield event - - result = event - - after_event = get_registry(agent).invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - kwargs=kwargs, - result=result, - ) + # Add standard arguments to kwargs for Python tools + kwargs.update( + { + "model": agent.model, + "system_prompt": agent.system_prompt, + "messages": agent.messages, + "tool_config": agent.tool_config, + } ) - yield after_event.result + + result = tool_func.invoke(tool, **kwargs) + yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from + return result except Exception as e: logger.exception("tool_name=<%s> | failed to process tool", tool_name) - error_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), + return { + "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event = get_registry(agent).invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - kwargs=kwargs, - result=error_result, - exception=e, - ) - ) - yield after_event.result async def _handle_tool_execution( @@ -394,8 +341,8 @@ async def _handle_tool_execution( kwargs: Additional keyword arguments, including request state. Yields: - Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple - containing: + Tool invocation events along with events yielded from a recursive call to the event loop. The last event is a + tuple containing: - The stop reason, - The updated message, - The updated event loop metrics, @@ -408,7 +355,7 @@ async def _handle_tool_execution( return def tool_handler(tool_use: ToolUse) -> ToolGenerator: - return run_tool(agent, tool_use, kwargs) + return run_tool(agent=agent, kwargs=kwargs, tool=tool_use) tool_events = run_tools( handler=tool_handler, @@ -418,8 +365,9 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator: tool_results=tool_results, cycle_trace=cycle_trace, parent_span=cycle_span, + thread_pool=agent.thread_pool, ) - async for tool_event in tool_events: + for tool_event in tool_events: yield tool_event # Store parent cycle ID for the next cycle diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 777c3a064..6ecc3e270 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -19,7 +19,7 @@ StreamEvent, Usage, ) -from ..types.tools import ToolSpec, ToolUse +from ..types.tools import ToolConfig, ToolUse logger = logging.getLogger(__name__) @@ -304,7 +304,7 @@ async def stream_messages( model: Model, system_prompt: Optional[str], messages: Messages, - tool_specs: list[ToolSpec], + tool_config: Optional[ToolConfig], ) -> AsyncGenerator[dict[str, Any], None]: """Streams messages to the model and processes the response. @@ -312,7 +312,7 @@ async def stream_messages( model: Model provider. system_prompt: The system prompt to send. messages: List of messages to send. - tool_specs: The list of tool specs. + tool_config: Configuration for the tools to use. Returns: The reason for stopping, the final message, and the usage metrics @@ -320,7 +320,8 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = remove_blank_messages_content_text(messages) + tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None - chunks = model.converse(messages, tool_specs if tool_specs else None, system_prompt) + chunks = model.converse(messages, tool_specs, system_prompt) async for event in process_stream(chunks, messages): yield event diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 61bd6ac3e..3ec805137 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -29,21 +29,13 @@ def log_end(self, event: EndRequestEvent) -> None: type-safe system that supports multiple subscribers per event type. """ -from .events import ( - AfterToolInvocationEvent, - AgentInitializedEvent, - BeforeToolInvocationEvent, - EndRequestEvent, - StartRequestEvent, -) +from .events import AgentInitializedEvent, EndRequestEvent, StartRequestEvent from .registry import HookCallback, HookEvent, HookProvider, HookRegistry __all__ = [ "AgentInitializedEvent", "StartRequestEvent", "EndRequestEvent", - "BeforeToolInvocationEvent", - "AfterToolInvocationEvent", "HookEvent", "HookProvider", "HookCallback", diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 559f1051d..c42b82d54 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -4,9 +4,7 @@ """ from dataclasses import dataclass -from typing import Any, Optional -from ...types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -58,63 +56,9 @@ class EndRequestEvent(HookEvent): @property def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True - - -@dataclass -class BeforeToolInvocationEvent(HookEvent): - """Event triggered before a tool is invoked. - - This event is fired just before the agent executes a tool, allowing hook - providers to inspect, modify, or replace the tool that will be executed. - The selected_tool can be modified by hook callbacks to change which tool - gets executed. - - Attributes: - selected_tool: The tool that will be invoked. Can be modified by hooks - to change which tool gets executed. This may be None if tool lookup failed. - tool_use: The tool parameters that will be passed to selected_tool. - kwargs: Keyword arguments that will be passed to the tool. - """ - - selected_tool: Optional[AgentTool] - tool_use: ToolUse - kwargs: dict[str, Any] - - def _can_write(self, name: str) -> bool: - return name in ["selected_tool", "tool_use"] - - -@dataclass -class AfterToolInvocationEvent(HookEvent): - """Event triggered after a tool invocation completes. + """Return True to invoke callbacks in reverse order for proper cleanup. - This event is fired after the agent has finished executing a tool, - regardless of whether the execution was successful or resulted in an error. - Hook providers can use this event for cleanup, logging, or post-processing. - - Note: This event uses reverse callback ordering, meaning callbacks registered - later will be invoked first during cleanup. - - Attributes: - selected_tool: The tool that was invoked. It may be None if tool lookup failed. - tool_use: The tool parameters that were passed to the tool invoked. - kwargs: Keyword arguments that were passed to the tool - result: The result of the tool invocation. Either a ToolResult on success - or an Exception if the tool execution failed. - """ - - selected_tool: Optional[AgentTool] - tool_use: ToolUse - kwargs: dict[str, Any] - result: ToolResult - exception: Optional[Exception] = None - - def _can_write(self, name: str) -> bool: - return name == "result" - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" + Returns: + True, indicating callbacks should be invoked in reverse order. + """ return True diff --git a/src/strands/experimental/hooks/registry.py b/src/strands/experimental/hooks/registry.py index befa6c397..4b3eceb4b 100644 --- a/src/strands/experimental/hooks/registry.py +++ b/src/strands/experimental/hooks/registry.py @@ -8,7 +8,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar if TYPE_CHECKING: from ...agent import Agent @@ -34,43 +34,9 @@ def should_reverse_callbacks(self) -> bool: """ return False - def _can_write(self, name: str) -> bool: - """Check if the given property can be written to. - - Args: - name: The name of the property to check. - - Returns: - True if the property can be written to, False otherwise. - """ - return False - - def __post_init__(self) -> None: - """Disallow writes to non-approved properties.""" - # This is needed as otherwise the class can't be initialized at all, so we trigger - # this after class initialization - super().__setattr__("_disallow_writes", True) - - def __setattr__(self, name: str, value: Any) -> None: - """Prevent setting attributes on hook events. - - Raises: - AttributeError: Always raised to prevent setting attributes on hook events. - """ - # Allow setting attributes: - # - during init (when __dict__) doesn't exist - # - if the subclass specifically said the property is writable - if not hasattr(self, "_disallow_writes") or self._can_write(name): - return super().__setattr__(name, value) - - raise AttributeError(f"Property {name} is not writable") - +T = TypeVar("T", bound=Callable) TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True) -"""Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes.""" - -TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent) -"""Generic for invoking events - non-contravariant to enable returning events.""" class HookProvider(Protocol): @@ -178,7 +144,7 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) - def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: + def invoke_callbacks(self, event: TEvent) -> None: """Invoke all registered callbacks for the given event. This method finds all callbacks registered for the event's type and @@ -191,9 +157,6 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: Raises: Any exceptions raised by callback functions will propagate to the caller. - Returns: - The event dispatched to registered callbacks. - Example: ```python event = StartRequestEvent(agent=my_agent) @@ -203,8 +166,6 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: for callback in self.get_callbacks_for(event): callback(event) - return event - def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]: """Get callbacks registered for the given event in the appropriate order. @@ -232,18 +193,3 @@ def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], No yield from reversed(callbacks) else: yield from callbacks - - -def get_registry(agent: "Agent") -> HookRegistry: - """*Experimental*: Get the hooks registry for the provided agent. - - This function is available while hooks are in experimental preview. - - Args: - agent: The agent whose hook registry should be returned. - - Returns: - The HookRegistry for the given agent. - - """ - return agent._hooks diff --git a/src/strands/handlers/__init__.py b/src/strands/handlers/__init__.py index fc1a56910..6a56201c7 100644 --- a/src/strands/handlers/__init__.py +++ b/src/strands/handlers/__init__.py @@ -2,6 +2,7 @@ Examples include: +- Processing tool invocations - Displaying events from the event stream """ diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index be96d55e2..02c3d9089 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -72,7 +72,7 @@ def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_conf logger.debug("config=<%s> | initializing", self.config) client_args = client_args or {} - self.client = anthropic.AsyncAnthropic(**client_args) + self.client = anthropic.Anthropic(**client_args) @override def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # type: ignore[override] @@ -358,8 +358,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] ModelThrottledException: If the request is throttled by Anthropic. """ try: - async with self.client.messages.stream(**request) as stream: - async for event in stream: + with self.client.messages.stream(**request) as stream: + for event in stream: if event.type in AnthropicModel.EVENT_TYPES: yield event.model_dump() diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 521d4491e..6f8492b79 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -8,7 +8,7 @@ import logging from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union -import mistralai +from mistralai import Mistral from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override @@ -94,7 +94,7 @@ def __init__( if api_key: client_args["api_key"] = api_key - self.client = mistralai.Mistral(**client_args) + self.client = Mistral(**client_args) @override def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore @@ -408,21 +408,21 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] try: if not self.config.get("stream", True): # Use non-streaming API - response = await self.client.chat.complete_async(**request) + response = self.client.chat.complete(**request) for event in self._handle_non_streaming_response(response): yield event return # Use the streaming API - stream_response = await self.client.chat.stream_async(**request) + stream_response = self.client.chat.stream(**request) yield {"chunk_type": "message_start"} content_started = False - tool_calls: dict[str, list[Any]] = {} + current_tool_calls: dict[str, dict[str, str]] = {} accumulated_text = "" - async for chunk in stream_response: + for chunk in stream_response: if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices: choice = chunk.data.choices[0] @@ -440,23 +440,24 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] if hasattr(delta, "tool_calls") and delta.tool_calls: for tool_call in delta.tool_calls: tool_id = tool_call.id - tool_calls.setdefault(tool_id, []).append(tool_call) - if hasattr(choice, "finish_reason") and choice.finish_reason: - if content_started: - yield {"chunk_type": "content_stop", "data_type": "text"} - - for tool_deltas in tool_calls.values(): - yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + if tool_id not in current_tool_calls: + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call} + current_tool_calls[tool_id] = {"name": tool_call.function.name, "arguments": ""} - for tool_delta in tool_deltas: - if hasattr(tool_delta.function, "arguments"): + if hasattr(tool_call.function, "arguments"): + current_tool_calls[tool_id]["arguments"] += tool_call.function.arguments yield { "chunk_type": "content_delta", "data_type": "tool", - "data": tool_delta.function.arguments, + "data": tool_call.function.arguments, } + if hasattr(choice, "finish_reason") and choice.finish_reason: + if content_started: + yield {"chunk_type": "content_stop", "data_type": "text"} + + for _ in current_tool_calls: yield {"chunk_type": "content_stop", "data_type": "tool"} yield {"chunk_type": "message_stop", "data": choice.finish_reason} @@ -498,7 +499,7 @@ async def structured_output( formatted_request["tool_choice"] = "any" formatted_request["parallel_tool_calls"] = False - response = await self.client.chat.complete_async(**formatted_request) + response = self.client.chat.complete(**formatted_request) if response.choices and response.choices[0].message.tool_calls: tool_call = response.choices[0].message.tool_calls[0] diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index ae70d2e77..707672498 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -7,7 +7,7 @@ import logging from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast -import ollama +from ollama import Client as OllamaClient from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override @@ -74,7 +74,7 @@ def __init__( ollama_client_args = ollama_client_args if ollama_client_args is not None else {} - self.client = ollama.AsyncClient(host, **ollama_client_args) + self.client = OllamaClient(host, **ollama_client_args) @override def update_config(self, **model_config: Unpack[OllamaConfig]) -> None: # type: ignore @@ -296,12 +296,12 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] """ tool_requested = False - response = await self.client.chat(**request) + response = self.client.chat(**request) yield {"chunk_type": "message_start"} yield {"chunk_type": "content_start", "data_type": "text"} - async for event in response: + for event in response: for tool_call in event.message.tool_calls or []: yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call} yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_call} @@ -330,7 +330,7 @@ async def structured_output( formatted_request = self.format_request(messages=prompt) formatted_request["format"] = output_model.model_json_schema() formatted_request["stream"] = False - response = await self.client.chat(**formatted_request) + response = self.client.chat(**formatted_request) try: content = response.message.content.strip() diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py deleted file mode 100644 index 0a5ca4a95..000000000 --- a/src/strands/models/writer.py +++ /dev/null @@ -1,431 +0,0 @@ -"""Writer model provider. - -- Docs: https://dev.writer.com/home/introduction -""" - -import base64 -import json -import logging -import mimetypes -from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast - -import writerai -from pydantic import BaseModel -from typing_extensions import Unpack, override - -from ..types.content import ContentBlock, Messages -from ..types.exceptions import ModelThrottledException -from ..types.models import Model -from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse - -logger = logging.getLogger(__name__) - -T = TypeVar("T", bound=BaseModel) - - -class WriterModel(Model): - """Writer API model provider implementation.""" - - class WriterConfig(TypedDict, total=False): - """Configuration options for Writer API. - - Attributes: - model_id: Model name to use (e.g. palmyra-x5, palmyra-x4, etc.). - max_tokens: Maximum number of tokens to generate. - stop: Default stop sequences. - stream_options: Additional options for streaming. - temperature: What sampling temperature to use. - top_p: Threshold for 'nucleus sampling' - """ - - model_id: str - max_tokens: Optional[int] - stop: Optional[Union[str, List[str]]] - stream_options: Dict[str, Any] - temperature: Optional[float] - top_p: Optional[float] - - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]): - """Initialize provider instance. - - Args: - client_args: Arguments for the Writer client (e.g., api_key, base_url, timeout, etc.). - **model_config: Configuration options for the Writer model. - """ - self.config = WriterModel.WriterConfig(**model_config) - - logger.debug("config=<%s> | initializing", self.config) - - client_args = client_args or {} - self.client = writerai.AsyncClient(**client_args) - - @override - def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: ignore[override] - """Update the Writer Model configuration with the provided arguments. - - Args: - **model_config: Configuration overrides. - """ - self.config.update(model_config) - - @override - def get_config(self) -> WriterConfig: - """Get the Writer model configuration. - - Returns: - The Writer model configuration. - """ - return self.config - - def _format_request_message_contents_vision(self, contents: list[ContentBlock]) -> list[dict[str, Any]]: - def _format_content_vision(content: ContentBlock) -> dict[str, Any]: - """Format a Writer content block for Palmyra V5 request. - - - NOTE: "reasoningContent", "document" and "video" are not supported currently. - - Args: - content: Message content. - - Returns: - Writer formatted content block for models, which support vision content format. - - Raises: - TypeError: If the content block type cannot be converted to a Writer-compatible format. - """ - if "text" in content: - return {"text": content["text"], "type": "text"} - - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") - - return { - "image_url": { - "url": f"data:{mime_type};base64,{image_data}", - }, - "type": "image_url", - } - - raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - - return [ - _format_content_vision(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] - - def _format_request_message_contents(self, contents: list[ContentBlock]) -> str: - def _format_content(content: ContentBlock) -> str: - """Format a Writer content block for Palmyra models (except V5) request. - - - NOTE: "reasoningContent", "document", "video" and "image" are not supported currently. - - Args: - content: Message content. - - Returns: - Writer formatted content block. - - Raises: - TypeError: If the content block type cannot be converted to a Writer-compatible format. - """ - if "text" in content: - return content["text"] - - raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - - content_blocks = list( - filter( - lambda content: content.get("text") - and not any(block_type in content for block_type in ["toolResult", "toolUse"]), - contents, - ) - ) - - if len(content_blocks) > 1: - raise ValueError( - f"Model with name {self.get_config().get('model_id', 'N/A')} doesn't support multiple contents" - ) - elif len(content_blocks) == 1: - return _format_content(content_blocks[0]) - else: - return "" - - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: - """Format a Writer tool call. - - Args: - tool_use: Tool use requested by the model. - - Returns: - Writer formatted tool call. - """ - return { - "function": { - "arguments": json.dumps(tool_use["input"]), - "name": tool_use["name"], - }, - "id": tool_use["toolUseId"], - "type": "function", - } - - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: - """Format a Writer tool message. - - Args: - tool_result: Tool result collected from a tool execution. - - Returns: - Writer formatted tool message. - """ - contents = cast( - list[ContentBlock], - [ - {"text": json.dumps(content["json"])} if "json" in content else content - for content in tool_result["content"] - ], - ) - - if self.get_config().get("model_id", "") == "palmyra-x5": - formatted_contents = self._format_request_message_contents_vision(contents) - else: - formatted_contents = self._format_request_message_contents(contents) # type: ignore [assignment] - - return { - "role": "tool", - "tool_call_id": tool_result["toolUseId"], - "content": formatted_contents, - } - - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format a Writer compatible messages array. - - Args: - messages: List of message objects to be processed by the model. - system_prompt: System prompt to provide context to the model. - - Returns: - Writer compatible messages array. - """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] - - for message in messages: - contents = message["content"] - - # Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}' - if self.get_config().get("model_id", "") == "palmyra-x5": - formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents) - else: - formatted_contents = self._format_request_message_contents(contents) - - formatted_tool_calls = [ - self._format_request_message_tool_call(content["toolUse"]) - for content in contents - if "toolUse" in content - ] - formatted_tool_messages = [ - self._format_request_tool_message(content["toolResult"]) - for content in contents - if "toolResult" in content - ] - - formatted_message = { - "role": message["role"], - "content": formatted_contents if len(formatted_contents) > 0 else "", - **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), - } - formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) - - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] - - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> Any: - """Format a streaming request to the underlying model. - - Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. - - Returns: - The formatted request. - """ - request = { - **{k: v for k, v in self.config.items()}, - "messages": self._format_request_messages(messages, system_prompt), - "stream": True, - } - try: - request["model"] = request.pop( - "model_id" - ) # To be consisted with other models WriterConfig use 'model_id' arg, but Writer API wait for 'model' arg - except KeyError as e: - raise KeyError("Please specify a model ID. Use 'model_id' keyword argument.") from e - - # Writer don't support empty tools attribute - if tool_specs: - request["tools"] = [ - { - "type": "function", - "function": { - "name": tool_spec["name"], - "description": tool_spec["description"], - "parameters": tool_spec["inputSchema"]["json"], - }, - } - for tool_spec in tool_specs - ] - - return request - - @override - def format_chunk(self, event: Any) -> StreamEvent: - """Format the model response events into standardized message chunks. - - Args: - event: A response event from the model. - - Returns: - The formatted chunk. - """ - match event.get("chunk_type", ""): - case "message_start": - return {"messageStart": {"role": "assistant"}} - - case "content_block_start": - if event["data_type"] == "text": - return {"contentBlockStart": {"start": {}}} - - return { - "contentBlockStart": { - "start": { - "toolUse": { - "name": event["data"].function.name, - "toolUseId": event["data"].id, - } - } - } - } - - case "content_block_delta": - if event["data_type"] == "text": - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} - - case "content_block_stop": - return {"contentBlockStop": {}} - - case "message_stop": - match event["data"]: - case "tool_calls": - return {"messageStop": {"stopReason": "tool_use"}} - case "length": - return {"messageStop": {"stopReason": "max_tokens"}} - case _: - return {"messageStop": {"stopReason": "end_turn"}} - - case "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens if event["data"] else 0, - "outputTokens": event["data"].completion_tokens if event["data"] else 0, - "totalTokens": event["data"].total_tokens if event["data"] else 0, - }, # If 'stream_options' param is unset, empty metadata will be provided. - # To avoid errors replacing expected fields with default zero value - "metrics": { - "latencyMs": 0, # All palmyra models don't provide 'latency' metadata - }, - }, - } - - case _: - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") - - @override - async def stream(self, request: Any) -> AsyncGenerator[Any, None]: - """Send the request to the model and get a streaming response. - - Args: - request: The formatted request to send to the model. - - Returns: - The model's response. - - Raises: - ModelThrottledException: When the model service is throttling requests from the client. - """ - try: - response = await self.client.chat.chat(**request) - except writerai.RateLimitError as e: - raise ModelThrottledException(str(e)) from e - - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_block_start", "data_type": "text"} - - tool_calls: dict[int, list[Any]] = {} - - async for chunk in response: - if not getattr(chunk, "choices", None): - continue - choice = chunk.choices[0] - - if choice.delta.content: - yield {"chunk_type": "content_block_delta", "data_type": "text", "data": choice.delta.content} - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - if choice.finish_reason: - break - - yield {"chunk_type": "content_block_stop", "data_type": "text"} - - for tool_deltas in tool_calls.values(): - tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:] - yield {"chunk_type": "content_block_start", "data_type": "tool", "data": tool_start} - - for tool_delta in tool_deltas: - yield {"chunk_type": "content_block_delta", "data_type": "tool", "data": tool_delta} - - yield {"chunk_type": "content_block_stop", "data_type": "tool"} - - yield {"chunk_type": "message_stop", "data": choice.finish_reason} - - # Iterating until the end to fetch metadata chunk - async for chunk in response: - _ = chunk - - yield {"chunk_type": "metadata", "data": chunk.usage} - - @override - async def structured_output( - self, output_model: Type[T], prompt: Messages - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: - """Get structured output from the model. - - Args: - output_model(Type[BaseModel]): The output model to use for the agent. - prompt(Messages): The prompt messages to use for the agent. - """ - formatted_request = self.format_request(messages=prompt) - formatted_request["response_format"] = { - "type": "json_schema", - "json_schema": {"schema": output_model.model_json_schema()}, - } - formatted_request["stream"] = False - formatted_request.pop("stream_options", None) - - response = await self.client.chat.chat(**formatted_request) - - try: - content = response.choices[0].message.content.strip() - yield {"output": output_model.model_validate_json(content)} - except Exception as e: - raise ValueError(f"Failed to parse or load content into model: {e}") from e diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 61d767858..b7a7af091 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -2,40 +2,31 @@ This module provides the StrandsA2AExecutor class, which adapts a Strands Agent to be used as an executor in the A2A protocol. It handles the execution of agent -requests and the conversion of Strands Agent streamed responses to A2A events. - -The A2A AgentExecutor ensures clients recieve responses for synchronous and -streamed requests to the A2AServer. +requests and the conversion of Strands Agent responses to A2A events. """ import logging -from typing import Any from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue -from a2a.server.tasks import TaskUpdater -from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError -from a2a.utils import new_agent_text_message, new_task +from a2a.types import UnsupportedOperationError +from a2a.utils import new_agent_text_message from a2a.utils.errors import ServerError from ...agent.agent import Agent as SAAgent -from ...agent.agent import AgentResult as SAAgentResult +from ...agent.agent_result import AgentResult as SAAgentResult -logger = logging.getLogger(__name__) +log = logging.getLogger(__name__) class StrandsA2AExecutor(AgentExecutor): - """Executor that adapts a Strands Agent to the A2A protocol. - - This executor uses streaming mode to handle the execution of agent requests - and converts Strands Agent responses to A2A protocol events. - """ + """Executor that adapts a Strands Agent to the A2A protocol.""" def __init__(self, agent: SAAgent): """Initialize a StrandsA2AExecutor. Args: - agent: The Strands Agent instance to adapt to the A2A protocol. + agent: The Strands Agent to adapt to the A2A protocol. """ self.agent = agent @@ -46,97 +37,24 @@ async def execute( ) -> None: """Execute a request using the Strands Agent and send the response as A2A events. - This method executes the user's input using the Strands Agent in streaming mode - and converts the agent's response to A2A events. - - Args: - context: The A2A request context, containing the user's input and task metadata. - event_queue: The A2A event queue used to send response events back to the client. - - Raises: - ServerError: If an error occurs during agent execution - """ - task = context.current_task - if not task: - task = new_task(context.message) # type: ignore - await event_queue.enqueue_event(task) - - updater = TaskUpdater(event_queue, task.id, task.contextId) - - try: - await self._execute_streaming(context, updater) - except Exception as e: - raise ServerError(error=InternalError()) from e - - async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None: - """Execute request in streaming mode. - - Streams the agent's response in real-time, sending incremental updates - as they become available from the agent. + This method executes the user's input using the Strands Agent and converts + the agent's response to A2A events, which are then sent to the event queue. Args: context: The A2A request context, containing the user's input and other metadata. - updater: The task updater for managing task state and sending updates. - """ - logger.info("Executing request in streaming mode") - user_input = context.get_user_input() - try: - async for event in self.agent.stream_async(user_input): - await self._handle_streaming_event(event, updater) - except Exception: - logger.exception("Error in streaming execution") - raise - - async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: - """Handle a single streaming event from the Strands Agent. - - Processes streaming events from the agent, converting data chunks to A2A - task updates and handling the final result when streaming is complete. - - Args: - event: The streaming event from the agent, containing either 'data' for - incremental content or 'result' for the final response. - updater: The task updater for managing task state and sending updates. - """ - logger.debug("Streaming event: %s", event) - if "data" in event: - if text_content := event["data"]: - await updater.update_status( - TaskState.working, - new_agent_text_message( - text_content, - updater.context_id, - updater.task_id, - ), - ) - elif "result" in event: - await self._handle_agent_result(event["result"], updater) - else: - logger.warning("Unexpected streaming event: %s", event) - - async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: - """Handle the final result from the Strands Agent. - - Processes the agent's final result, extracts text content from the response, - and adds it as an artifact to the task before marking the task as complete. - - Args: - result: The agent result object containing the final response, or None if no result. - updater: The task updater for managing task state and adding the final artifact. + event_queue: The A2A event queue, used to send response events. """ - if final_content := str(result): - await updater.add_artifact( - [Part(root=TextPart(text=final_content))], - name="agent_response", - ) - await updater.complete() + result: SAAgentResult = self.agent(context.get_user_input()) + if result.message and "content" in result.message: + for content_block in result.message["content"]: + if "text" in content_block: + await event_queue.enqueue_event(new_agent_text_message(content_block["text"])) async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Cancel an ongoing execution. - This method is called when a request cancellation is requested. Currently, - cancellation is not supported by the Strands Agent executor, so this method - always raises an UnsupportedOperationError. + This method is called when a request is cancelled. Currently, cancellation + is not supported, so this method raises an UnsupportedOperationError. Args: context: The A2A request context. @@ -146,5 +64,4 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None ServerError: Always raised with an UnsupportedOperationError, as cancellation is not currently supported. """ - logger.warning("Cancellation requested but not supported") raise ServerError(error=UnsupportedOperationError()) diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index 9442c34db..0e271b1cf 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -34,10 +34,12 @@ def __init__( version: str = "0.0.1", skills: list[AgentSkill] | None = None, ): - """Initialize an A2A-compatible server from a Strands agent. + """Initialize an A2A-compatible agent from a Strands agent. Args: agent: The Strands Agent to wrap with A2A compatibility. + name: The name of the agent, used in the AgentCard. + description: A description of the agent's capabilities, used in the AgentCard. host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". port: The port to bind the A2A server to. Defaults to 9000. version: The version of the agent. Defaults to "0.0.1". @@ -50,7 +52,8 @@ def __init__( self.strands_agent = agent self.name = self.strands_agent.name self.description = self.strands_agent.description - self.capabilities = AgentCapabilities(streaming=True) + # TODO: enable configurable capabilities and request handler + self.capabilities = AgentCapabilities() self.request_handler = DefaultRequestHandler( agent_executor=StrandsA2AExecutor(self.strands_agent), task_store=InMemoryTaskStore(), diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 10d230811..7f8abb1e6 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -10,7 +10,6 @@ from typing import Any, Dict, Mapping, Optional import opentelemetry.trace as trace_api -from opentelemetry.instrumentation.threading import ThreadingInstrumentor from opentelemetry.trace import Span, StatusCode from ..agent.agent_result import AgentResult @@ -90,7 +89,6 @@ def __init__( self.tracer_provider = trace_api.get_tracer_provider() self.tracer = self.tracer_provider.get_tracer(self.service_name) - ThreadingInstrumentor().instrument() def _start_span( self, @@ -407,7 +405,7 @@ def end_event_loop_cycle_span( def start_agent_span( self, - message: Message, + prompt: str, agent_name: str, model_id: Optional[str] = None, tools: Optional[list] = None, @@ -417,7 +415,7 @@ def start_agent_span( """Start a new span for an agent invocation. Args: - message: The user message being sent to the agent. + prompt: The user prompt being sent to the agent. agent_name: Name of the agent. model_id: Optional model identifier. tools: Optional list of tools being used. @@ -454,7 +452,7 @@ def start_agent_span( span, "gen_ai.user.message", event_attributes={ - "content": serialize(message["content"]), + "content": prompt, }, ) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index a91d6c255..46a6320ad 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -40,19 +40,20 @@ def my_tool(param1: str, param2: int = 42) -> dict: ``` """ -import asyncio import functools import inspect import logging from typing import ( Any, Callable, + Dict, Generic, Optional, ParamSpec, Type, TypeVar, Union, + cast, get_type_hints, overload, ) @@ -61,7 +62,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse +from ..types.tools import AgentTool, JSONSchema, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -118,7 +119,7 @@ def _create_input_model(self) -> Type[BaseModel]: Returns: A Pydantic BaseModel class customized for the function's parameters. """ - field_definitions: dict[str, Any] = {} + field_definitions: Dict[str, Any] = {} for name, param in self.signature.parameters.items(): # Skip special parameters @@ -178,7 +179,7 @@ def extract_metadata(self) -> ToolSpec: return tool_spec - def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: + def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None: """Clean up Pydantic schema to match Strands' expected format. Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could @@ -226,7 +227,7 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: if key in prop_schema: del prop_schema[key] - def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]: + def validate_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Validate input data using the Pydantic model. This method ensures that the input data meets the expected schema before it's passed to the actual function. It @@ -269,32 +270,32 @@ class DecoratedFunctionTool(AgentTool, Generic[P, R]): _tool_name: str _tool_spec: ToolSpec - _tool_func: Callable[P, R] _metadata: FunctionToolMetadata + original_function: Callable[P, R] def __init__( self, + function: Callable[P, R], tool_name: str, tool_spec: ToolSpec, - tool_func: Callable[P, R], metadata: FunctionToolMetadata, ): """Initialize the decorated function tool. Args: + function: The original function being decorated. tool_name: The name to use for the tool (usually the function name). tool_spec: The tool specification containing metadata for Agent integration. - tool_func: The original function being decorated. metadata: The FunctionToolMetadata object with extracted function information. """ super().__init__() - self._tool_name = tool_name + self.original_function = function self._tool_spec = tool_spec - self._tool_func = tool_func self._metadata = metadata + self._tool_name = tool_name - functools.update_wrapper(wrapper=self, wrapped=self._tool_func) + functools.update_wrapper(wrapper=self, wrapped=self.original_function) def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. @@ -322,10 +323,12 @@ def my_tool(): tool = instance.my_tool ``` """ - if instance is not None and not inspect.ismethod(self._tool_func): + if instance is not None and not inspect.ismethod(self.original_function): # Create a bound method - tool_func = self._tool_func.__get__(instance, instance.__class__) - return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata) + new_callback = self.original_function.__get__(instance, instance.__class__) + return DecoratedFunctionTool( + function=new_callback, tool_name=self.tool_name, tool_spec=self.tool_spec, metadata=self._metadata + ) return self @@ -342,7 +345,22 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: Returns: The result of the original function call. """ - return self._tool_func(*args, **kwargs) + if ( + len(args) > 0 + and isinstance(args[0], dict) + and (not args[0] or "toolUseId" in args[0] or "input" in args[0]) + ): + # This block is only for backwards compatability so we cast as any for now + logger.warning( + "issue=<%s> | " + "passing tool use into a function instead of using .invoke will be removed in a future release", + "https://github.com/strands-agents/sdk-python/pull/258", + ) + tool_use = cast(Any, args[0]) + + return cast(R, self.invoke(tool_use, **kwargs)) + + return self.original_function(*args, **kwargs) @property def tool_name(self) -> str: @@ -371,11 +389,10 @@ def tool_type(self) -> str: """ return "function" - @override - async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator: - """Stream the tool with a tool use specification. + def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + """Invoke the tool with a tool use specification. - This method handles tool use streams from a Strands Agent. It validates the input, + This method handles tool use invocations from a Strands Agent. It validates the input, calls the function, and formats the result according to the expected tool result format. Key operations: @@ -387,13 +404,15 @@ async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerat 5. Handle and format any errors that occur Args: - tool_use: The tool use specification from the Agent. - kwargs: Additional keyword arguments, may include 'agent' reference. + tool: The tool use specification from the Agent. + *args: Additional positional arguments (not typically used). + **kwargs: Additional keyword arguments, may include 'agent' reference. - Yields: - Tool events with the last being the tool result. + Returns: + A standardized tool result dictionary with status and content. """ # This is a tool use call - process accordingly + tool_use = tool tool_use_id = tool_use.get("toolUseId", "unknown") tool_input = tool_use.get("input", {}) @@ -405,21 +424,18 @@ async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerat if "agent" in kwargs and "agent" in self._metadata.signature.parameters: validated_input["agent"] = kwargs.get("agent") - # "Too few arguments" expected, hence the type ignore - if inspect.iscoroutinefunction(self._tool_func): - result = await self._tool_func(**validated_input) # type: ignore - else: - result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + # We get "too few arguments here" but because that's because fof the way we're calling it + result = self.original_function(**validated_input) # type: ignore # FORMAT THE RESULT for Strands Agent if isinstance(result, dict) and "status" in result and "content" in result: # Result is already in the expected format, just add toolUseId result["toolUseId"] = tool_use_id - yield result + return cast(ToolResult, result) else: # Wrap any other return value in the standard format # Always include at least one content item for consistency - yield { + return { "toolUseId": tool_use_id, "status": "success", "content": [{"text": str(result)}], @@ -428,7 +444,7 @@ async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerat except ValueError as e: # Special handling for validation errors error_msg = str(e) - yield { + return { "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error_msg}"}], @@ -437,7 +453,7 @@ async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerat # Return error result with exception details for any other error error_type = type(e).__name__ error_msg = str(e) - yield { + return { "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error_type} - {error_msg}"}], @@ -460,7 +476,7 @@ def get_display_properties(self) -> dict[str, str]: Function properties (e.g., function name). """ properties = super().get_display_properties() - properties["Function"] = self._tool_func.__name__ + properties["Function"] = self.original_function.__name__ return properties @@ -557,7 +573,7 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]": if not isinstance(tool_name, str): raise ValueError(f"Tool name must be a string, got {type(tool_name)}") - return DecoratedFunctionTool(tool_name, tool_spec, f, tool_meta) + return DecoratedFunctionTool(function=f, tool_name=tool_name, tool_spec=tool_spec, metadata=tool_meta) # Handle both @tool and @tool() syntax if func is None: diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 5c17f2be6..2291e0ff4 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -1,9 +1,11 @@ """Tool execution functionality for the event loop.""" -import asyncio import logging +import queue +import threading import time -from typing import Any, Optional, cast +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Generator, Optional, cast from opentelemetry import trace @@ -16,7 +18,7 @@ logger = logging.getLogger(__name__) -async def run_tools( +def run_tools( handler: RunToolHandler, tool_uses: list[ToolUse], event_loop_metrics: EventLoopMetrics, @@ -24,8 +26,9 @@ async def run_tools( tool_results: list[ToolResult], cycle_trace: Trace, parent_span: Optional[trace.Span] = None, -) -> ToolGenerator: - """Execute tools concurrently. + thread_pool: Optional[ThreadPoolExecutor] = None, +) -> Generator[dict[str, Any], None, None]: + """Execute tools either in parallel or sequentially. Args: handler: Tool handler processing function. @@ -35,38 +38,26 @@ async def run_tools( tool_results: List to populate with tool results. cycle_trace: Parent trace for the current cycle. parent_span: Parent span for the current cycle. + thread_pool: Optional thread pool for parallel processing. Yields: - Events of the tool stream. Tool results are appended to `tool_results`. + Events of the tool invocations. Tool results are appended to `tool_results`. """ - async def work( - tool_use: ToolUse, - worker_id: int, - worker_queue: asyncio.Queue, - worker_event: asyncio.Event, - stop_event: object, - ) -> ToolResult: + def handle(tool: ToolUse) -> ToolGenerator: tracer = get_tracer() - tool_call_span = tracer.start_tool_call_span(tool_use, parent_span) + tool_call_span = tracer.start_tool_call_span(tool, parent_span) - tool_name = tool_use["name"] + tool_name = tool["name"] tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) tool_start_time = time.time() - try: - async for event in handler(tool_use): - worker_queue.put_nowait((worker_id, event)) - await worker_event.wait() - - result = cast(ToolResult, event) - finally: - worker_queue.put_nowait((worker_id, stop_event)) + result = yield from handler(tool) tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) + event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) if tool_call_span: @@ -74,27 +65,52 @@ async def work( return result + def work( + tool: ToolUse, + worker_id: int, + worker_queue: queue.Queue, + worker_event: threading.Event, + ) -> ToolResult: + events = handle(tool) + + try: + while True: + event = next(events) + worker_queue.put((worker_id, event)) + worker_event.wait() + + except StopIteration as stop: + return cast(ToolResult, stop.value) + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] - worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue() - worker_events = [asyncio.Event() for _ in tool_uses] - stop_event = object() - - workers = [ - asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event)) - for worker_id, tool_use in enumerate(tool_uses) - ] - - worker_count = len(workers) - while worker_count: - worker_id, event = await worker_queue.get() - if event is stop_event: - worker_count -= 1 - continue - - yield event - worker_events[worker_id].set() - - tool_results.extend([worker.result() for worker in workers]) + + if thread_pool: + logger.debug("tool_count=<%s> | executing tools in parallel", len(tool_uses)) + + worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() + worker_events = [threading.Event() for _ in range(len(tool_uses))] + + workers = [ + thread_pool.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id]) + for worker_id, tool_use in enumerate(tool_uses) + ] + logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) + + while not all(worker.done() for worker in workers): + if not worker_queue.empty(): + worker_id, event = worker_queue.get() + yield event + worker_events[worker_id].set() + + time.sleep(0.001) + + tool_results.extend([worker.result() for worker in workers]) + + else: + # Sequential execution fallback + for tool_use in tool_uses: + result = yield from handle(tool_use) + tool_results.append(result) def validate_and_prepare_tools( diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 56433324e..7bf5c5e75 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -108,7 +108,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: if not callable(tool_func): raise TypeError(f"Tool {tool_name} function is not callable") - return PythonAgentTool(tool_name, tool_spec, tool_func) + return PythonAgentTool(tool_name, tool_spec, callback=tool_func) except Exception: logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index 40119f9db..e24c30b48 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -5,14 +5,12 @@ It allows MCP tools to be seamlessly integrated and used within the agent ecosystem. """ -import asyncio import logging from typing import TYPE_CHECKING, Any from mcp.types import Tool as MCPTool -from typing_extensions import override -from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse +from ...types.tools import AgentTool, ToolResult, ToolSpec, ToolUse if TYPE_CHECKING: from .mcp_client import MCPClient @@ -75,22 +73,13 @@ def tool_type(self) -> str: """ return "python" - @override - async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator: - """Stream the MCP tool. + def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + """Invoke the MCP tool. - This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and - input arguments. - - Yields: - Tool events with the last being the tool result. + This method delegates the tool invocation to the MCP server connection, + passing the tool use ID, tool name, and input arguments. """ - logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"]) - - result = await asyncio.to_thread( - self.mcp_client.call_tool_sync, - tool_use_id=tool_use["toolUseId"], - name=self.tool_name, - arguments=tool_use["input"], + logger.debug("invoking MCP tool '%s' with tool_use_id=%s", self.tool_name, tool["toolUseId"]) + return self.mcp_client.call_tool_sync( + tool_use_id=tool["toolUseId"], name=self.tool_name, arguments=tool["input"] ) - yield result diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index b0d84946d..5ab611e0c 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -17,7 +17,7 @@ from strands.tools.decorator import DecoratedFunctionTool -from ..types.tools import AgentTool, ToolSpec +from ..types.tools import AgentTool, Tool, ToolChoice, ToolChoiceAuto, ToolConfig, ToolSpec from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -347,7 +347,11 @@ def reload_tool(self, tool_name: str) -> None: # Validate tool spec self.validate_tool_spec(module.TOOL_SPEC) - new_tool = PythonAgentTool(tool_name, module.TOOL_SPEC, tool_function) + new_tool = PythonAgentTool( + tool_name=tool_name, + tool_spec=module.TOOL_SPEC, + callback=tool_function, + ) # Register the tool self.register_tool(new_tool) @@ -427,7 +431,11 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None: continue tool_spec = module.TOOL_SPEC - tool = PythonAgentTool(tool_name, tool_spec, tool_function) + tool = PythonAgentTool( + tool_name=tool_name, + tool_spec=tool_spec, + callback=tool_function, + ) self.register_tool(tool) successful_loads += 1 @@ -455,7 +463,11 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None: continue tool_spec = module.TOOL_SPEC - tool = PythonAgentTool(tool_name, tool_spec, tool_function) + tool = PythonAgentTool( + tool_name=tool_name, + tool_spec=tool_spec, + callback=tool_function, + ) self.register_tool(tool) successful_loads += 1 @@ -472,15 +484,20 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None: for tool_name, error in tool_import_errors.items(): logger.debug("tool_name=<%s> | import error | %s", tool_name, error) - def get_all_tool_specs(self) -> list[ToolSpec]: - """Get all the tool specs for all tools in this registry.. + def initialize_tool_config(self) -> ToolConfig: + """Initialize tool configuration from tool handler with optional filtering. Returns: - A list of ToolSpecs. + Tool config. """ all_tools = self.get_all_tools_config() - tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] - return tools + + tools: List[Tool] = [{"toolSpec": tool_spec} for tool_spec in all_tools.values()] + + return ToolConfig( + tools=tools, + toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), + ) def validate_tool_spec(self, tool_spec: ToolSpec) -> None: """Validate tool specification against required schema. diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 1d05bfa6f..1694f98c4 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -4,15 +4,11 @@ Python module-based tools, as well as utilities for validating tool uses and normalizing tool schemas. """ -import asyncio -import inspect import logging import re -from typing import Any +from typing import Any, Callable, Dict -from typing_extensions import override - -from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse +from ..types.tools import AgentTool, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -64,7 +60,7 @@ def validate_tool_use_name(tool: ToolUse) -> None: raise InvalidToolUseNameException(message) -def _normalize_property(prop_name: str, prop_def: Any) -> dict[str, Any]: +def _normalize_property(prop_name: str, prop_def: Any) -> Dict[str, Any]: """Normalize a single property definition. Args: @@ -92,7 +88,7 @@ def _normalize_property(prop_name: str, prop_def: Any) -> dict[str, Any]: return normalized_prop -def normalize_schema(schema: dict[str, Any]) -> dict[str, Any]: +def normalize_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Normalize a JSON schema to match expectations. This function recursively processes nested objects to preserve the complete schema structure. @@ -152,23 +148,25 @@ class PythonAgentTool(AgentTool): as SDK tools. """ + _callback: Callable[[ToolUse, Any, dict[str, Any]], ToolResult] _tool_name: str _tool_spec: ToolSpec - _tool_func: ToolFunc - def __init__(self, tool_name: str, tool_spec: ToolSpec, tool_func: ToolFunc) -> None: + def __init__( + self, tool_name: str, tool_spec: ToolSpec, callback: Callable[[ToolUse, Any, dict[str, Any]], ToolResult] + ) -> None: """Initialize a Python-based tool. Args: tool_name: Unique identifier for the tool. tool_spec: Tool specification defining parameters and behavior. - tool_func: Python function to execute when the tool is invoked. + callback: Python function to execute when the tool is invoked. """ super().__init__() self._tool_name = tool_name self._tool_spec = tool_spec - self._tool_func = tool_func + self._callback = callback @property def tool_name(self) -> str: @@ -197,20 +195,15 @@ def tool_type(self) -> str: """ return "python" - @override - async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator: - """Stream the Python function with the given tool use request. + def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + """Execute the Python function with the given tool use request. Args: - tool_use: The tool use request. - kwargs: Additional keyword arguments to pass to the underlying tool function. + tool: The tool use request. + *args: Additional positional arguments to pass to the underlying callback function. + **kwargs: Additional keyword arguments to pass to the underlying callback function. - Yields: - Tool events with the last being the tool result. + Returns: + A ToolResult containing the status and content from the callback execution. """ - if inspect.iscoroutinefunction(self._tool_func): - result = await self._tool_func(tool_use, **kwargs) - else: - result = await asyncio.to_thread(self._tool_func, tool_use, **kwargs) - - yield result + return self._callback(tool, *args, **kwargs) diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 09d24bd80..30971c2ba 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -34,6 +34,32 @@ class OpenAIModel(Model, abc.ABC): config: dict[str, Any] + @staticmethod + def b64encode(data: bytes) -> bytes: + """Base64 encode the provided data. + + If the data is already base64 encoded, we do nothing. + Note, this is a temporary method used to provide a warning to users who pass in base64 encoded data. In future + versions, images and documents will be base64 encoded on behalf of customers for consistency with the other + providers and general convenience. + + Args: + data: Data to encode. + + Returns: + Base64 encoded data. + """ + try: + base64.b64decode(data, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images and documents will not be accepted in future versions", + "https://github.com/strands-agents/sdk-python/issues/252", + ) + except ValueError: + data = base64.b64encode(data) + + return data + @classmethod def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format an OpenAI compatible content block. @@ -60,7 +86,7 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + image_data = OpenAIModel.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") return { "image_url": { diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index e2895f2dd..798cbc185 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from typing import Any, Callable, Generator, Literal, Union from typing_extensions import TypedDict @@ -130,11 +130,11 @@ class ToolChoiceTool(TypedDict): - "tool": The model must use the specified tool """ -RunToolHandler = Callable[[ToolUse], AsyncGenerator[dict[str, Any], None]] +RunToolHandler = Callable[[ToolUse], Generator[dict[str, Any], None, ToolResult]] """Callback that runs a single tool and streams back results.""" -ToolGenerator = AsyncGenerator[Any, None] -"""Generator of tool events with the last being the tool result.""" +ToolGenerator = Generator[dict[str, Any], None, ToolResult] +"""Generator of tool events and a returned tool result.""" class ToolConfig(TypedDict): @@ -149,30 +149,11 @@ class ToolConfig(TypedDict): toolChoice: ToolChoice -class ToolFunc(Protocol): - """Function signature for Python decorated and module based tools.""" - - __name__: str - - def __call__( - self, *args: Any, **kwargs: Any - ) -> Union[ - ToolResult, - Awaitable[ToolResult], - ]: - """Function signature for Python decorated and module based tools. - - Returns: - Tool result or awaitable tool result. - """ - ... - - class AgentTool(ABC): """Abstract base class for all SDK tools. This class defines the interface that all tool implementations must follow. Each tool must provide its name, - specification, and implement a stream method that executes the tool's functionality. + specification, and implement an invoke method that executes the tool's functionality. """ _is_dynamic: bool @@ -216,17 +197,18 @@ def supports_hot_reload(self) -> bool: @abstractmethod # pragma: no cover - def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator: - """Stream tool events and return the final result. + def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + """Execute the tool's functionality with the given tool use request. Args: - tool_use: The tool use request containing tool ID and parameters. - kwargs: Keyword arguments to pass to the tool. + tool: The tool use request containing tool ID and parameters. + *args: Positional arguments to pass to the tool. + **kwargs: Keyword arguments to pass to the tool. - Yield: - Tool events with the last being the tool result. + Returns: + The result of the tool execution. """ - ... + pass @property def is_dynamic(self) -> bool: diff --git a/tests-integ/__init__.py b/tests-integ/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-integ/conftest.py b/tests-integ/conftest.py new file mode 100644 index 000000000..4b38540c5 --- /dev/null +++ b/tests-integ/conftest.py @@ -0,0 +1,20 @@ +import pytest + +## Async + + +@pytest.fixture(scope="session") +def agenerator(): + async def agenerator(items): + for item in items: + yield item + + return agenerator + + +@pytest.fixture(scope="session") +def alist(): + async def alist(items): + return [item async for item in items] + + return alist diff --git a/tests-integ/echo_server.py b/tests-integ/echo_server.py new file mode 100644 index 000000000..d309607a8 --- /dev/null +++ b/tests-integ/echo_server.py @@ -0,0 +1,39 @@ +""" +Echo Server for MCP Integration Testing + +This module implements a simple echo server using the Model Context Protocol (MCP). +It provides a basic tool that echoes back any input string, which is useful for +testing the MCP communication flow and validating that messages are properly +transmitted between the client and server. + +The server runs with stdio transport, making it suitable for integration tests +where the client can spawn this process and communicate with it through standard +input/output streams. + +Usage: + Run this file directly to start the echo server: + $ python echo_server.py +""" + +from mcp.server import FastMCP + + +def start_echo_server(): + """ + Initialize and start the MCP echo server. + + Creates a FastMCP server instance with a single 'echo' tool that returns + any input string back to the caller. The server uses stdio transport + for communication. + """ + mcp = FastMCP("Echo Server") + + @mcp.tool(description="Echos response back to the user") + def echo(to_echo: str) -> str: + return to_echo + + mcp.run(transport="stdio") + + +if __name__ == "__main__": + start_echo_server() diff --git a/tests-integ/test_agent_async.py b/tests-integ/test_agent_async.py new file mode 100644 index 000000000..597ba13f7 --- /dev/null +++ b/tests-integ/test_agent_async.py @@ -0,0 +1,22 @@ +import pytest + +import strands + + +@pytest.fixture +def agent(): + return strands.Agent() + + +@pytest.mark.asyncio +async def test_stream_async(agent): + stream = agent.stream_async("hello") + + exp_message = "" + async for event in stream: + if "event" in event and "contentBlockDelta" in event["event"]: + exp_message += event["event"]["contentBlockDelta"]["delta"]["text"] + + tru_message = agent.messages[-1]["content"][0]["text"] + + assert tru_message == exp_message diff --git a/tests-integ/test_bedrock_cache_point.py b/tests-integ/test_bedrock_cache_point.py new file mode 100644 index 000000000..82bca22a2 --- /dev/null +++ b/tests-integ/test_bedrock_cache_point.py @@ -0,0 +1,31 @@ +from strands import Agent +from strands.types.content import Messages + + +def test_bedrock_cache_point(): + messages: Messages = [ + { + "role": "user", + "content": [ + { + "text": "Some really long text!" * 1000 # Minimum token count for cachePoint is 1024 tokens + }, + {"cachePoint": {"type": "default"}}, + ], + }, + {"role": "assistant", "content": [{"text": "Blue!"}]}, + ] + + cache_point_usage = 0 + + def cache_point_callback_handler(**kwargs): + nonlocal cache_point_usage + if "event" in kwargs and kwargs["event"] and "metadata" in kwargs["event"] and kwargs["event"]["metadata"]: + metadata = kwargs["event"]["metadata"] + if "usage" in metadata and metadata["usage"]: + if "cacheReadInputTokens" in metadata["usage"] or "cacheWriteInputTokens" in metadata["usage"]: + cache_point_usage += 1 + + agent = Agent(messages=messages, callback_handler=cache_point_callback_handler, load_tools_from_directory=False) + agent("What is favorite color?") + assert cache_point_usage > 0 diff --git a/tests-integ/test_bedrock_guardrails.py b/tests-integ/test_bedrock_guardrails.py new file mode 100644 index 000000000..bf0be7068 --- /dev/null +++ b/tests-integ/test_bedrock_guardrails.py @@ -0,0 +1,160 @@ +import time + +import boto3 +import pytest + +from strands import Agent +from strands.models.bedrock import BedrockModel + +BLOCKED_INPUT = "BLOCKED_INPUT" +BLOCKED_OUTPUT = "BLOCKED_OUTPUT" + + +@pytest.fixture(scope="module") +def boto_session(): + return boto3.Session(region_name="us-east-1") + + +@pytest.fixture(scope="module") +def bedrock_guardrail(boto_session): + """ + Fixture that creates a guardrail before tests if it doesn't already exist." + """ + + client = boto_session.client("bedrock") + + guardrail_name = "test-guardrail-block-cactus" + guardrail_id = get_guardrail_id(client, guardrail_name) + + if guardrail_id: + print(f"Guardrail {guardrail_name} already exists with ID: {guardrail_id}") + else: + print(f"Creating guardrail {guardrail_name}") + response = client.create_guardrail( + name=guardrail_name, + description="Testing Guardrail", + wordPolicyConfig={ + "wordsConfig": [ + { + "text": "CACTUS", + "inputAction": "BLOCK", + "outputAction": "BLOCK", + "inputEnabled": True, + "outputEnabled": True, + }, + ], + }, + blockedInputMessaging=BLOCKED_INPUT, + blockedOutputsMessaging=BLOCKED_OUTPUT, + ) + guardrail_id = response.get("guardrailId") + print(f"Created test guardrail with ID: {guardrail_id}") + wait_for_guardrail_active(client, guardrail_id) + return guardrail_id + + +def get_guardrail_id(client, guardrail_name): + """ + Retrieves the ID of a guardrail by its name. + + Args: + client: The Bedrock client instance + guardrail_name: Name of the guardrail to look up + + Returns: + str: The ID of the guardrail if found, None otherwise + """ + response = client.list_guardrails() + for guardrail in response.get("guardrails", []): + if guardrail["name"] == guardrail_name: + return guardrail["id"] + return None + + +def wait_for_guardrail_active(bedrock_client, guardrail_id, max_attempts=10, delay=5): + """ + Wait for the guardrail to become active + """ + for _ in range(max_attempts): + response = bedrock_client.get_guardrail(guardrailIdentifier=guardrail_id) + status = response.get("status") + + if status == "READY": + print(f"Guardrail {guardrail_id} is now active") + return True + + print(f"Waiting for guardrail to become active. Current status: {status}") + time.sleep(delay) + + print(f"Guardrail did not become active within {max_attempts * delay} seconds.") + raise RuntimeError("Guardrail did not become active.") + + +def test_guardrail_input_intervention(boto_session, bedrock_guardrail): + bedrock_model = BedrockModel( + guardrail_id=bedrock_guardrail, + guardrail_version="DRAFT", + boto_session=boto_session, + ) + + agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None) + + response1 = agent("CACTUS") + response2 = agent("Hello!") + + assert response1.stop_reason == "guardrail_intervened" + assert str(response1).strip() == BLOCKED_INPUT + assert response2.stop_reason != "guardrail_intervened" + assert str(response2).strip() != BLOCKED_INPUT + + +@pytest.mark.parametrize("processing_mode", ["sync", "async"]) +def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processing_mode): + bedrock_model = BedrockModel( + guardrail_id=bedrock_guardrail, + guardrail_version="DRAFT", + guardrail_redact_output=False, + guardrail_stream_processing_mode=processing_mode, + boto_session=boto_session, + ) + + agent = Agent( + model=bedrock_model, + system_prompt="When asked to say the word, say CACTUS.", + callback_handler=None, + load_tools_from_directory=False, + ) + + response1 = agent("Say the word.") + response2 = agent("Hello!") + assert response1.stop_reason == "guardrail_intervened" + assert BLOCKED_OUTPUT in str(response1) + assert response2.stop_reason != "guardrail_intervened" + assert BLOCKED_OUTPUT not in str(response2) + + +@pytest.mark.parametrize("processing_mode", ["sync", "async"]) +def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode): + REDACT_MESSAGE = "Redacted." + bedrock_model = BedrockModel( + guardrail_id=bedrock_guardrail, + guardrail_version="DRAFT", + guardrail_stream_processing_mode=processing_mode, + guardrail_redact_output=True, + guardrail_redact_output_message=REDACT_MESSAGE, + region_name="us-east-1", + ) + + agent = Agent( + model=bedrock_model, + system_prompt="When asked to say the word, say CACTUS.", + callback_handler=None, + load_tools_from_directory=False, + ) + + response1 = agent("Say the word.") + response2 = agent("Hello!") + assert response1.stop_reason == "guardrail_intervened" + assert REDACT_MESSAGE in str(response1) + assert response2.stop_reason != "guardrail_intervened" + assert REDACT_MESSAGE not in str(response2) diff --git a/tests-integ/test_context_overflow.py b/tests-integ/test_context_overflow.py new file mode 100644 index 000000000..16dc3c4b8 --- /dev/null +++ b/tests-integ/test_context_overflow.py @@ -0,0 +1,13 @@ +from strands import Agent +from strands.types.content import Messages + + +def test_context_window_overflow(): + messages: Messages = [ + {"role": "user", "content": [{"text": "Too much text!" * 100000}]}, + {"role": "assistant", "content": [{"text": "That was a lot of text!"}]}, + ] + + agent = Agent(messages=messages, load_tools_from_directory=False) + agent("Hi!") + assert len(agent.messages) == 2 diff --git a/tests-integ/test_function_tools.py b/tests-integ/test_function_tools.py new file mode 100644 index 000000000..835dccf5d --- /dev/null +++ b/tests-integ/test_function_tools.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +""" +Test script for function-based tools +""" + +import logging +from typing import Optional + +from strands import Agent, tool + +logging.getLogger("strands").setLevel(logging.DEBUG) +logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) + + +@tool +def word_counter(text: str) -> str: + """ + Count words in text. + + Args: + text: Text to analyze + """ + count = len(text.split()) + return f"Word count: {count}" + + +@tool(name="count_chars", description="Count characters in text") +def count_chars(text: str, include_spaces: Optional[bool] = True) -> str: + """ + Count characters in text. + + Args: + text: Text to analyze + include_spaces: Whether to include spaces in the count + """ + if not include_spaces: + text = text.replace(" ", "") + return f"Character count: {len(text)}" + + +# Initialize agent with function tools +agent = Agent(tools=[word_counter, count_chars]) + +print("\n===== Testing Direct Tool Access =====") +# Use the tools directly +word_result = agent.tool.word_counter(text="Hello world, this is a test") +print(f"\nWord counter result: {word_result}") + +char_result = agent.tool.count_chars(text="Hello world!", include_spaces=False) +print(f"\nCharacter counter result: {char_result}") + +print("\n===== Testing Natural Language Access =====") +# Use through natural language +nl_result = agent("Count the words in this sentence: 'The quick brown fox jumps over the lazy dog'") +print(f"\nNL Result: {nl_result}") diff --git a/tests-integ/test_hot_tool_reload_decorator.py b/tests-integ/test_hot_tool_reload_decorator.py new file mode 100644 index 000000000..0a15a2be7 --- /dev/null +++ b/tests-integ/test_hot_tool_reload_decorator.py @@ -0,0 +1,143 @@ +""" +Integration test for hot tool reloading functionality with the @tool decorator. + +This test verifies that the Strands Agent can automatically detect and load +new tools created with the @tool decorator when they are added to a tools directory. +""" + +import logging +import os +import time +from pathlib import Path + +from strands import Agent + +logging.getLogger("strands").setLevel(logging.DEBUG) +logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) + + +def test_hot_reload_decorator(): + """ + Test that the Agent automatically loads tools created with @tool decorator + when added to the current working directory's tools folder. + """ + # Set up the tools directory in current working directory + tools_dir = Path.cwd() / "tools" + os.makedirs(tools_dir, exist_ok=True) + + # Tool path that will need cleanup + test_tool_path = tools_dir / "uppercase.py" + + try: + # Create an Agent instance without any tools + agent = Agent() + + # Create a test tool using @tool decorator + with open(test_tool_path, "w") as f: + f.write(""" +from strands import tool + +@tool +def uppercase(text: str) -> str: + \"\"\"Convert text to uppercase.\"\"\" + return f"Input: {text}, Output: {text.upper()}" +""") + + # Wait for tool detection + time.sleep(3) + + # Verify the tool was automatically loaded + assert "uppercase" in agent.tool_names, "Agent should have detected and loaded the uppercase tool" + + # Test calling the dynamically loaded tool + result = agent.tool.uppercase(text="hello world") + + # Check that the result is successful + assert result.get("status") == "success", "Tool call should be successful" + + # Check the content of the response + content_list = result.get("content", []) + assert len(content_list) > 0, "Tool response should have content" + + # Check that the expected message is in the content + text_content = next((item.get("text") for item in content_list if "text" in item), "") + assert "Input: hello world, Output: HELLO WORLD" in text_content + + finally: + # Clean up - remove the test file + if test_tool_path.exists(): + os.remove(test_tool_path) + + +def test_hot_reload_decorator_update(): + """ + Test that the Agent detects updates to tools created with @tool decorator. + """ + # Set up the tools directory in current working directory + tools_dir = Path.cwd() / "tools" + os.makedirs(tools_dir, exist_ok=True) + + # Tool path that will need cleanup - make sure filename matches function name + test_tool_path = tools_dir / "greeting.py" + + try: + # Create an Agent instance + agent = Agent() + + # Create the initial version of the tool + with open(test_tool_path, "w") as f: + f.write(""" +from strands import tool + +@tool +def greeting(name: str) -> str: + \"\"\"Generate a simple greeting.\"\"\" + return f"Hello, {name}!" +""") + + # Wait for tool detection + time.sleep(3) + + # Verify the tool was loaded + assert "greeting" in agent.tool_names, "Agent should have detected and loaded the greeting tool" + + # Test calling the tool + result1 = agent.tool.greeting(name="Strands") + text_content1 = next((item.get("text") for item in result1.get("content", []) if "text" in item), "") + assert "Hello, Strands!" in text_content1, "Tool should return simple greeting" + + # Update the tool with new functionality + with open(test_tool_path, "w") as f: + f.write(""" +from strands import tool +import datetime + +@tool +def greeting(name: str, formal: bool = False) -> str: + \"\"\"Generate a greeting with optional formality.\"\"\" + current_hour = datetime.datetime.now().hour + time_of_day = "morning" if current_hour < 12 else "afternoon" if current_hour < 18 else "evening" + + if formal: + return f"Good {time_of_day}, {name}. It's a pleasure to meet you." + else: + return f"Hey {name}! How's your {time_of_day} going?" +""") + + # Wait for hot reload to detect the change + time.sleep(3) + + # Test calling the updated tool + result2 = agent.tool.greeting(name="Strands", formal=True) + text_content2 = next((item.get("text") for item in result2.get("content", []) if "text" in item), "") + assert "Good" in text_content2 and "Strands" in text_content2 and "pleasure to meet you" in text_content2 + + # Test with informal parameter + result3 = agent.tool.greeting(name="Strands", formal=False) + text_content3 = next((item.get("text") for item in result3.get("content", []) if "text" in item), "") + assert "Hey Strands!" in text_content3 and "going" in text_content3 + + finally: + # Clean up - remove the test file + if test_tool_path.exists(): + os.remove(test_tool_path) diff --git a/tests-integ/test_image.png b/tests-integ/test_image.png new file mode 100644 index 000000000..9caac13be Binary files /dev/null and b/tests-integ/test_image.png differ diff --git a/tests-integ/test_mcp_client.py b/tests-integ/test_mcp_client.py new file mode 100644 index 000000000..8b1dade33 --- /dev/null +++ b/tests-integ/test_mcp_client.py @@ -0,0 +1,130 @@ +import base64 +import os +import threading +import time +from typing import List, Literal + +import pytest +from mcp import StdioServerParameters, stdio_client +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import ImageContent as MCPImageContent + +from strands import Agent +from strands.tools.mcp.mcp_client import MCPClient +from strands.tools.mcp.mcp_types import MCPTransport +from strands.types.content import Message +from strands.types.tools import ToolUse + + +def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int): + """ + Initialize and start an MCP calculator server for integration testing. + + This function creates a FastMCP server instance that provides a simple + calculator tool for performing addition operations. The server uses + Server-Sent Events (SSE) transport for communication, making it accessible + over HTTP. + """ + from mcp.server import FastMCP + + mcp = FastMCP("Calculator Server", port=port) + + @mcp.tool(description="Calculator tool which performs calculations") + def calculator(x: int, y: int) -> int: + return x + y + + @mcp.tool(description="Generates a custom image") + def generate_custom_image() -> MCPImageContent: + try: + with open("tests-integ/test_image.png", "rb") as image_file: + encoded_image = base64.b64encode(image_file.read()) + return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") + except Exception as e: + print("Error while generating custom image: {}".format(e)) + + mcp.run(transport=transport) + + +def test_mcp_client(): + """ + Test should yield output similar to the following + {'role': 'user', 'content': [{'text': 'add 1 and 2, then echo the result back to me'}]} + {'role': 'assistant', 'content': [{'text': "I'll help you add 1 and 2 and then echo the result back to you.\n\nFirst, I'll calculate 1 + 2:"}, {'toolUse': {'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'name': 'calculator', 'input': {'x': 1, 'y': 2}}}]} + {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_17ptaKUxQB20ySZxwgiI_w', 'content': [{'text': '3'}]}}]} + {'role': 'assistant', 'content': [{'text': "\n\nNow I'll echo the result back to you:"}, {'toolUse': {'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'name': 'echo', 'input': {'to_echo': '3'}}}]} + {'role': 'user', 'content': [{'toolResult': {'status': 'success', 'toolUseId': 'tooluse_GlOc5SN8TE6ti8jVZJMBOg', 'content': [{'text': '3'}]}}]} + {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]} + """ # noqa: E501 + + server_thread = threading.Thread( + target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True + ) + server_thread.start() + time.sleep(2) # wait for server to startup completely + + sse_mcp_client = MCPClient(lambda: sse_client("http://127.0.0.1:8000/sse")) + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"])) + ) + with sse_mcp_client, stdio_mcp_client: + agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync()) + agent("add 1 and 2, then echo the result back to me") + + tool_use_content_blocks = _messages_to_content_blocks(agent.messages) + assert any([block["name"] == "echo" for block in tool_use_content_blocks]) + assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) + + image_prompt = """ + Generate a custom image, then tell me if the image is red, blue, yellow, pink, orange, or green. + RESPOND ONLY WITH THE COLOR + """ + assert any( + [ + "yellow".casefold() in block["text"].casefold() + for block in agent(image_prompt).message["content"] + if "text" in block + ] + ) + + +def test_can_reuse_mcp_client(): + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests-integ/echo_server.py"])) + ) + with stdio_mcp_client: + stdio_mcp_client.list_tools_sync() + pass + with stdio_mcp_client: + agent = Agent(tools=stdio_mcp_client.list_tools_sync()) + agent("echo the following to me DOG") + + tool_use_content_blocks = _messages_to_content_blocks(agent.messages) + assert any([block["name"] == "echo" for block in tool_use_content_blocks]) + + +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", +) +def test_streamable_http_mcp_client(): + server_thread = threading.Thread( + target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + ) + server_thread.start() + time.sleep(2) # wait for server to startup completely + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url="http://127.0.0.1:8001/mcp") + + streamable_http_client = MCPClient(transport_callback) + with streamable_http_client: + agent = Agent(tools=streamable_http_client.list_tools_sync()) + agent("add 1 and 2 using a calculator") + + tool_use_content_blocks = _messages_to_content_blocks(agent.messages) + assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) + + +def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: + return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py new file mode 100644 index 000000000..50033f8f7 --- /dev/null +++ b/tests-integ/test_model_anthropic.py @@ -0,0 +1,63 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-7-sonnet-20250219", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant." + + +@pytest.fixture +def agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py new file mode 100644 index 000000000..120f4036b --- /dev/null +++ b/tests-integ/test_model_bedrock.py @@ -0,0 +1,153 @@ +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models import BedrockModel + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that uses & instead of ." + + +@pytest.fixture +def streaming_model(): + return BedrockModel( + model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", + streaming=True, + ) + + +@pytest.fixture +def non_streaming_model(): + return BedrockModel( + model_id="us.meta.llama3-2-90b-instruct-v1:0", + streaming=False, + ) + + +@pytest.fixture +def streaming_agent(streaming_model, system_prompt): + return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +@pytest.fixture +def non_streaming_agent(non_streaming_model, system_prompt): + return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +def test_streaming_agent(streaming_agent): + """Test agent with streaming model.""" + result = streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +def test_non_streaming_agent(non_streaming_agent): + """Test agent with non-streaming model.""" + result = non_streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +@pytest.mark.asyncio +async def test_streaming_model_events(streaming_model, alist): + """Test streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = await alist(streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +@pytest.mark.asyncio +async def test_non_streaming_model_events(non_streaming_model, alist): + """Test non-streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = await alist(non_streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +def test_tool_use_streaming(streaming_model): + """Test tool use with streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False) + result = agent("What is 123 + 456?") + + # Print the full message content for debugging + print("\nFull message content:") + import json + + print(json.dumps(result.message["content"], indent=2)) + + assert tool_was_called + + +def test_tool_use_non_streaming(non_streaming_model): + """Test tool use with non-streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) + agent("What is 123 + 456?") + + assert tool_was_called + + +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_litellm.py b/tests-integ/test_model_litellm.py new file mode 100644 index 000000000..01a3e1211 --- /dev/null +++ b/tests-integ/test_model_litellm.py @@ -0,0 +1,49 @@ +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.litellm import LiteLLMModel + + +@pytest.fixture +def model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(model): + class Weather(BaseModel): + time: str + weather: str + + agent_no_tools = Agent(model=model) + + result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_llamaapi.py b/tests-integ/test_model_llamaapi.py new file mode 100644 index 000000000..dad6919e2 --- /dev/null +++ b/tests-integ/test_model_llamaapi.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import os + +import pytest + +import strands +from strands import Agent +from strands.models.llamaapi import LlamaAPIModel + + +@pytest.fixture +def model(): + return LlamaAPIModel( + model_id="Llama-4-Maverick-17B-128E-Instruct-FP8", + client_args={ + "api_key": os.getenv("LLAMA_API_KEY"), + }, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +@pytest.mark.skipif( + "LLAMA_API_KEY" not in os.environ, + reason="LLAMA_API_KEY environment variable missing", +) +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) diff --git a/tests-integ/test_model_mistral.py b/tests-integ/test_model_mistral.py new file mode 100644 index 000000000..f2664f7fd --- /dev/null +++ b/tests-integ/test_model_mistral.py @@ -0,0 +1,157 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.mistral import MistralModel + + +@pytest.fixture +def streaming_model(): + return MistralModel( + model_id="mistral-medium-latest", + api_key=os.getenv("MISTRAL_API_KEY"), + stream=True, + temperature=0.7, + max_tokens=1000, + top_p=0.9, + ) + + +@pytest.fixture +def non_streaming_model(): + return MistralModel( + model_id="mistral-medium-latest", + api_key=os.getenv("MISTRAL_API_KEY"), + stream=False, + temperature=0.7, + max_tokens=1000, + top_p=0.9, + ) + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that provides helpful and accurate information." + + +@pytest.fixture +def calculator_tool(): + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + return eval(expression) + + return calculator + + +@pytest.fixture +def weather_tools(): + @strands.tool + def tool_time() -> str: + """Get the current time.""" + return "12:00" + + @strands.tool + def tool_weather() -> str: + """Get the current weather.""" + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def streaming_agent(streaming_model): + return Agent(model=streaming_model) + + +@pytest.fixture +def non_streaming_agent(non_streaming_model): + return Agent(model=non_streaming_model) + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_streaming_agent_basic(streaming_agent): + """Test basic streaming agent functionality.""" + result = streaming_agent("Tell me about Agentic AI in one sentence.") + + assert len(str(result)) > 0 + assert hasattr(result, "message") + assert "content" in result.message + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_non_streaming_agent_basic(non_streaming_agent): + """Test basic non-streaming agent functionality.""" + result = non_streaming_agent("Tell me about Agentic AI in one sentence.") + + assert len(str(result)) > 0 + assert hasattr(result, "message") + assert "content" in result.message + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_tool_use_streaming(streaming_model): + """Test tool use with streaming model.""" + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + return eval(expression) + + agent = Agent(model=streaming_model, tools=[calculator]) + result = agent("What is the square root of 1764") + + # Verify the result contains the calculation + text_content = str(result).lower() + assert "42" in text_content + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_tool_use_non_streaming(non_streaming_model): + """Test tool use with non-streaming model.""" + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + return eval(expression) + + agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) + result = agent("What is the square root of 1764") + + text_content = str(result).lower() + assert "42" in text_content + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_structured_output_streaming(streaming_model): + """Test structured output with streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=streaming_model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" + + +@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") +def test_structured_output_non_streaming(non_streaming_model): + """Test structured output with non-streaming model.""" + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model=non_streaming_model) + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_ollama.py b/tests-integ/test_model_ollama.py new file mode 100644 index 000000000..38b46821d --- /dev/null +++ b/tests-integ/test_model_ollama.py @@ -0,0 +1,47 @@ +import pytest +import requests +from pydantic import BaseModel + +from strands import Agent +from strands.models.ollama import OllamaModel + + +def is_server_available() -> bool: + try: + return requests.get("http://localhost:11434").ok + except requests.exceptions.ConnectionError: + return False + + +@pytest.fixture +def model(): + return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b") + + +@pytest.fixture +def agent(model): + return Agent(model=model) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_agent(agent): + result = agent("Say 'hello world' with no other text") + assert isinstance(result.message["content"][0]["text"], str) + + +@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434") +def test_structured_output(agent): + class Weather(BaseModel): + """Extract the time and weather. + + Time format: HH:MM + Weather: sunny, cloudy, rainy, etc. + """ + + time: str + weather: str + + result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny") + assert isinstance(result, Weather) + assert result.time == "12:00" + assert result.weather == "sunny" diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py new file mode 100644 index 000000000..e0dfcb34b --- /dev/null +++ b/tests-integ/test_model_openai.py @@ -0,0 +1,121 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent, tool + +if "OPENAI_API_KEY" not in os.environ: + pytest.skip(allow_module_level=True, reason="OPENAI_API_KEY environment variable missing") + +from strands.models.openai import OpenAIModel + + +@pytest.fixture(scope="module") +def model(): + return OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) + + +@pytest.fixture(scope="module") +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture(scope="module") +def agent(model, tools): + return Agent(model=model, tools=tools) + + +@pytest.fixture(scope="module") +def weather(): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + +@pytest.fixture(scope="module") +def test_image_path(request): + return request.config.rootpath / "tests-integ" / "test_image.png" + + +def test_agent_invoke(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +def test_tool_returning_images(model, test_image_path): + @tool + def tool_with_image_return(): + with open(test_image_path, "rb") as image_file: + encoded_image = image_file.read() + + return { + "status": "success", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": encoded_image}, + } + }, + ], + } + + agent = Agent(model, tools=[tool_with_image_return]) + # NOTE - this currently fails with: "Invalid 'messages[3]'. Image URLs are only allowed for messages with role + # 'user', but this message with role 'tool' contains an image URL." + # See https://github.com/strands-agents/sdk-python/issues/320 for additional details + agent("Run the the tool and analyze the image") diff --git a/tests-integ/test_stream_agent.py b/tests-integ/test_stream_agent.py new file mode 100644 index 000000000..01f203390 --- /dev/null +++ b/tests-integ/test_stream_agent.py @@ -0,0 +1,70 @@ +""" +Test script for Strands' custom callback handler functionality. +Demonstrates different patterns of callback handling and processing. +""" + +import logging + +from strands import Agent + +logging.getLogger("strands").setLevel(logging.DEBUG) +logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) + + +class ToolCountingCallbackHandler: + def __init__(self): + self.tool_count = 0 + self.message_count = 0 + + def callback_handler(self, **kwargs) -> None: + """ + Custom callback handler that processes and displays different types of events. + + Args: + **kwargs: Callback event data including: + - data: Regular output + - complete: Completion status + - message: Message processing + - current_tool_use: Tool execution + """ + # Extract event data + data = kwargs.get("data", "") + complete = kwargs.get("complete", False) + message = kwargs.get("message", {}) + current_tool_use = kwargs.get("current_tool_use", {}) + + # Handle regular data output + if data: + print(f"🔄 Data: {data}") + + # Handle tool execution events + if current_tool_use: + self.tool_count += 1 + tool_name = current_tool_use.get("name", "") + tool_input = current_tool_use.get("input", {}) + print(f"🛠️ Tool Execution #{self.tool_count}\nTool: {tool_name}\nInput: {tool_input}") + + # Handle message processing + if message: + self.message_count += 1 + print(f"📝 Message #{self.message_count}") + + # Handle completion + if complete: + self.console.print("✨ Callback Complete", style="bold green") + + +def test_basic_interaction(): + """Test basic AGI interaction with custom callback handler.""" + print("\nTesting Basic Interaction") + + # Initialize agent with custom handler + agent = Agent( + callback_handler=ToolCountingCallbackHandler().callback_handler, + load_tools_from_directory=False, + ) + + # Simple prompt to test callbacking + agent("Tell me a short joke from your general knowledge") + + print("\nBasic Interaction Complete") diff --git a/tests-integ/test_summarizing_conversation_manager_integration.py b/tests-integ/test_summarizing_conversation_manager_integration.py new file mode 100644 index 000000000..5dcf49443 --- /dev/null +++ b/tests-integ/test_summarizing_conversation_manager_integration.py @@ -0,0 +1,374 @@ +"""Integration tests for SummarizingConversationManager with actual AI models. + +These tests validate the end-to-end functionality of the SummarizingConversationManager +by testing with real AI models and API calls. They ensure that: + +1. **Real summarization** - Tests that actual model-generated summaries work correctly +2. **Context overflow handling** - Validates real context overflow scenarios and recovery +3. **Tool preservation** - Ensures ToolUse/ToolResult pairs survive real summarization +4. **Message structure** - Verifies real model outputs maintain proper message structure +5. **Agent integration** - Tests that conversation managers work with real Agent workflows + +These tests require API keys (`ANTHROPIC_API_KEY`) and make real API calls, so they should be run sparingly +and may be skipped in CI environments without proper credentials. +""" + +import os + +import pytest + +import strands +from strands import Agent +from strands.agent.conversation_manager import SummarizingConversationManager +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + """Real Anthropic model for integration testing.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + max_tokens=1024, + ) + + +@pytest.fixture +def summarization_model(): + """Separate model instance for summarization to test dedicated agent functionality.""" + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-haiku-20240307", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + """Real tools for testing tool preservation during summarization.""" + + @strands.tool + def get_current_time() -> str: + """Get the current time.""" + return "2024-01-15 14:30:00" + + @strands.tool + def get_weather(city: str) -> str: + """Get weather information for a city.""" + return f"The weather in {city} is sunny and 72°F" + + @strands.tool + def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two numbers.""" + return a + b + + return [get_current_time, get_weather, calculate_sum] + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_summarization_with_context_overflow(model): + """Test that summarization works when context overflow occurs.""" + # Mock conversation data to avoid API calls + greeting_response = """ + Hello! I'm here to help you test your conversation manager. What specifically would you like + me to do as part of this test? I can respond to different types of prompts, maintain context + throughout our conversation, or demonstrate other capabilities of the AI assistant. Just let + me know what aspects you'd like to evaluate. + """.strip() + + computer_history_response = """ + # History of Computers + + The history of computers spans many centuries, evolving from simple calculating tools to + the powerful machines we use today. + + ## Early Computing Devices + - **Ancient abacus** (3000 BCE): One of the earliest computing devices used for arithmetic calculations + - **Pascaline** (1642): Mechanical calculator invented by Blaise Pascal + - **Difference Engine** (1822): Designed by Charles Babbage to compute polynomial functions + - **Analytical Engine**: Babbage's more ambitious design, considered the first general-purpose computer concept + - **Hollerith's Tabulating Machine** (1890s): Used punch cards to process data for the US Census + + ## Early Electronic Computers + - **ENIAC** (1945): First general-purpose electronic computer, weighed 30 tons + - **EDVAC** (1949): Introduced the stored program concept + - **UNIVAC I** (1951): First commercial computer in the United States + """.strip() + + first_computers_response = """ + # The First Computers + + Early computers were dramatically different from today's machines in almost every aspect: + + ## Physical Characteristics + - **Enormous size**: Room-filling or even building-filling machines + - **ENIAC** (1945) weighed about 30 tons, occupied 1,800 square feet + - Consisted of large metal frames or cabinets filled with components + - Required special cooling systems due to excessive heat generation + + ## Technology and Components + - **Vacuum tubes**: Thousands of fragile glass tubes served as switches and amplifiers + - ENIAC contained over 17,000 vacuum tubes + - Generated tremendous heat and frequently failed + - **Memory**: Limited storage using delay lines, cathode ray tubes, or magnetic drums + """.strip() + + messages = [ + {"role": "user", "content": [{"text": "Hello, I'm testing a conversation manager."}]}, + {"role": "assistant", "content": [{"text": greeting_response}]}, + {"role": "user", "content": [{"text": "Can you tell me about the history of computers?"}]}, + {"role": "assistant", "content": [{"text": computer_history_response}]}, + {"role": "user", "content": [{"text": "What were the first computers like?"}]}, + {"role": "assistant", "content": [{"text": first_computers_response}]}, + ] + + # Create agent with very aggressive summarization settings and pre-built conversation + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, # Summarize 50% of messages + preserve_recent_messages=2, # Keep only 2 recent messages + ), + load_tools_from_directory=False, + messages=messages, + ) + + # Should have the pre-built conversation history + initial_message_count = len(agent.messages) + assert initial_message_count == 6 # 3 user + 3 assistant messages + + # Store the last 2 messages before summarization to verify they're preserved + messages_before_summary = agent.messages[-2:].copy() + + # Now manually trigger context reduction to test summarization + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < initial_message_count + # Should have: 1 summary + remaining messages + # With 6 messages, summary_ratio=0.5, preserve_recent_messages=2: + # messages_to_summarize = min(6 * 0.5, 6 - 2) = min(3, 4) = 3 + # So we summarize 3 messages, leaving 3 remaining + 1 summary = 4 total + expected_total_messages = 4 + assert len(agent.messages) == expected_total_messages + + # First message should be the summary (assistant message) + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + assert len(summary_message["content"]) > 0 + + # Verify the summary contains actual text content + summary_content = None + for content_block in summary_message["content"]: + if "text" in content_block: + summary_content = content_block["text"] + break + + assert summary_content is not None + assert len(summary_content) > 50 # Should be a substantial summary + + # Recent messages should be preserved - verify they're exactly the same + recent_messages = agent.messages[-2:] # Last 2 messages should be preserved + assert len(recent_messages) == 2 + assert recent_messages == messages_before_summary, "The last 2 messages should be preserved exactly as they were" + + # Agent should still be functional after summarization + post_summary_result = agent("That's very interesting, thank you!") + assert post_summary_result.message["role"] == "assistant" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_tool_preservation_during_summarization(model, tools): + """Test that ToolUse/ToolResult pairs are preserved during summarization.""" + agent = Agent( + model=model, + tools=tools, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.6, # Aggressive summarization + preserve_recent_messages=3, + ), + load_tools_from_directory=False, + ) + + # Mock conversation with tool usage to avoid API calls and speed up tests + greeting_text = """ + Hello! I'd be happy to help you with calculations. I have access to tools that can + help with math, time, and weather information. What would you like me to calculate for you? + """.strip() + + weather_response = "The weather in San Francisco is sunny and 72°F. Perfect weather for being outside!" + + tool_conversation_data = [ + # Initial greeting exchange + {"role": "user", "content": [{"text": "Hello, can you help me with some calculations?"}]}, + {"role": "assistant", "content": [{"text": greeting_text}]}, + # Time query with tool use/result pair + {"role": "user", "content": [{"text": "What's the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "time_001", "name": "get_current_time", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "time_001", + "content": [{"text": "2024-01-15 14:30:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 2024-01-15 14:30:00."}]}, + # Math calculation with tool use/result pair + {"role": "user", "content": [{"text": "What's 25 + 37?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "calc_001", "name": "calculate_sum", "input": {"a": 25, "b": 37}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "calc_001", "content": [{"text": "62"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "25 + 37 = 62"}]}, + # Weather query with tool use/result pair + {"role": "user", "content": [{"text": "What's the weather like in San Francisco?"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "weather_001", "name": "get_weather", "input": {"city": "San Francisco"}}} + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "weather_001", + "content": [{"text": "The weather in San Francisco is sunny and 72°F"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": weather_response}]}, + ] + + # Add all the mocked conversation messages to avoid real API calls + agent.messages.extend(tool_conversation_data) + + # Force summarization + agent.conversation_manager.reduce_context(agent) + + # Verify tool pairs are still balanced after summarization + post_summary_tool_use_count = 0 + post_summary_tool_result_count = 0 + + for message in agent.messages: + for content in message.get("content", []): + if "toolUse" in content: + post_summary_tool_use_count += 1 + if "toolResult" in content: + post_summary_tool_result_count += 1 + + # Tool uses and results should be balanced (no orphaned tools) + assert post_summary_tool_use_count == post_summary_tool_result_count, ( + "Tool use and tool result counts should be balanced after summarization" + ) + + # Agent should still be able to use tools after summarization + agent("Calculate 15 + 28 for me.") + + # Should have triggered the calculate_sum tool + found_calculation = False + for message in agent.messages[-2:]: # Check recent messages + for content in message.get("content", []): + if "toolResult" in content and "43" in str(content): # 15 + 28 = 43 + found_calculation = True + break + + assert found_calculation, "Tool should still work after summarization" + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_dedicated_summarization_agent(model, summarization_model): + """Test that a dedicated summarization agent works correctly.""" + # Create a dedicated summarization agent + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + load_tools_from_directory=False, + ) + + # Create main agent with dedicated summarization agent + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Mock conversation data for space exploration topic + space_intro_response = """ + Space exploration has been one of humanity's greatest achievements, beginning with early + satellite launches in the 1950s and progressing to human spaceflight, moon landings, and now + commercial space ventures. + """.strip() + + space_milestones_response = """ + Key milestones include Sputnik 1 (1957), Yuri Gagarin's first human spaceflight (1961), + the Apollo 11 moon landing (1969), the Space Shuttle program, and the International Space + Station construction. + """.strip() + + apollo_missions_response = """ + The Apollo program was NASA's lunar exploration program from 1961-1975. Apollo 11 achieved + the first moon landing in 1969 with Neil Armstrong and Buzz Aldrin, followed by five more + successful lunar missions through Apollo 17. + """.strip() + + spacex_response = """ + SpaceX has revolutionized space travel with reusable rockets, reducing launch costs dramatically. + They've achieved crew transportation to the ISS, satellite deployments, and are developing + Starship for Mars missions. + """.strip() + + conversation_pairs = [ + ("I'm interested in learning about space exploration.", space_intro_response), + ("What were the key milestones in space exploration?", space_milestones_response), + ("Tell me about the Apollo missions.", apollo_missions_response), + ("What about modern space exploration with SpaceX?", spacex_response), + ] + + # Manually build the conversation history to avoid real API calls + for user_input, assistant_response in conversation_pairs: + agent.messages.append({"role": "user", "content": [{"text": user_input}]}) + agent.messages.append({"role": "assistant", "content": [{"text": assistant_response}]}) + + # Force summarization + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + # Verify summarization occurred + assert len(agent.messages) < original_length + + # Get the summary message + summary_message = agent.messages[0] + assert summary_message["role"] == "assistant" + + # Extract summary text + summary_text = None + for content in summary_message["content"]: + if "text" in content: + summary_text = content["text"] + break + + assert summary_text diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 7214ac490..7810c9ba0 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Tuple, Type +from collections import deque +from typing import Type from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry @@ -8,12 +9,12 @@ def __init__(self, event_types: list[Type]): self.events_received = [] self.events_types = event_types - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: - return len(self.events_received), iter(self.events_received) + def get_events(self) -> deque[HookEvent]: + return deque(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: for event_type in self.events_types: - registry.add_callback(event_type, self.add_event) + registry.add_callback(event_type, self._add_event) - def add_event(self, event: HookEvent) -> None: + def _add_event(self, event: HookEvent) -> None: self.events_received.append(event) diff --git a/tests/strands/multiagent/__init__.py b/tests/multiagent/__init__.py similarity index 100% rename from tests/strands/multiagent/__init__.py rename to tests/multiagent/__init__.py diff --git a/tests/strands/multiagent/a2a/__init__.py b/tests/multiagent/a2a/__init__.py similarity index 100% rename from tests/strands/multiagent/a2a/__init__.py rename to tests/multiagent/a2a/__init__.py diff --git a/tests/strands/multiagent/a2a/conftest.py b/tests/multiagent/a2a/conftest.py similarity index 90% rename from tests/strands/multiagent/a2a/conftest.py rename to tests/multiagent/a2a/conftest.py index e0061a025..a9730eacb 100644 --- a/tests/strands/multiagent/a2a/conftest.py +++ b/tests/multiagent/a2a/conftest.py @@ -22,10 +22,6 @@ def mock_strands_agent(): mock_result.message = {"content": [{"text": "Test response"}]} agent.return_value = mock_result - # Setup async methods - agent.invoke_async = AsyncMock(return_value=mock_result) - agent.stream_async = AsyncMock(return_value=iter([])) - # Setup mock tool registry mock_tool_registry = MagicMock() mock_tool_registry.get_all_tools_config.return_value = {} diff --git a/tests/multiagent/a2a/test_executor.py b/tests/multiagent/a2a/test_executor.py new file mode 100644 index 000000000..2ac9bed91 --- /dev/null +++ b/tests/multiagent/a2a/test_executor.py @@ -0,0 +1,118 @@ +"""Tests for the StrandsA2AExecutor class.""" + +from unittest.mock import MagicMock + +import pytest +from a2a.types import UnsupportedOperationError +from a2a.utils.errors import ServerError + +from strands.agent.agent_result import AgentResult as SAAgentResult +from strands.multiagent.a2a.executor import StrandsA2AExecutor + + +def test_executor_initialization(mock_strands_agent): + """Test that StrandsA2AExecutor initializes correctly.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + assert executor.agent == mock_strands_agent + + +@pytest.mark.asyncio +async def test_execute_with_text_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes text responses correctly.""" + # Setup mock agent response + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "Test response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify event was enqueued + mock_event_queue.enqueue_event.assert_called_once() + args, _ = mock_event_queue.enqueue_event.call_args + event = args[0] + assert event.parts[0].root.text == "Test response" + + +@pytest.mark.asyncio +async def test_execute_with_multiple_text_blocks(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute processes multiple text blocks correctly.""" + # Setup mock agent response with multiple text blocks + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": [{"text": "First response"}, {"text": "Second response"}]} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify events were enqueued + assert mock_event_queue.enqueue_event.call_count == 2 + + # Check first event + args1, _ = mock_event_queue.enqueue_event.call_args_list[0] + event1 = args1[0] + assert event1.parts[0].root.text == "First response" + + # Check second event + args2, _ = mock_event_queue.enqueue_event.call_args_list[1] + event2 = args2[0] + assert event2.parts[0].root.text == "Second response" + + +@pytest.mark.asyncio +async def test_execute_with_empty_response(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles empty responses correctly.""" + # Setup mock agent response with empty content + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = {"content": []} + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_with_no_message(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that execute handles responses with no message correctly.""" + # Setup mock agent response with no message + mock_result = MagicMock(spec=SAAgentResult) + mock_result.message = None + mock_strands_agent.return_value = mock_result + + # Create executor and call execute + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with correct input + mock_strands_agent.assert_called_once_with("Test input") + + # Verify no events were enqueued + mock_event_queue.enqueue_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify the error is a ServerError containing an UnsupportedOperationError + assert isinstance(excinfo.value.error, UnsupportedOperationError) diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/multiagent/a2a/test_server.py similarity index 96% rename from tests/strands/multiagent/a2a/test_server.py rename to tests/multiagent/a2a/test_server.py index 74f470741..a851c8c7d 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/multiagent/a2a/test_server.py @@ -44,14 +44,6 @@ def test_a2a_agent_initialization_with_custom_values(mock_strands_agent): assert a2a_agent.port == 8080 assert a2a_agent.http_url == "http://127.0.0.1:8080/" assert a2a_agent.version == "1.0.0" - assert a2a_agent.capabilities.streaming is True - - -def test_a2a_agent_initialization_with_streaming_always_enabled(mock_strands_agent): - """Test that A2AAgent always initializes with streaming enabled.""" - a2a_agent = A2AServer(mock_strands_agent) - - assert a2a_agent.capabilities.streaming is True def test_a2a_agent_initialization_with_custom_skills(mock_strands_agent): @@ -479,16 +471,6 @@ def test_serve_with_custom_kwargs(mock_run, mock_strands_agent): assert kwargs["reload"] is True -def test_executor_created_correctly(mock_strands_agent): - """Test that the executor is created correctly.""" - from strands.multiagent.a2a.executor import StrandsA2AExecutor - - a2a_agent = A2AServer(mock_strands_agent) - - assert isinstance(a2a_agent.request_handler.agent_executor, StrandsA2AExecutor) - assert a2a_agent.request_handler.agent_executor.agent == mock_strands_agent - - @patch("uvicorn.run", side_effect=KeyboardInterrupt) def test_serve_handles_keyboard_interrupt(mock_run, mock_strands_agent, caplog): """Test that serve handles KeyboardInterrupt gracefully.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 6460878b3..b49e294e2 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -180,7 +180,7 @@ def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_impor agent = Agent(tools=[tool_decorated, tool_module, tool_imported]) - tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs()) + tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"]) exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] assert tru_tool_names == exp_tool_names @@ -191,12 +191,25 @@ def test_agent__init__tool_loader_dict(tool_module, tool_registry): agent = Agent(tools=[{"name": "tool_module", "path": tool_module}]) - tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs()) + tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"]) exp_tool_names = ["tool_module"] assert tru_tool_names == exp_tool_names +def test_agent__init__invalid_max_parallel_tools(tool_registry): + _ = tool_registry + + with pytest.raises(ValueError): + Agent(max_parallel_tools=0) + + +def test_agent__init__one_max_parallel_tools_succeeds(tool_registry): + _ = tool_registry + + Agent(max_parallel_tools=1) + + def test_agent__init__with_default_model(): agent = Agent() @@ -759,24 +772,6 @@ def test_agent_tool(mock_randint, agent): conversation_manager_spy.apply_management.assert_called_with(agent) -@pytest.mark.asyncio -async def test_agent_tool_in_async_context(mock_randint, agent): - mock_randint.return_value = 123 - - tru_result = agent.tool.tool_decorated(random_string="abcdEfghI123") - exp_result = { - "content": [ - { - "text": "abcdEfghI123", - }, - ], - "status": "success", - "toolUseId": "tooluse_tool_decorated_123", - } - - assert tru_result == exp_result - - def test_agent_tool_user_message_override(agent): agent.tool.tool_decorated(random_string="abcdEfghI123", user_message_override="test override") @@ -843,8 +838,8 @@ def test_agent_init_with_no_model_or_model_id(): assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID -def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) +def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool): + mock_run_tool.return_value = iter([]) @strands.tools.tool(name="system_prompter") def function(system_prompt: str) -> str: @@ -857,18 +852,18 @@ def function(system_prompt: str) -> str: agent.tool.system_prompter(system_prompt="tool prompt") mock_run_tool.assert_called_with( - agent, - { + agent=agent, + tool={ "toolUseId": "tooluse_system_prompter_1", "name": "system_prompter", "input": {"system_prompt": "tool prompt"}, }, - {"system_prompt": "tool prompt"}, + kwargs={"system_prompt": "tool prompt"}, ) -def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool): + mock_run_tool.return_value = iter([]) tool_name = "system-prompter" @@ -884,15 +879,15 @@ def function(system_prompt: str) -> str: # Verify the correct tool was invoked assert mock_run_tool.call_count == 1 - tru_tool_use = mock_run_tool.call_args.args[1] - exp_tool_use = { + tool_call = mock_run_tool.call_args.kwargs.get("tool") + + assert tool_call == { # Note that the tool-use uses the "python safe" name "toolUseId": "tooluse_system_prompter_1", # But the name of the tool is the one in the registry "name": tool_name, "input": {"system_prompt": "tool prompt"}, } - assert tru_tool_use == exp_tool_use def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): @@ -1022,39 +1017,6 @@ async def test_event_loop(*args, **kwargs): mock_callback.assert_has_calls(exp_calls) -@pytest.mark.asyncio -async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, alist): - mock_model.mock_converse.return_value = agenerator( - [ - {"contentBlockDelta": {"delta": {"text": "I see text and an image"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn"}}, - ] - ) - - prompt = [ - {"text": "This is a description of the image:"}, - { - "image": { - "format": "png", - "source": { - "bytes": b"\x89PNG\r\n\x1a\n", - }, - } - }, - ] - - stream = agent.stream_async(prompt) - await alist(stream) - - tru_message = agent.messages - exp_message = [ - {"content": prompt, "role": "user"}, - {"content": [{"text": "I see text and an image"}], "role": "assistant"}, - ] - assert tru_message == exp_message - - @pytest.mark.asyncio async def test_stream_async_passes_kwargs(agent, mock_model, mock_event_loop_cycle, agenerator, alist): mock_model.mock_converse.side_effect = [ @@ -1188,12 +1150,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", agent_name="Strands Agents", - custom_trace_attributes=agent.trace_attributes, - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the result @@ -1222,12 +1184,12 @@ async def test_event_loop(*args, **kwargs): # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( - custom_trace_attributes=agent.trace_attributes, + prompt="test prompt", agent_name="Strands Agents", - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) expected_response = AgentResult( @@ -1260,12 +1222,12 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( - custom_trace_attributes=agent.trace_attributes, + prompt="test prompt", agent_name="Strands Agents", - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the exception @@ -1296,12 +1258,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", agent_name="Strands Agents", - custom_trace_attributes=agent.trace_attributes, - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the exception diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 22f261b15..2953d6ab6 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,17 +1,12 @@ -from unittest.mock import ANY, Mock, call, patch +import unittest.mock +from unittest.mock import call import pytest from pydantic import BaseModel import strands from strands import Agent -from strands.experimental.hooks import ( - AfterToolInvocationEvent, - AgentInitializedEvent, - BeforeToolInvocationEvent, - EndRequestEvent, - StartRequestEvent, -) +from strands.experimental.hooks import AgentInitializedEvent, EndRequestEvent, StartRequestEvent from strands.types.content import Messages from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -19,9 +14,7 @@ @pytest.fixture def hook_provider(): - return MockHookProvider( - [AgentInitializedEvent, StartRequestEvent, EndRequestEvent, AfterToolInvocationEvent, BeforeToolInvocationEvent] - ) + return MockHookProvider([AgentInitializedEvent, StartRequestEvent, EndRequestEvent]) @pytest.fixture @@ -78,7 +71,7 @@ class User(BaseModel): return User(name="Jane Doe", age=30) -@patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks") +@unittest.mock.patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks") def test_agent__init__hooks(mock_invoke_callbacks): """Verify that the AgentInitializedEvent is emitted on Agent construction.""" agent = Agent() @@ -93,21 +86,11 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use): agent("test message") - length, events = hook_provider.get_events() + events = hook_provider.get_events() + assert len(events) == 2 - assert length == 4 - assert next(events) == StartRequestEvent(agent=agent) - assert next(events) == BeforeToolInvocationEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY - ) - assert next(events) == AfterToolInvocationEvent( - agent=agent, - selected_tool=agent_tool, - tool_use=tool_use, - kwargs=ANY, - result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, - ) - assert next(events) == EndRequestEvent(agent=agent) + assert events.popleft() == StartRequestEvent(agent=agent) + assert events.popleft() == EndRequestEvent(agent=agent) @pytest.mark.asyncio @@ -121,28 +104,17 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u async for _ in iterator: pass - length, events = hook_provider.get_events() + events = hook_provider.get_events() + assert len(events) == 2 - assert length == 4 - - assert next(events) == StartRequestEvent(agent=agent) - assert next(events) == BeforeToolInvocationEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY - ) - assert next(events) == AfterToolInvocationEvent( - agent=agent, - selected_tool=agent_tool, - tool_use=tool_use, - kwargs=ANY, - result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, - ) - assert next(events) == EndRequestEvent(agent=agent) + assert events.popleft() == StartRequestEvent(agent=agent) + assert events.popleft() == EndRequestEvent(agent=agent) def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output.""" - agent.model.structured_output = Mock(return_value=agenerator([{"output": user}])) + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) agent.structured_output(type(user), "example prompt") assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] @@ -152,7 +124,7 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output_async.""" - agent.model.structured_output = Mock(return_value=agenerator([{"output": user}])) + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) await agent.structured_output_async(type(user), "example prompt") assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0d35fe28b..1b37fc106 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,17 +1,15 @@ import concurrent import unittest.mock -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import MagicMock, call, patch import pytest import strands import strands.telemetry from strands.event_loop.event_loop import run_tool -from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookProvider, HookRegistry from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException -from tests.fixtures.mock_hook_provider import MockHookProvider @pytest.fixture @@ -35,6 +33,11 @@ def messages(): return [{"role": "user", "content": [{"text": "Hello"}]}] +@pytest.fixture +def tool_config(): + return {"tools": [{"toolSpec": {"name": "tool_for_testing"}}], "toolChoice": {"auto": {}}} + + @pytest.fixture def tool_registry(): return ToolRegistry() @@ -47,8 +50,8 @@ def thread_pool(): @pytest.fixture def tool(tool_registry): - @strands.tool - def tool_for_testing(random_string: str): + @strands.tools.tool + def tool_for_testing(random_string: str) -> str: return random_string tool_registry.register_tool(tool_for_testing) @@ -56,28 +59,6 @@ def tool_for_testing(random_string: str): return tool_for_testing -@pytest.fixture -def tool_times_2(tool_registry): - @strands.tools.tool - def multiply_by_2(x: int) -> int: - return x * 2 - - tool_registry.register_tool(multiply_by_2) - - return multiply_by_2 - - -@pytest.fixture -def tool_times_5(tool_registry): - @strands.tools.tool - def multiply_by_5(x: int) -> int: - return x * 5 - - tool_registry.register_tool(multiply_by_5) - - return multiply_by_5 - - @pytest.fixture def tool_stream(tool): return [ @@ -98,28 +79,16 @@ def tool_stream(tool): @pytest.fixture -def hook_registry(): - return HookRegistry() - - -@pytest.fixture -def hook_provider(hook_registry): - provider = MockHookProvider(event_types=[BeforeToolInvocationEvent, AfterToolInvocationEvent]) - hook_registry.add_hook(provider) - return provider - - -@pytest.fixture -def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry): +def agent(model, system_prompt, messages, tool_config, tool_registry, thread_pool): mock = unittest.mock.Mock(name="agent") mock.config.cache_points = [] mock.model = model mock.system_prompt = system_prompt mock.messages = messages + mock.tool_config = tool_config mock.tool_registry = tool_registry mock.thread_pool = thread_pool mock.event_loop_metrics = EventLoopMetrics() - mock._hooks = hook_registry return mock @@ -291,7 +260,6 @@ async def test_event_loop_cycle_tool_result( system_prompt, messages, tool_stream, - tool_registry, agenerator, alist, ): @@ -347,7 +315,7 @@ async def test_event_loop_cycle_tool_result( }, {"role": "assistant", "content": [{"text": "test text"}]}, ], - tool_registry.get_all_tool_specs(), + [{"name": "tool_for_testing"}], "p1", ) @@ -756,255 +724,31 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a assert recursive_args["kwargs"]["event_loop_parent_cycle_id"] == recursive_args["kwargs"]["event_loop_cycle_id"] -@pytest.mark.asyncio -async def test_run_tool(agent, tool, alist): +def test_run_tool(agent, tool, generate): process = run_tool( - agent, - tool_use={"toolUseId": "tool_use_id", "name": tool.tool_name, "input": {"random_string": "a_string"}}, + agent=agent, + tool={"toolUseId": "tool_use_id", "name": tool.tool_name, "input": {"random_string": "a_string"}}, kwargs={}, ) - tru_result = (await alist(process))[-1] + _, tru_result = generate(process) exp_result = {"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]} assert tru_result == exp_result -@pytest.mark.asyncio -async def test_run_tool_missing_tool(agent, alist): - process = run_tool( - agent, - tool_use={"toolUseId": "missing", "name": "missing", "input": {}}, - kwargs={}, - ) - - tru_events = await alist(process) - exp_events = [ - { - "toolUseId": "missing", - "status": "error", - "content": [{"text": "Unknown tool: missing"}], - }, - ] - - assert tru_events == exp_events - - -@pytest.mark.asyncio -async def test_run_tool_hooks(agent, hook_provider, tool_times_2, alist): - """Test that the correct hooks are emitted.""" - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - kwargs={}, - ) - await alist(process) - - assert len(hook_provider.events_received) == 2 - - assert hook_provider.events_received[0] == BeforeToolInvocationEvent( - agent=agent, - selected_tool=tool_times_2, - tool_use={"input": {"x": 5}, "name": "multiply_by_2", "toolUseId": "test"}, - kwargs=ANY, - ) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=tool_times_2, - exception=None, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - result={"toolUseId": "test", "status": "success", "content": [{"text": "10"}]}, - kwargs=ANY, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hooks_on_missing_tool(agent, hook_provider, alist): - """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, - kwargs={}, - ) - await alist(process) - - assert len(hook_provider.events_received) == 2 - - assert hook_provider.events_received[0] == BeforeToolInvocationEvent( - agent=agent, - selected_tool=None, - tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, - kwargs=ANY, - ) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=None, - tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, - kwargs=ANY, - result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"}, - exception=None, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, hook_provider, alist): - """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" - error = ValueError("Tool failed") - - failing_tool = MagicMock() - failing_tool.tool_name = "failing_tool" - - failing_tool.stream.side_effect = error - - tool_registry.register_tool(failing_tool) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "failing_tool", "input": {"x": 5}}, - kwargs={}, - ) - await alist(process) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=failing_tool, - tool_use={"input": {"x": 5}, "name": "failing_tool", "toolUseId": "test"}, - kwargs=ANY, - result={"content": [{"text": "Error: Tool failed"}], "status": "error", "toolUseId": "test"}, - exception=error, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_before_tool_invocation_updates(agent, tool_times_5, hook_registry, hook_provider, alist): - """Test that modifying properties on BeforeToolInvocation takes effect.""" - - updated_tool_use = {"toolUseId": "modified", "name": "replacement_tool", "input": {"x": 3}} - - def modify_hook(event: BeforeToolInvocationEvent): - # Modify selected_tool to use replacement_tool - event.selected_tool = tool_times_5 - # Modify tool_use to change toolUseId - event.tool_use = updated_tool_use - - hook_registry.add_callback(BeforeToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "original", "name": "original_tool", "input": {"x": 1}}, - kwargs={}, - ) - result = (await alist(process))[-1] - - # Should use replacement_tool (5 * 3 = 15) instead of original_tool (1 * 2 = 2) - assert result == {"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]} - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=tool_times_5, - tool_use=updated_tool_use, - kwargs=ANY, - result={"content": [{"text": "15"}], "status": "success", "toolUseId": "modified"}, - exception=None, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_updates(agent, tool_times_2, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} - - def modify_hook(event: AfterToolInvocationEvent): - # Modify result to change the output - event.result = updated_result - - hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) - +def test_run_tool_missing_tool(agent, generate): process = run_tool( agent=agent, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, + tool={"toolUseId": "missing", "name": "missing", "input": {}}, kwargs={}, ) - result = (await alist(process))[-1] - assert result == updated_result - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} - - def modify_hook(event: AfterToolInvocationEvent): - # Modify result to change the output - event.result = updated_result - - hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, - kwargs={}, - ) - - result = (await alist(process))[-1] - assert result == updated_result - - -@pytest.mark.asyncio -async def test_run_tool_hook_update_result_with_missing_tool(agent, tool_registry, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - @strands.tool - def test_quota(): - return "9" - - tool_registry.register_tool(test_quota) - - class ExampleProvider(HookProvider): - def register_hooks(self, registry: "HookRegistry") -> None: - registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call) - registry.add_callback(AfterToolInvocationEvent, self.after_tool_call) - - def before_tool_call(self, event: BeforeToolInvocationEvent): - if event.tool_use.get("name") == "test_quota": - event.selected_tool = None - - def after_tool_call(self, event: AfterToolInvocationEvent): - if event.tool_use.get("name") == "test_quota": - event.result = { - "status": "error", - "toolUseId": "test", - "content": [{"text": "This tool has been used too many times!"}], - } - - hook_registry.add_hook(ExampleProvider()) - - with patch.object(strands.event_loop.event_loop, "logger") as mock_logger: - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}, - kwargs={}, - ) - - result = (await alist(process))[-1] - - assert result == { + _, tru_result = generate(process) + exp_result = { + "toolUseId": "missing", "status": "error", - "toolUseId": "test", - "content": [{"text": "This tool has been used too many times!"}], + "content": [{"text": "Unknown tool: missing"}], } - assert mock_logger.debug.call_args_list == [ - call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}), - call( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - "test_quota", - "test", - ), - ] + assert tru_result == exp_result diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 44c5b5a8e..7b64264e3 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -549,7 +549,7 @@ async def test_stream_messages(agenerator, alist): mock_model, system_prompt="test prompt", messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_specs=None, + tool_config=None, ) tru_events = await alist(stream) diff --git a/tests/strands/experimental/hooks/test_events.py b/tests/strands/experimental/hooks/test_events.py deleted file mode 100644 index c9c5ecdd7..000000000 --- a/tests/strands/experimental/hooks/test_events.py +++ /dev/null @@ -1,124 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from strands.experimental.hooks.events import ( - AfterToolInvocationEvent, - AgentInitializedEvent, - BeforeToolInvocationEvent, - EndRequestEvent, - StartRequestEvent, -) -from strands.types.tools import ToolResult, ToolUse - - -@pytest.fixture -def agent(): - return Mock() - - -@pytest.fixture -def tool(): - tool = Mock() - tool.tool_name = "test_tool" - return tool - - -@pytest.fixture -def tool_use(): - return ToolUse(name="test_tool", toolUseId="123", input={"param": "value"}) - - -@pytest.fixture -def tool_kwargs(): - return {"param": "value"} - - -@pytest.fixture -def tool_result(): - return ToolResult(content=[{"text": "result"}], status="success", toolUseId="123") - - -@pytest.fixture -def initialized_event(agent): - return AgentInitializedEvent(agent=agent) - - -@pytest.fixture -def start_request_event(agent): - return StartRequestEvent(agent=agent) - - -@pytest.fixture -def end_request_event(agent): - return EndRequestEvent(agent=agent) - - -@pytest.fixture -def before_tool_event(agent, tool, tool_use, tool_kwargs): - return BeforeToolInvocationEvent( - agent=agent, - selected_tool=tool, - tool_use=tool_use, - kwargs=tool_kwargs, - ) - - -@pytest.fixture -def after_tool_event(agent, tool, tool_use, tool_kwargs, tool_result): - return AfterToolInvocationEvent( - agent=agent, - selected_tool=tool, - tool_use=tool_use, - kwargs=tool_kwargs, - result=tool_result, - ) - - -def test_event_should_reverse_callbacks( - initialized_event, - start_request_event, - end_request_event, - before_tool_event, - after_tool_event, -): - # note that we ignore E712 (explicit booleans) for consistency/readability purposes - - assert initialized_event.should_reverse_callbacks == False # noqa: E712 - - assert start_request_event.should_reverse_callbacks == False # noqa: E712 - assert end_request_event.should_reverse_callbacks == True # noqa: E712 - - assert before_tool_event.should_reverse_callbacks == False # noqa: E712 - assert after_tool_event.should_reverse_callbacks == True # noqa: E712 - - -def test_before_tool_invocation_event_can_write_properties(before_tool_event): - new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) - before_tool_event.selected_tool = None # Should not raise - before_tool_event.tool_use = new_tool_use # Should not raise - - -def test_before_tool_invocation_event_cannot_write_properties(before_tool_event): - with pytest.raises(AttributeError, match="Property agent is not writable"): - before_tool_event.agent = Mock() - with pytest.raises(AttributeError, match="Property kwargs is not writable"): - before_tool_event.kwargs = {} - - -def test_after_tool_invocation_event_can_write_properties(after_tool_event): - new_result = ToolResult(content=[{"text": "new result"}], status="success", toolUseId="456") - after_tool_event.result = new_result # Should not raise - - -def test_after_tool_invocation_event_cannot_write_properties(after_tool_event): - with pytest.raises(AttributeError, match="Property agent is not writable"): - after_tool_event.agent = Mock() - with pytest.raises(AttributeError, match="Property selected_tool is not writable"): - after_tool_event.selected_tool = None - with pytest.raises(AttributeError, match="Property tool_use is not writable"): - after_tool_event.tool_use = ToolUse(name="new", toolUseId="456", input={}) - with pytest.raises(AttributeError, match="Property kwargs is not writable"): - after_tool_event.kwargs = {} - with pytest.raises(AttributeError, match="Property exception is not writable"): - after_tool_event.exception = Exception("test") diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index fa1eb8616..66046b7a8 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -11,7 +11,7 @@ @pytest.fixture def anthropic_client(): - with unittest.mock.patch.object(strands.models.anthropic.anthropic, "AsyncAnthropic") as mock_client_cls: + with unittest.mock.patch.object(strands.models.anthropic.anthropic, "Anthropic") as mock_client_cls: yield mock_client_cls.return_value @@ -625,7 +625,7 @@ def test_format_chunk_unknown(model): @pytest.mark.asyncio -async def test_stream(anthropic_client, model, agenerator, alist): +async def test_stream(anthropic_client, model, alist): mock_event_1 = unittest.mock.Mock( type="message_start", dict=lambda: {"type": "message_start"}, @@ -646,9 +646,9 @@ async def test_stream(anthropic_client, model, agenerator, alist): ), ) - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3]) - anthropic_client.messages.stream.return_value = mock_context + mock_stream = unittest.mock.MagicMock() + mock_stream.__iter__.return_value = iter([mock_event_1, mock_event_2, mock_event_3]) + anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream request = {"model": "m1"} response = model.stream(request) @@ -705,7 +705,7 @@ async def test_stream_bad_request_error(anthropic_client, model): @pytest.mark.asyncio -async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist): +async def test_structured_output(anthropic_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] events = [ @@ -749,9 +749,9 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, ), ] - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator(events) - anthropic_client.messages.stream.return_value = mock_context + mock_stream = unittest.mock.MagicMock() + mock_stream.__iter__.return_value = iter(events) + anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream) diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index a93e77593..786ba25b3 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -10,7 +10,7 @@ @pytest.fixture def mistral_client(): - with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls: + with unittest.mock.patch.object(strands.models.mistral, "Mistral") as mock_client_cls: yield mock_client_cls.return_value @@ -436,42 +436,9 @@ def test_format_chunk_unknown(model): model.format_chunk(event) -@pytest.mark.asyncio -async def test_stream(mistral_client, model, agenerator, alist): - mock_event = unittest.mock.Mock( - data=unittest.mock.Mock( - choices=[ - unittest.mock.Mock( - delta=unittest.mock.Mock(content="test stream", tool_calls=None), - finish_reason="end_turn", - ) - ] - ), - usage="usage", - ) - - mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) - - request = {"model": "m1"} - response = model.stream(request) - - tru_events = await alist(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "test stream"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "end_turn"}, - {"chunk_type": "metadata", "data": "usage"}, - ] - assert tru_events == exp_events - - mistral_client.chat.stream_async.assert_called_once_with(**request) - - @pytest.mark.asyncio async def test_stream_rate_limit_error(mistral_client, model, alist): - mistral_client.chat.stream_async.side_effect = Exception("rate limit exceeded (429)") + mistral_client.chat.stream.side_effect = Exception("rate limit exceeded (429)") with pytest.raises(ModelThrottledException, match="rate limit exceeded"): await alist(model.stream({})) @@ -479,7 +446,7 @@ async def test_stream_rate_limit_error(mistral_client, model, alist): @pytest.mark.asyncio async def test_stream_other_error(mistral_client, model, alist): - mistral_client.chat.stream_async.side_effect = Exception("some other error") + mistral_client.chat.stream.side_effect = Exception("some other error") with pytest.raises(Exception, match="some other error"): await alist(model.stream({})) @@ -494,7 +461,7 @@ async def test_structured_output_success(mistral_client, model, test_output_mode mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls[0].function.arguments = '{"name": "John", "age": 30}' - mistral_client.chat.complete_async = unittest.mock.AsyncMock(return_value=mock_response) + mistral_client.chat.complete.return_value = mock_response stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream) @@ -510,7 +477,7 @@ async def test_structured_output_no_tool_calls(mistral_client, model, test_outpu mock_response.choices = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls = None - mistral_client.chat.complete_async = unittest.mock.AsyncMock(return_value=mock_response) + mistral_client.chat.complete.return_value = mock_response prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] @@ -526,7 +493,7 @@ async def test_structured_output_invalid_json(mistral_client, model, test_output mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls[0].function.arguments = "invalid json" - mistral_client.chat.complete_async = unittest.mock.AsyncMock(return_value=mock_response) + mistral_client.chat.complete.return_value = mock_response prompt = [{"role": "user", "content": [{"text": "Extract data"}]}] diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index aeba644a6..c718a602c 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -11,7 +11,7 @@ @pytest.fixture def ollama_client(): - with unittest.mock.patch.object(strands.models.ollama.ollama, "AsyncClient") as mock_client_cls: + with unittest.mock.patch.object(strands.models.ollama, "OllamaClient") as mock_client_cls: yield mock_client_cls.return_value @@ -416,13 +416,13 @@ def test_format_chunk_other(model): @pytest.mark.asyncio -async def test_stream(ollama_client, model, agenerator, alist): +async def test_stream(ollama_client, model, alist): mock_event = unittest.mock.Mock() mock_event.message.tool_calls = None mock_event.message.content = "Hello" mock_event.done_reason = "stop" - ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + ollama_client.chat.return_value = [mock_event] request = {"model": "m1", "messages": [{"role": "user", "content": "Hello"}]} response = model.stream(request) @@ -442,14 +442,14 @@ async def test_stream(ollama_client, model, agenerator, alist): @pytest.mark.asyncio -async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): +async def test_stream_with_tool_calls(ollama_client, model, alist): mock_event = unittest.mock.Mock() mock_tool_call = unittest.mock.Mock() mock_event.message.tool_calls = [mock_tool_call] mock_event.message.content = "I'll calculate that for you" mock_event.done_reason = "stop" - ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + ollama_client.chat.return_value = [mock_event] request = {"model": "m1", "messages": [{"role": "user", "content": "Calculate 2+2"}]} response = model.stream(request) @@ -478,7 +478,7 @@ async def test_structured_output(ollama_client, model, test_output_model_cls, al mock_response = unittest.mock.Mock() mock_response.message.content = '{"name": "John", "age": 30}' - ollama_client.chat = unittest.mock.AsyncMock(return_value=mock_response) + ollama_client.chat.return_value = mock_response stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream) diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py deleted file mode 100644 index 09aa033c5..000000000 --- a/tests/strands/models/test_writer.py +++ /dev/null @@ -1,396 +0,0 @@ -import unittest.mock -from typing import Any, List - -import pytest - -import strands -from strands.models.writer import WriterModel - - -@pytest.fixture -def writer_client_cls(): - with unittest.mock.patch.object(strands.models.writer.writerai, "AsyncClient") as mock_client_cls: - yield mock_client_cls - - -@pytest.fixture -def writer_client(writer_client_cls): - return writer_client_cls.return_value - - -@pytest.fixture -def client_args(): - return {"api_key": "writer_api_key"} - - -@pytest.fixture -def model_id(): - return "palmyra-x5" - - -@pytest.fixture -def stream_options(): - return {"include_usage": True} - - -@pytest.fixture -def model(writer_client, model_id, stream_options, client_args): - _ = writer_client - - return WriterModel(client_args, model_id=model_id, stream_options=stream_options) - - -@pytest.fixture -def messages(): - return [{"role": "user", "content": [{"text": "test"}]}] - - -@pytest.fixture -def system_prompt(): - return "System prompt" - - -def test__init__(writer_client_cls, model_id, stream_options, client_args): - model = WriterModel(client_args=client_args, model_id=model_id, stream_options=stream_options) - - config = model.get_config() - exp_config = {"stream_options": stream_options, "model_id": model_id} - - assert config == exp_config - - writer_client_cls.assert_called_once_with(api_key=client_args.get("api_key", "")) - - -def test_update_config(model): - model.update_config(model_id="palmyra-x4") - - model_id = model.get_config().get("model_id") - - assert model_id == "palmyra-x4" - - -def test_format_request_basic(model, messages, model_id, stream_options): - request = model.format_request(messages) - - exp_request = { - "stream": True, - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream_options": stream_options, - } - - assert request == exp_request - - -def test_format_request_with_params(model, messages, model_id, stream_options): - model.update_config(temperature=0.19) - - request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream_options": stream_options, - "temperature": 0.19, - "stream": True, - } - - assert request == exp_request - - -def test_format_request_with_system_prompt(model, messages, model_id, stream_options, system_prompt): - request = model.format_request(messages, system_prompt=system_prompt) - - exp_request = { - "messages": [ - {"content": "System prompt", "role": "system"}, - {"content": [{"text": "test", "type": "text"}], "role": "user"}, - ], - "model": model_id, - "stream_options": stream_options, - "stream": True, - } - - assert request == exp_request - - -def test_format_request_with_tool_use(model, model_id, stream_options): - messages = [ - { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "c1", - "name": "calculator", - "input": {"expression": "2+2"}, - }, - }, - ], - }, - ] - - request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, - "id": "c1", - "type": "function", - } - ], - }, - ], - "model": model_id, - "stream_options": stream_options, - "stream": True, - } - - assert request == exp_request - - -def test_format_request_with_tool_results(model, model_id, stream_options): - messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "c1", - "status": "success", - "content": [ - {"text": "answer is 4"}, - ], - } - } - ], - } - ] - - request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "tool", - "content": [{"text": "answer is 4", "type": "text"}], - "tool_call_id": "c1", - }, - ], - "model": model_id, - "stream_options": stream_options, - "stream": True, - } - - assert request == exp_request - - -def test_format_request_with_image(model, model_id, stream_options): - messages = [ - { - "role": "user", - "content": [ - { - "image": { - "format": "png", - "source": {"bytes": b"lovely sunny day"}, - }, - }, - ], - }, - ] - - request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "image_url": { - "url": "", - }, - "type": "image_url", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": stream_options, - } - - assert request == exp_request - - -def test_format_request_with_empty_content(model, model_id, stream_options): - messages = [ - { - "role": "user", - "content": [], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [], - "model": model_id, - "stream_options": stream_options, - "stream": True, - } - - assert tru_request == exp_request - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - ({"video": {}}, "video"), - ({"document": {}}, "document"), - ({"reasoningContent": {}}, "reasoningContent"), - ({"other": {}}, "other"), - ], -) -def test_format_request_with_unsupported_type(model, content, content_type): - messages = [ - { - "role": "user", - "content": [content], - }, - ] - - with pytest.raises(TypeError, match=f"content_type=<{content_type}> | unsupported type"): - model.format_request(messages) - - -class AsyncStreamWrapper: - def __init__(self, items: List[Any]): - self.items = items - - def __aiter__(self): - return self._generator() - - async def _generator(self): - for item in self.items: - yield item - - -async def mock_streaming_response(items: List[Any]): - return AsyncStreamWrapper(items) - - -@pytest.mark.asyncio -async def test_stream(writer_client, model, model_id): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) - mock_delta_1 = unittest.mock.Mock( - content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] - ) - - mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_2 = unittest.mock.Mock( - content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] - ) - - mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) - mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)]) - mock_event_4 = unittest.mock.Mock() - - writer_client.chat.chat.return_value = mock_streaming_response( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4] - ) - - request = { - "model": model_id, - "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}], - } - response = model.stream(request) - - events = [event async for event in response] - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_block_start", "data_type": "text"}, - {"chunk_type": "content_block_delta", "data_type": "text", "data": "I'll calculate"}, - {"chunk_type": "content_block_delta", "data_type": "text", "data": "that for you"}, - {"chunk_type": "content_block_stop", "data_type": "text"}, - {"chunk_type": "content_block_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, - {"chunk_type": "content_block_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, - {"chunk_type": "content_block_stop", "data_type": "tool"}, - {"chunk_type": "content_block_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, - {"chunk_type": "content_block_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, - {"chunk_type": "content_block_stop", "data_type": "tool"}, - {"chunk_type": "message_stop", "data": "tool_calls"}, - {"chunk_type": "metadata", "data": mock_event_4.usage}, - ] - - assert events == exp_events - writer_client.chat.chat(**request) - - -@pytest.mark.asyncio -async def test_stream_empty(writer_client, model, model_id): - mock_delta = unittest.mock.Mock(content=None, tool_calls=None) - mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) - mock_event_3 = unittest.mock.Mock() - mock_event_4 = unittest.mock.Mock(usage=mock_usage) - - writer_client.chat.chat.return_value = mock_streaming_response( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4] - ) - - request = {"model": model_id, "messages": [{"role": "user", "content": []}]} - response = model.stream(request) - - events = [event async for event in response] - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_block_start", "data_type": "text"}, - {"chunk_type": "content_block_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - {"chunk_type": "metadata", "data": mock_usage}, - ] - - assert events == exp_events - writer_client.chat.chat.assert_called_once_with(**request) - - -@pytest.mark.asyncio -async def test_stream_with_empty_choices(writer_client, model, model_id): - mock_delta = unittest.mock.Mock(content="content", tool_calls=None) - mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) - - mock_event_1 = unittest.mock.Mock(spec=[]) - mock_event_2 = unittest.mock.Mock(choices=[]) - mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) - mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) - mock_event_5 = unittest.mock.Mock(usage=mock_usage) - - writer_client.chat.chat.return_value = mock_streaming_response( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] - ) - - request = {"model": model_id, "messages": [{"role": "user", "content": ["test"]}]} - response = model.stream(request) - - events = [event async for event in response] - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_block_start", "data_type": "text"}, - {"chunk_type": "content_block_delta", "data_type": "text", "data": "content"}, - {"chunk_type": "content_block_delta", "data_type": "text", "data": "content"}, - {"chunk_type": "content_block_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - {"chunk_type": "metadata", "data": mock_usage}, - ] - - assert events == exp_events - writer_client.chat.chat.assert_called_once_with(**request) diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py deleted file mode 100644 index a956cb769..000000000 --- a/tests/strands/multiagent/a2a/test_executor.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Tests for the StrandsA2AExecutor class.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from a2a.types import UnsupportedOperationError -from a2a.utils.errors import ServerError - -from strands.agent.agent_result import AgentResult as SAAgentResult -from strands.multiagent.a2a.executor import StrandsA2AExecutor - - -def test_executor_initialization(mock_strands_agent): - """Test that StrandsA2AExecutor initializes correctly.""" - executor = StrandsA2AExecutor(mock_strands_agent) - - assert executor.agent == mock_strands_agent - - -@pytest.mark.asyncio -async def test_execute_streaming_mode_with_data_events(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that execute processes data events correctly in streaming mode.""" - - async def mock_stream(user_input): - """Mock streaming function that yields data events.""" - yield {"data": "First chunk"} - yield {"data": "Second chunk"} - yield {"result": MagicMock(spec=SAAgentResult)} - - # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) - - # Create executor - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock the task creation - mock_task = MagicMock() - mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" - mock_request_context.current_task = mock_task - - await executor.execute(mock_request_context, mock_event_queue) - - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") - - # Verify events were enqueued - mock_event_queue.enqueue_event.assert_called() - - -@pytest.mark.asyncio -async def test_execute_streaming_mode_with_result_event(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that execute processes result events correctly in streaming mode.""" - - async def mock_stream(user_input): - """Mock streaming function that yields only result event.""" - yield {"result": MagicMock(spec=SAAgentResult)} - - # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) - - # Create executor - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock the task creation - mock_task = MagicMock() - mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" - mock_request_context.current_task = mock_task - - await executor.execute(mock_request_context, mock_event_queue) - - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") - - # Verify events were enqueued - mock_event_queue.enqueue_event.assert_called() - - -@pytest.mark.asyncio -async def test_execute_streaming_mode_with_empty_data(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that execute handles empty data events correctly in streaming mode.""" - - async def mock_stream(user_input): - """Mock streaming function that yields empty data.""" - yield {"data": ""} - yield {"result": MagicMock(spec=SAAgentResult)} - - # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) - - # Create executor - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock the task creation - mock_task = MagicMock() - mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" - mock_request_context.current_task = mock_task - - await executor.execute(mock_request_context, mock_event_queue) - - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") - - # Verify events were enqueued - mock_event_queue.enqueue_event.assert_called() - - -@pytest.mark.asyncio -async def test_execute_streaming_mode_with_unexpected_event(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that execute handles unexpected events correctly in streaming mode.""" - - async def mock_stream(user_input): - """Mock streaming function that yields unexpected event.""" - yield {"unexpected": "event"} - yield {"result": MagicMock(spec=SAAgentResult)} - - # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) - - # Create executor - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock the task creation - mock_task = MagicMock() - mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" - mock_request_context.current_task = mock_task - - await executor.execute(mock_request_context, mock_event_queue) - - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") - - # Verify events were enqueued - mock_event_queue.enqueue_event.assert_called() - - -@pytest.mark.asyncio -async def test_execute_creates_task_when_none_exists(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that execute creates a new task when none exists.""" - - async def mock_stream(user_input): - """Mock streaming function that yields data events.""" - yield {"data": "Test chunk"} - yield {"result": MagicMock(spec=SAAgentResult)} - - # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) - - # Create executor - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock no existing task - mock_request_context.current_task = None - - with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: - mock_new_task.return_value = MagicMock(id="new-task-id", contextId="new-context-id") - - await executor.execute(mock_request_context, mock_event_queue) - - # Verify task creation and completion events were enqueued - assert mock_event_queue.enqueue_event.call_count >= 1 - mock_new_task.assert_called_once() - - -@pytest.mark.asyncio -async def test_execute_streaming_mode_handles_agent_exception( - mock_strands_agent, mock_request_context, mock_event_queue -): - """Test that execute handles agent exceptions correctly in streaming mode.""" - - # Setup mock agent to raise exception when stream_async is called - mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) - - # Create executor - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock the task creation - mock_task = MagicMock() - mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" - mock_request_context.current_task = mock_task - - with pytest.raises(ServerError): - await executor.execute(mock_request_context, mock_event_queue) - - # Verify agent was called - mock_strands_agent.stream_async.assert_called_once_with("Test input") - - -@pytest.mark.asyncio -async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that cancel raises UnsupportedOperationError.""" - executor = StrandsA2AExecutor(mock_strands_agent) - - with pytest.raises(ServerError) as excinfo: - await executor.cancel(mock_request_context, mock_event_queue) - - # Verify the error is a ServerError containing an UnsupportedOperationError - assert isinstance(excinfo.value.error, UnsupportedOperationError) - - -@pytest.mark.asyncio -async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that _handle_agent_result handles None result correctly.""" - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock the task creation - mock_task = MagicMock() - mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" - mock_request_context.current_task = mock_task - - # Mock TaskUpdater - mock_updater = MagicMock() - mock_updater.complete = AsyncMock() - mock_updater.add_artifact = AsyncMock() - - # Call _handle_agent_result with None - await executor._handle_agent_result(None, mock_updater) - - # Verify completion was called - mock_updater.complete.assert_called_once() - - -@pytest.mark.asyncio -async def test_handle_agent_result_with_result_but_no_message( - mock_strands_agent, mock_request_context, mock_event_queue -): - """Test that _handle_agent_result handles result with no message correctly.""" - executor = StrandsA2AExecutor(mock_strands_agent) - - # Mock the task creation - mock_task = MagicMock() - mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" - mock_request_context.current_task = mock_task - - # Mock TaskUpdater - mock_updater = MagicMock() - mock_updater.complete = AsyncMock() - mock_updater.add_artifact = AsyncMock() - - # Create result with no message - mock_result = MagicMock(spec=SAAgentResult) - mock_result.message = None - - # Call _handle_agent_result - await executor._handle_agent_result(mock_result, mock_updater) - - # Verify completion was called - mock_updater.complete.assert_called_once() diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 7623085f2..2fcd98c39 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -276,17 +276,17 @@ def test_start_agent_span(mock_tracer): mock_span = mock.MagicMock() mock_tracer.start_span.return_value = mock_span - content = [{"text": "test prompt"}] + prompt = "What's the weather today?" model_id = "test-model" tools = [{"name": "weather_tool"}] custom_attrs = {"custom_attr": "value"} span = tracer.start_agent_span( - custom_trace_attributes=custom_attrs, + prompt=prompt, agent_name="WeatherAgent", - message={"content": content, "role": "user"}, model_id=model_id, tools=tools, + custom_trace_attributes=custom_attrs, ) mock_tracer.start_span.assert_called_once() @@ -295,7 +295,7 @@ def test_start_agent_span(mock_tracer): mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) mock_span.set_attribute.assert_any_call("custom_attr", "value") - mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)}) + mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": prompt}) assert span is not None diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index b00bf4cc9..eba4ad6c2 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -57,14 +57,12 @@ def test_tool_spec_without_description(mock_mcp_tool, mock_mcp_client): assert tool_spec["description"] == "Tool which performs test_tool" -@pytest.mark.asyncio -async def test_stream(mcp_agent_tool, mock_mcp_client, alist): +def test_invoke(mcp_agent_tool, mock_mcp_client): tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} - tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) - exp_events = [mock_mcp_client.call_tool_sync.return_value] + result = mcp_agent_tool.invoke(tool_use) - assert tru_events == exp_events mock_mcp_client.call_tool_sync.assert_called_once_with( tool_use_id="test-123", name="test_tool", arguments={"param": "value"} ) + assert result == mock_mcp_client.call_tool_sync.return_value diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 52a9282e0..50333474c 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -5,139 +5,14 @@ from typing import Any, Dict, Optional, Union from unittest.mock import MagicMock -import pytest - -import strands +from strands.tools.decorator import tool from strands.types.tools import ToolUse -@pytest.fixture(scope="module") -def identity_invoke(): - @strands.tool - def identity(a: int): - return a - - return identity - - -@pytest.fixture(scope="module") -def identity_invoke_async(): - @strands.tool - async def identity(a: int): - return a - - return identity - - -@pytest.fixture -def identity_tool(request): - return request.getfixturevalue(request.param) - - -def test__init__invalid_name(): - with pytest.raises(ValueError, match="Tool name must be a string"): - - @strands.tool(name=0) - def identity(a): - return a - - -def test_tool_func_not_decorated(): - def identity(a: int): - return a - - tool = strands.tool(func=identity, name="identity") - - tru_name = tool._tool_func.__name__ - exp_name = "identity" - - assert tru_name == exp_name - - -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_tool_name(identity_tool): - tru_name = identity_tool.tool_name - exp_name = "identity" - - assert tru_name == exp_name - - -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_tool_spec(identity_tool): - tru_spec = identity_tool.tool_spec - exp_spec = { - "name": "identity", - "description": "identity", - "inputSchema": { - "json": { - "type": "object", - "properties": { - "a": { - "description": "Parameter a", - "type": "integer", - }, - }, - "required": ["a"], - } - }, - } - assert tru_spec == exp_spec - - -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_tool_type(identity_tool): - tru_type = identity_tool.tool_type - exp_type = "function" - - assert tru_type == exp_type - - -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_supports_hot_reload(identity_tool): - assert identity_tool.supports_hot_reload - - -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_get_display_properties(identity_tool): - tru_properties = identity_tool.get_display_properties() - exp_properties = { - "Function": "identity", - "Name": "identity", - "Type": "function", - } - - assert tru_properties == exp_properties - - -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -@pytest.mark.asyncio -async def test_stream(identity_tool, alist): - stream = identity_tool.stream({"toolUseId": "t1", "input": {"a": 2}}, {}) - - tru_events = await alist(stream) - exp_events = [{"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]}] - - assert tru_events == exp_events - - -@pytest.mark.asyncio -async def test_stream_with_agent(alist): - @strands.tool - def identity(a: int, agent: dict = None): - return a, agent - - stream = identity.stream({"input": {"a": 2}}, {"agent": {"state": 1}}) - - tru_events = await alist(stream) - exp_events = [{"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}] - assert tru_events == exp_events - - -@pytest.mark.asyncio -async def test_basic_tool_creation(alist): +def test_basic_tool_creation(): """Test basic tool decorator functionality.""" - @strands.tool + @tool def test_tool(param1: str, param2: int) -> str: """Test tool function. @@ -175,21 +50,20 @@ def test_tool(param1: str, param2: int) -> str: # Test actual usage tool_use = {"toolUseId": "test-id", "input": {"param1": "hello", "param2": 42}} - stream = test_tool.stream(tool_use, {}) - - tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] - assert tru_events == exp_events + result = test_tool.invoke(tool_use) + assert result["toolUseId"] == "test-id" + assert result["status"] == "success" + assert result["content"][0]["text"] == "Result: hello 42" # Make sure these are set properly assert test_tool.__wrapped__ is not None - assert test_tool.__doc__ == test_tool._tool_func.__doc__ + assert test_tool.__doc__ == test_tool.original_function.__doc__ def test_tool_with_custom_name_description(): """Test tool decorator with custom name and description.""" - @strands.tool(name="custom_name", description="Custom description") + @tool(name="custom_name", description="Custom description") def test_tool(param: str) -> str: return f"Result: {param}" @@ -199,11 +73,10 @@ def test_tool(param: str) -> str: assert spec["description"] == "Custom description" -@pytest.mark.asyncio -async def test_tool_with_optional_params(alist): +def test_tool_with_optional_params(): """Test tool decorator with optional parameters.""" - @strands.tool + @tool def test_tool(required: str, optional: Optional[int] = None) -> str: """Test with optional param. @@ -224,25 +97,23 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: # Test with only required param tool_use = {"toolUseId": "test-id", "input": {"required": "hello"}} - stream = test_tool.stream(tool_use, {}) - tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}] - assert tru_events == exp_events + result = test_tool.invoke(tool_use) + assert result["status"] == "success" + assert result["content"][0]["text"] == "Result: hello" # Test with both params tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "optional": 42}} - stream = test_tool.stream(tool_use, {}) - tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + result = test_tool.invoke(tool_use) + assert result["status"] == "success" + assert result["content"][0]["text"] == "Result: hello 42" -@pytest.mark.asyncio -async def test_tool_error_handling(alist): +def test_tool_error_handling(): """Test error handling in tool decorator.""" - @strands.tool + @tool def test_tool(required: str) -> str: """Test tool function.""" if required == "error": @@ -251,9 +122,8 @@ def test_tool(required: str) -> str: # Test with missing required param tool_use = {"toolUseId": "test-id", "input": {}} - stream = test_tool.stream(tool_use, {}) - result = (await alist(stream))[-1] + result = test_tool.invoke(tool_use) assert result["status"] == "error" assert "validation error for test_tooltool\nrequired\n" in result["content"][0]["text"].lower(), ( "Validation error should indicate which argument is missing" @@ -261,9 +131,8 @@ def test_tool(required: str) -> str: # Test with exception in tool function tool_use = {"toolUseId": "test-id", "input": {"required": "error"}} - stream = test_tool.stream(tool_use, {}) - result = (await alist(stream))[-1] + result = test_tool.invoke(tool_use) assert result["status"] == "error" assert "test error" in result["content"][0]["text"].lower(), ( "Runtime error should contain the original error message" @@ -273,7 +142,7 @@ def test_tool(required: str) -> str: def test_type_handling(): """Test handling of basic parameter types.""" - @strands.tool + @tool def test_tool( str_param: str, int_param: int, @@ -293,12 +162,11 @@ def test_tool( assert props["bool_param"]["type"] == "boolean" -@pytest.mark.asyncio -async def test_agent_parameter_passing(alist): +def test_agent_parameter_passing(): """Test passing agent parameter to tool function.""" mock_agent = MagicMock() - @strands.tool + @tool def test_tool(param: str, agent=None) -> str: """Test tool with agent parameter.""" if agent: @@ -308,74 +176,85 @@ def test_tool(param: str, agent=None) -> str: tool_use = {"toolUseId": "test-id", "input": {"param": "test"}} # Test without agent - stream = test_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = test_tool.invoke(tool_use) assert result["content"][0]["text"] == "Param: test" # Test with agent - stream = test_tool.stream(tool_use, {"agent": mock_agent}) + result = test_tool.invoke(tool_use, agent=mock_agent) + assert "Agent:" in result["content"][0]["text"] + assert "test" in result["content"][0]["text"] + + +def test_agent_backwards_compatability_parameter_passing(): + """Test passing agent parameter to tool function.""" + mock_agent = MagicMock() + + @tool + def test_tool(param: str, agent=None) -> str: + """Test tool with agent parameter.""" + if agent: + return f"Agent: {agent}, Param: {param}" + return f"Param: {param}" - result = (await alist(stream))[-1] + tool_use = {"toolUseId": "test-id", "input": {"param": "test"}} + + # Test without agent + result = test_tool(tool_use) + assert result["content"][0]["text"] == "Param: test" + + # Test with agent + result = test_tool(tool_use, agent=mock_agent) assert "Agent:" in result["content"][0]["text"] assert "test" in result["content"][0]["text"] -@pytest.mark.asyncio -async def test_tool_decorator_with_different_return_values(alist): +def test_tool_decorator_with_different_return_values(): """Test tool decorator with different return value types.""" # Test with dict return that follows ToolResult format - @strands.tool + @tool def dict_return_tool(param: str) -> dict: """Test tool that returns a dict in ToolResult format.""" return {"status": "success", "content": [{"text": f"Result: {param}"}]} # Test with non-dict return - @strands.tool + @tool def string_return_tool(param: str) -> str: """Test tool that returns a string.""" return f"Result: {param}" # Test with None return - @strands.tool + @tool def none_return_tool(param: str) -> None: """Test tool that returns None.""" pass # Test the dict return - should preserve dict format but add toolUseId tool_use: ToolUse = {"toolUseId": "test-id", "input": {"param": "test"}} - stream = dict_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = dict_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "Result: test" assert result["toolUseId"] == "test-id" # Test the string return - should wrap in standard format - stream = string_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = string_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "Result: test" # Test None return - should still create valid ToolResult with "None" text - stream = none_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = none_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "None" -@pytest.mark.asyncio -async def test_class_method_handling(alist): +def test_class_method_handling(): """Test handling of class methods with tool decorator.""" class TestClass: def __init__(self, prefix): self.prefix = prefix - @strands.tool + @tool def test_method(self, param: str) -> str: """Test method. @@ -398,15 +277,12 @@ def test_method(self, param: str) -> str: # Test tool-style call tool_use = {"toolUseId": "test-id", "input": {"param": "tool-value"}} - stream = instance.test_method.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = instance.test_method.invoke(tool_use) assert "Test: tool-value" in result["content"][0]["text"] -@pytest.mark.asyncio -async def test_tool_as_adhoc_field(alist): - @strands.tool +def test_tool_as_adhoc_field(): + @tool def test_method(param: str) -> str: return f"param: {param}" @@ -418,18 +294,16 @@ class MyThing: ... result = instance.field("example") assert result == "param: example" - stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) - result2 = (await alist(stream))[-1] + result2 = instance.field.invoke({"toolUseId": "test-id", "input": {"param": "example"}}) assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} -@pytest.mark.asyncio -async def test_tool_as_instance_field(alist): +def test_tool_as_instance_field(): """Make sure that class instance properties operate correctly.""" class MyThing: def __init__(self): - @strands.tool + @tool def test_method(param: str) -> str: return f"param: {param}" @@ -440,16 +314,14 @@ def test_method(param: str) -> str: result = instance.field("example") assert result == "param: example" - stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) - result2 = (await alist(stream))[-1] + result2 = instance.field.invoke({"toolUseId": "test-id", "input": {"param": "example"}}) assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} -@pytest.mark.asyncio -async def test_default_parameter_handling(alist): +def test_default_parameter_handling(): """Test handling of parameters with default values.""" - @strands.tool + @tool def tool_with_defaults(required: str, optional: str = "default", number: int = 42) -> str: """Test tool with multiple default parameters. @@ -469,46 +341,38 @@ def tool_with_defaults(required: str, optional: str = "default", number: int = 4 # Call with just required parameter tool_use = {"toolUseId": "test-id", "input": {"required": "hello"}} - stream = tool_with_defaults.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = tool_with_defaults.invoke(tool_use) assert result["content"][0]["text"] == "hello default 42" # Call with some but not all optional parameters tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "number": 100}} - stream = tool_with_defaults.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = tool_with_defaults.invoke(tool_use) assert result["content"][0]["text"] == "hello default 100" -@pytest.mark.asyncio -async def test_empty_tool_use_handling(alist): +def test_empty_tool_use_handling(): """Test handling of empty tool use dictionaries.""" - @strands.tool + @tool def test_tool(required: str) -> str: """Test with a required parameter.""" return f"Got: {required}" # Test with completely empty tool use - stream = test_tool.stream({}, {}) - result = (await alist(stream))[-1] + result = test_tool.invoke({}) assert result["status"] == "error" assert "unknown" in result["toolUseId"] # Test with missing input - stream = test_tool.stream({"toolUseId": "test-id"}, {}) - result = (await alist(stream))[-1] + result = test_tool.invoke({"toolUseId": "test-id"}) assert result["status"] == "error" assert "test-id" in result["toolUseId"] -@pytest.mark.asyncio -async def test_traditional_function_call(alist): +def test_traditional_function_call(): """Test that decorated functions can still be called normally.""" - @strands.tool + @tool def add_numbers(a: int, b: int) -> int: """Add two numbers. @@ -524,18 +388,15 @@ def add_numbers(a: int, b: int) -> int: # Call through tool interface tool_use = {"toolUseId": "test-id", "input": {"a": 2, "b": 3}} - stream = add_numbers.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = add_numbers.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "5" -@pytest.mark.asyncio -async def test_multiple_default_parameters(alist): +def test_multiple_default_parameters(): """Test handling of multiple parameters with default values.""" - @strands.tool + @tool def multi_default_tool( required_param: str, optional_str: str = "default_str", @@ -560,9 +421,7 @@ def multi_default_tool( # Test calling with only required parameter tool_use = {"toolUseId": "test-id", "input": {"required_param": "hello"}} - stream = multi_default_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = multi_default_tool.invoke(tool_use) assert result["status"] == "success" assert "hello, default_str, 42, True, 3.14" in result["content"][0]["text"] @@ -571,18 +430,15 @@ def multi_default_tool( "toolUseId": "test-id", "input": {"required_param": "hello", "optional_int": 100, "optional_float": 2.718}, } - stream = multi_default_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = multi_default_tool.invoke(tool_use) assert "hello, default_str, 100, True, 2.718" in result["content"][0]["text"] -@pytest.mark.asyncio -async def test_return_type_validation(alist): +def test_return_type_validation(): """Test that return types are properly handled and validated.""" # Define tool with explicitly typed return - @strands.tool + @tool def int_return_tool(param: str) -> int: """Tool that returns an integer. @@ -598,9 +454,7 @@ def int_return_tool(param: str) -> int: # Test with return that matches declared type tool_use = {"toolUseId": "test-id", "input": {"param": "valid"}} - stream = int_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = int_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "42" @@ -608,22 +462,18 @@ def int_return_tool(param: str) -> int: # Note: This should still work because Python doesn't enforce return types at runtime # but the function will return a string instead of an int tool_use = {"toolUseId": "test-id", "input": {"param": "invalid_type"}} - stream = int_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = int_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "not an int" # Test with None return from a non-None return type tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} - stream = int_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = int_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "None" # Define tool with Union return type - @strands.tool + @tool def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: """Tool with Union return type. @@ -639,32 +489,25 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: # Test with each possible return type in the Union tool_use = {"toolUseId": "test-id", "input": {"param": "dict"}} - stream = union_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = union_return_tool.invoke(tool_use) assert result["status"] == "success" assert "{'key': 'value'}" in result["content"][0]["text"] or '{"key": "value"}' in result["content"][0]["text"] tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} - stream = union_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = union_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "string result" tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} - stream = union_return_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = union_return_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "None" -@pytest.mark.asyncio -async def test_tool_with_no_parameters(alist): +def test_tool_with_no_parameters(): """Test a tool that doesn't require any parameters.""" - @strands.tool + @tool def no_params_tool() -> str: """A tool that doesn't need any parameters.""" return "Success - no parameters needed" @@ -677,9 +520,7 @@ def no_params_tool() -> str: # Test tool use call tool_use = {"toolUseId": "test-id", "input": {}} - stream = no_params_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = no_params_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "Success - no parameters needed" @@ -688,11 +529,10 @@ def no_params_tool() -> str: assert direct_result == "Success - no parameters needed" -@pytest.mark.asyncio -async def test_complex_parameter_types(alist): +def test_complex_parameter_types(): """Test handling of complex parameter types like nested dictionaries.""" - @strands.tool + @tool def complex_type_tool(config: Dict[str, Any]) -> str: """Tool with complex parameter type. @@ -706,9 +546,7 @@ def complex_type_tool(config: Dict[str, Any]) -> str: # Call via tool use tool_use = {"toolUseId": "test-id", "input": {"config": nested_dict}} - stream = complex_type_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = complex_type_tool.invoke(tool_use) assert result["status"] == "success" assert "Got config with 3 keys" in result["content"][0]["text"] @@ -717,11 +555,10 @@ def complex_type_tool(config: Dict[str, Any]) -> str: assert direct_result == "Got config with 3 keys" -@pytest.mark.asyncio -async def test_custom_tool_result_handling(alist): +def test_custom_tool_result_handling(): """Test that a function returning a properly formatted tool result dictionary is handled correctly.""" - @strands.tool + @tool def custom_result_tool(param: str) -> Dict[str, Any]: """Tool that returns a custom tool result dictionary. @@ -736,10 +573,9 @@ def custom_result_tool(param: str) -> Dict[str, Any]: # Test via tool use tool_use = {"toolUseId": "custom-id", "input": {"param": "test"}} - stream = custom_result_tool.stream(tool_use, {}) + result = custom_result_tool.invoke(tool_use) # The wrapper should preserve our format and just add the toolUseId - result = (await alist(stream))[-1] assert result["status"] == "success" assert result["toolUseId"] == "custom-id" assert len(result["content"]) == 2 @@ -751,7 +587,7 @@ def custom_result_tool(param: str) -> Dict[str, Any]: def test_docstring_parsing(): """Test that function docstring is correctly parsed into tool spec.""" - @strands.tool + @tool def documented_tool(param1: str, param2: int = 10) -> str: """This is the summary line. @@ -787,11 +623,10 @@ def documented_tool(param1: str, param2: int = 10) -> str: assert "param2" not in schema["required"] -@pytest.mark.asyncio -async def test_detailed_validation_errors(alist): +def test_detailed_validation_errors(): """Test detailed error messages for various validation failures.""" - @strands.tool + @tool def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: """Tool with various parameter types for validation testing. @@ -811,9 +646,7 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: "bool_param": True, }, } - stream = validation_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = validation_tool.invoke(tool_use) assert result["status"] == "error" assert "int_param" in result["content"][0]["text"] @@ -826,20 +659,17 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: "bool_param": True, }, } - stream = validation_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = validation_tool.invoke(tool_use) assert result["status"] == "error" assert "int_param" in result["content"][0]["text"] -@pytest.mark.asyncio -async def test_tool_complex_validation_edge_cases(alist): +def test_tool_complex_validation_edge_cases(): """Test validation of complex schema edge cases.""" from typing import Any, Dict, Union # Define a tool with a complex anyOf type that could trigger edge case handling - @strands.tool + @tool def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: """Tool with complex anyOf structure. @@ -850,38 +680,31 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: # Test with None value tool_use = {"toolUseId": "test-id", "input": {"param": None}} - stream = edge_case_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = edge_case_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "None" # Test with empty dict tool_use = {"toolUseId": "test-id", "input": {"param": {}}} - stream = edge_case_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = edge_case_tool.invoke(tool_use) assert result["status"] == "success" assert result["content"][0]["text"] == "{}" # Test with a complex nested dictionary nested_dict = {"key1": {"nested": [1, 2, 3]}, "key2": None} tool_use = {"toolUseId": "test-id", "input": {"param": nested_dict}} - stream = edge_case_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = edge_case_tool.invoke(tool_use) assert result["status"] == "success" assert "key1" in result["content"][0]["text"] assert "nested" in result["content"][0]["text"] -@pytest.mark.asyncio -async def test_tool_method_detection_errors(alist): +def test_tool_method_detection_errors(): """Test edge cases in method detection logic.""" # Define a class with a decorated method to test exception handling in method detection class TestClass: - @strands.tool + @tool def test_method(self, param: str) -> str: """Test method that should be called properly despite errors. @@ -917,14 +740,12 @@ def test_method(self): assert instance.test_method("test") == "Method Got: test" # Test direct function call - stream = instance.test_method.stream({"toolUseId": "test-id", "input": {"param": "direct"}}, {}) - - direct_result = (await alist(stream))[-1] + direct_result = instance.test_method.invoke({"toolUseId": "test-id", "input": {"param": "direct"}}) assert direct_result["status"] == "success" assert direct_result["content"][0]["text"] == "Method Got: direct" # Create a standalone function to test regular function calls - @strands.tool + @tool def standalone_tool(p1: str, p2: str = "default") -> str: """Standalone tool for testing. @@ -939,18 +760,15 @@ def standalone_tool(p1: str, p2: str = "default") -> str: assert result == "Standalone: param1, param2" # And that it works with tool use call too - stream = standalone_tool.stream({"toolUseId": "test-id", "input": {"p1": "value1"}}, {}) - - tool_use_result = (await alist(stream))[-1] + tool_use_result = standalone_tool.invoke({"toolUseId": "test-id", "input": {"p1": "value1"}}) assert tool_use_result["status"] == "success" assert tool_use_result["content"][0]["text"] == "Standalone: value1, default" -@pytest.mark.asyncio -async def test_tool_general_exception_handling(alist): +def test_tool_general_exception_handling(): """Test handling of arbitrary exceptions in tool execution.""" - @strands.tool + @tool def failing_tool(param: str) -> str: """Tool that raises different exception types. @@ -971,9 +789,7 @@ def failing_tool(param: str) -> str: error_types = ["value_error", "type_error", "attribute_error", "key_error"] for error_type in error_types: tool_use = {"toolUseId": "test-id", "input": {"param": error_type}} - stream = failing_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = failing_tool.invoke(tool_use) assert result["status"] == "error" error_message = result["content"][0]["text"] @@ -990,12 +806,11 @@ def failing_tool(param: str) -> str: assert "key_name" in error_message -@pytest.mark.asyncio -async def test_tool_with_complex_anyof_schema(alist): +def test_tool_with_complex_anyof_schema(): """Test handling of complex anyOf structures in the schema.""" from typing import Any, Dict, List, Union - @strands.tool + @tool def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]) -> str: """Tool with a complex Union type that creates anyOf in schema. @@ -1006,33 +821,25 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] # Test with a list tool_use = {"toolUseId": "test-id", "input": {"union_param": [1, 2, 3]}} - stream = complex_schema_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = complex_schema_tool.invoke(tool_use) assert result["status"] == "success" assert "list: [1, 2, 3]" in result["content"][0]["text"] # Test with a dict tool_use = {"toolUseId": "test-id", "input": {"union_param": {"key": "value"}}} - stream = complex_schema_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = complex_schema_tool.invoke(tool_use) assert result["status"] == "success" assert "dict:" in result["content"][0]["text"] assert "key" in result["content"][0]["text"] # Test with a string tool_use = {"toolUseId": "test-id", "input": {"union_param": "test_string"}} - stream = complex_schema_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = complex_schema_tool.invoke(tool_use) assert result["status"] == "success" assert "str: test_string" in result["content"][0]["text"] # Test with None tool_use = {"toolUseId": "test-id", "input": {"union_param": None}} - stream = complex_schema_tool.stream(tool_use, {}) - - result = (await alist(stream))[-1] + result = complex_schema_tool.invoke(tool_use) assert result["status"] == "success" assert "NoneType: None" in result["content"][0]["text"] diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index 04d4ea657..d3e934acc 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -1,3 +1,4 @@ +import concurrent import unittest.mock import uuid @@ -15,9 +16,9 @@ def moto_autouse(moto_env): @pytest.fixture def tool_handler(request): - async def handler(tool_use): + def handler(tool_use): yield {"event": "abc"} - yield { + return { **params, "toolUseId": tool_use["toolUseId"], } @@ -64,14 +65,18 @@ def cycle_trace(): return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") -@pytest.mark.asyncio -async def test_run_tools( +@pytest.fixture +def thread_pool(request): + return concurrent.futures.ThreadPoolExecutor(max_workers=1) + + +def test_run_tools( tool_handler, tool_uses, event_loop_metrics, invalid_tool_use_ids, cycle_trace, - alist, + thread_pool, ): tool_results = [] @@ -82,11 +87,14 @@ async def test_run_tools( invalid_tool_use_ids, tool_results, cycle_trace, + thread_pool, ) - tru_events = await alist(stream) - exp_events = [ - {"event": "abc"}, + tru_events = list(stream) + exp_events = [{"event": "abc"}] + + tru_results = tool_results + exp_results = [ { "content": [ { @@ -98,21 +106,17 @@ async def test_run_tools( }, ] - tru_results = tool_results - exp_results = [exp_events[-1]] - assert tru_events == exp_events and tru_results == exp_results @pytest.mark.parametrize("invalid_tool_use_ids", [["t1"]], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_invalid_tool( +def test_run_tools_invalid_tool( tool_handler, tool_uses, event_loop_metrics, invalid_tool_use_ids, cycle_trace, - alist, + thread_pool, ): tool_results = [] @@ -123,8 +127,9 @@ async def test_run_tools_invalid_tool( invalid_tool_use_ids, tool_results, cycle_trace, + thread_pool, ) - await alist(stream) + list(stream) tru_results = tool_results exp_results = [] @@ -133,14 +138,13 @@ async def test_run_tools_invalid_tool( @pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_failed_tool( +def test_run_tools_failed_tool( tool_handler, tool_uses, event_loop_metrics, invalid_tool_use_ids, cycle_trace, - alist, + thread_pool, ): tool_results = [] @@ -151,8 +155,9 @@ async def test_run_tools_failed_tool( invalid_tool_use_ids, tool_results, cycle_trace, + thread_pool, ) - await alist(stream) + list(stream) tru_results = tool_results exp_results = [ @@ -191,14 +196,12 @@ async def test_run_tools_failed_tool( ], indirect=True, ) -@pytest.mark.asyncio -async def test_run_tools_sequential( +def test_run_tools_sequential( tool_handler, tool_uses, event_loop_metrics, invalid_tool_use_ids, cycle_trace, - alist, ): tool_results = [] @@ -211,7 +214,7 @@ async def test_run_tools_sequential( cycle_trace, None, # tool_pool ) - await alist(stream) + list(stream) tru_results = tool_results exp_results = [ @@ -278,8 +281,7 @@ def test_validate_and_prepare_tools(): @unittest.mock.patch("strands.tools.executor.get_tracer") -@pytest.mark.asyncio -async def test_run_tools_creates_and_ends_span_on_success( +def test_run_tools_creates_and_ends_span_on_success( mock_get_tracer, tool_handler, tool_uses, @@ -287,7 +289,7 @@ async def test_run_tools_creates_and_ends_span_on_success( event_loop_metrics, invalid_tool_use_ids, cycle_trace, - alist, + thread_pool, ): """Test that run_tools creates and ends a span on successful execution.""" # Setup mock tracer and span @@ -310,8 +312,9 @@ async def test_run_tools_creates_and_ends_span_on_success( tool_results, cycle_trace, parent_span, + thread_pool, ) - await alist(stream) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -326,15 +329,14 @@ async def test_run_tools_creates_and_ends_span_on_success( @unittest.mock.patch("strands.tools.executor.get_tracer") @pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_creates_and_ends_span_on_failure( +def test_run_tools_creates_and_ends_span_on_failure( mock_get_tracer, tool_handler, tool_uses, event_loop_metrics, invalid_tool_use_ids, cycle_trace, - alist, + thread_pool, ): """Test that run_tools creates and ends a span on tool failure.""" # Setup mock tracer and span @@ -357,8 +359,9 @@ async def test_run_tools_creates_and_ends_span_on_failure( tool_results, cycle_trace, parent_span, + thread_pool, ) - await alist(stream) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -392,16 +395,16 @@ async def test_run_tools_creates_and_ends_span_on_failure( ], indirect=True, ) -@pytest.mark.asyncio -async def test_run_tools_concurrent_execution_with_spans( +def test_run_tools_parallel_execution_with_spans( mock_get_tracer, tool_handler, tool_uses, event_loop_metrics, invalid_tool_use_ids, cycle_trace, - alist, + thread_pool, ): + """Test that spans are created and ended for each tool in parallel execution.""" # Setup mock tracer and spans mock_tracer = unittest.mock.MagicMock() mock_span1 = unittest.mock.MagicMock() @@ -423,8 +426,9 @@ async def test_run_tools_concurrent_execution_with_spans( tool_results, cycle_trace, parent_span, + thread_pool, ) - await alist(stream) + list(stream) # Verify spans were created for both tools assert mock_tracer.start_tool_call_span.call_count == 2 diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ebcba3fb1..bfdc2a47d 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -6,7 +6,6 @@ import pytest -import strands from strands.tools import PythonAgentTool from strands.tools.decorator import DecoratedFunctionTool, tool from strands.tools.registry import ToolRegistry @@ -31,8 +30,8 @@ def test_process_tools_with_invalid_path(): def test_register_tool_with_similar_name_raises(): - tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), tool_func=lambda: None) - tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), tool_func=lambda: None) + tool_1 = PythonAgentTool(tool_name="tool-like-this", tool_spec=MagicMock(), callback=lambda: None) + tool_2 = PythonAgentTool(tool_name="tool_like_this", tool_spec=MagicMock(), callback=lambda: None) tool_registry = ToolRegistry() @@ -47,23 +46,6 @@ def test_register_tool_with_similar_name_raises(): ) -def test_get_all_tool_specs_returns_right_tool_specs(): - tool_1 = strands.tool(lambda a: a, name="tool_1") - tool_2 = strands.tool(lambda b: b, name="tool_2") - - tool_registry = ToolRegistry() - - tool_registry.register_tool(tool_1) - tool_registry.register_tool(tool_2) - - tool_specs = tool_registry.get_all_tool_specs() - - assert tool_specs == [ - tool_1.tool_spec, - tool_2.tool_spec, - ] - - def test_scan_module_for_tools(): @tool def tool_function_1(a): diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index cec4825dc..cc3150209 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -12,44 +12,6 @@ from strands.types.tools import ToolUse -@pytest.fixture(scope="module") -def identity_invoke(): - def identity(tool_use, a): - return tool_use, a - - return identity - - -@pytest.fixture(scope="module") -def identity_invoke_async(): - async def identity(tool_use, a): - return tool_use, a - - return identity - - -@pytest.fixture -def identity_tool(request): - identity = request.getfixturevalue(request.param) - - return PythonAgentTool( - tool_name="identity", - tool_spec={ - "name": "identity", - "description": "identity", - "inputSchema": { - "type": "object", - "properties": { - "a": { - "type": "integer", - }, - }, - }, - }, - tool_func=identity, - ) - - def test_validate_tool_use_name_valid(): tool1 = {"name": "valid_tool_name", "toolUseId": "123"} # Should not raise an exception @@ -436,62 +398,174 @@ def test_validate_tool_use_invalid(tool_use, expected_error): strands.tools.tools.validate_tool_use(tool_use) -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_tool_name(identity_tool): - tru_name = identity_tool.tool_name +@pytest.fixture +def function(): + def identity(a: int) -> int: + return a + + return identity + + +@pytest.fixture +def tool(function): + return strands.tools.tool(function) + + +def test__init__invalid_name(): + with pytest.raises(ValueError, match="Tool name must be a string"): + + @strands.tool(name=0) + def identity(a): + return a + + +def test_tool_name(tool): + tru_name = tool.tool_name exp_name = "identity" assert tru_name == exp_name -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_tool_spec(identity_tool): - tru_spec = identity_tool.tool_spec +def test_tool_spec(tool): exp_spec = { "name": "identity", "description": "identity", "inputSchema": { - "type": "object", - "properties": { - "a": { - "type": "integer", + "json": { + "type": "object", + "properties": { + "a": { + "description": "Parameter a", + "type": "integer", + }, }, - }, + "required": ["a"], + } }, } + tru_spec = tool.tool_spec assert tru_spec == exp_spec -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_tool_type(identity_tool): - tru_type = identity_tool.tool_type - exp_type = "python" +def test_tool_type(tool): + tru_type = tool.tool_type + exp_type = "function" assert tru_type == exp_type -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_supports_hot_reload(identity_tool): - assert not identity_tool.supports_hot_reload +def test_supports_hot_reload(tool): + assert tool.supports_hot_reload + + +def test_original_function(tool, function): + tru_name = tool.original_function.__name__ + exp_name = function.__name__ + + assert tru_name == exp_name + + +def test_original_function_not_decorated(): + def identity(a: int): + return a + tool = strands.tool(func=identity, name="identity") -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -def test_get_display_properties(identity_tool): - tru_properties = identity_tool.get_display_properties() + tru_name = tool.original_function.__name__ + exp_name = "identity" + + assert tru_name == exp_name + + +def test_get_display_properties(tool): + tru_properties = tool.get_display_properties() exp_properties = { + "Function": "identity", "Name": "identity", - "Type": "python", + "Type": "function", } assert tru_properties == exp_properties -@pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) -@pytest.mark.asyncio -async def test_stream(identity_tool, alist): - stream = identity_tool.stream({"tool_use": 1}, {"a": 2}) +def test_invoke(tool): + tru_output = tool.invoke({"input": {"a": 2}}) + exp_output = {"toolUseId": "unknown", "status": "success", "content": [{"text": "2"}]} + + assert tru_output == exp_output + + +def test_invoke_with_agent(): + @strands.tools.tool + def identity(a: int, agent: dict = None): + return a, agent + + exp_output = {"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]} + + tru_output = identity.invoke({"input": {"a": 2}}, agent={"state": 1}) + + assert tru_output == exp_output + + +# Tests from test_python_agent_tool.py +@pytest.fixture +def python_tool(): + def identity(tool_use, a): + return tool_use, a + + return PythonAgentTool( + tool_name="identity", + tool_spec={ + "name": "identity", + "description": "identity", + "inputSchema": { + "type": "object", + "properties": { + "a": { + "type": "integer", + }, + }, + }, + }, + callback=identity, + ) + + +def test_python_tool_name(python_tool): + tru_name = python_tool.tool_name + exp_name = "identity" + + assert tru_name == exp_name + + +def test_python_tool_spec(python_tool): + tru_spec = python_tool.tool_spec + exp_spec = { + "name": "identity", + "description": "identity", + "inputSchema": { + "type": "object", + "properties": { + "a": { + "type": "integer", + }, + }, + }, + } + + assert tru_spec == exp_spec + + +def test_python_tool_type(python_tool): + tru_type = python_tool.tool_type + exp_type = "python" + + assert tru_type == exp_type + + +def test_python_invoke(python_tool): + tru_output = python_tool.invoke({"tool_use": 1}, a=2) + exp_output = ({"tool_use": 1}, 2) - tru_events = await alist(stream) - exp_events = [({"tool_use": 1}, 2)] - assert tru_events == exp_events + assert tru_output == exp_output diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 5baa7e709..dc43b3fcd 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -1,3 +1,4 @@ +import base64 import unittest.mock import pytest @@ -95,6 +96,23 @@ def system_prompt(): "type": "image_url", }, ), + # Image - base64 encoded + ( + { + "image": { + "format": "jpg", + "source": {"bytes": base64.b64encode(b"image")}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "", + }, + "type": "image_url", + }, + ), # Text ( {"text": "hello"}, @@ -349,3 +367,15 @@ def test_format_chunk_unknown_type(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +@pytest.mark.parametrize( + ("data", "exp_result"), + [ + (b"image", b"aW1hZ2U="), + (b"aW1hZ2U=", b"aW1hZ2U="), + ], +) +def test_b64encode(data, exp_result): + tru_result = SAOpenAIModel.b64encode(data) + assert tru_result == exp_result