diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..5de0b79c2 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,11 @@ +coverage: + status: + project: + default: + target: 90% # overall coverage threshold + patch: + default: + target: 90% # patch coverage threshold + base: auto + # Only post patch coverage on decreases + only_pulls: true \ No newline at end of file diff --git a/.github/workflows/pr-size-labeler.yml b/.github/workflows/pr-size-labeler.yml new file mode 100644 index 000000000..bc4d52c6d --- /dev/null +++ b/.github/workflows/pr-size-labeler.yml @@ -0,0 +1,58 @@ +name: PR Size Labeler + +on: + pull_request_target: + branches: main + +jobs: + label-size: + runs-on: ubuntu-latest + permissions: + pull-requests: write + issues: write + steps: + - name: Calculate PR size and apply label + uses: actions/github-script@v8 + with: + script: | + const pr = context.payload.pull_request; + const totalChanges = pr.additions + pr.deletions; + + // Remove existing size labels + const labels = await github.rest.issues.listLabelsOnIssue({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number + }); + + for (const label of labels.data) { + if (label.name.startsWith('size/')) { + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: label.name + }); + } + } + + // Determine and apply new size label + let sizeLabel; + if (totalChanges <= 20) sizeLabel = 'size/xs'; + else if (totalChanges <= 100) sizeLabel = 'size/s'; + else if (totalChanges <= 500) sizeLabel = 'size/m'; + else if (totalChanges <= 1000) sizeLabel = 'size/l'; + else { + sizeLabel = 'size/xl'; + } + + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + labels: [sizeLabel] + }); + + if (sizeLabel === 'size/xl') { + core.setFailed(`PR is too large (${totalChanges} lines). Please split into smaller PRs.`); + } diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 291874dce..e38942b2c 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -66,6 +66,11 @@ jobs: id: tests run: hatch test tests --cover continue-on-error: false + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} lint: name: Lint runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 888a96bbc..e92a233f8 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ __pycache__* .vscode dist repl_state -.kiro \ No newline at end of file +.kiro +uv.lock diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d107b1fa8..be83ff85b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,6 +36,18 @@ Before starting work on any issue: 3. Wait for maintainer confirmation before beginning significant work +## Development Tenets +Our team follows these core principles when designing and implementing features. These tenets help us make consistent decisions, resolve trade-offs, and maintain the quality and coherence of the SDK. When contributing, please consider how your changes align with these principles: + +1. **Simple at any scale:** We believe that simple things should be simple. The same clean abstractions that power a weekend prototype should scale effortlessly to production workloads. We reject the notion that enterprise-grade means enterprise-complicated - Strands remains approachable whether it's your first agent or your millionth. +2. **Extensible by design:** We allow for as much configuration as possible, from hooks to model providers, session managers, tools, etc. We meet customers where they are with flexible extension points that are simple to integrate with. +3. **Composability:** Primitives are building blocks with each other. Each feature of Strands is developed with all other features in mind, they are consistent and complement one another. +4. **The obvious path is the happy path:** Through intuitive naming, helpful error messages, and thoughtful API design, we guide developers toward correct patterns and away from common pitfalls. +5. **We are accessible to humans and agents:** Strands is designed for both humans and AI to understand equally well. We don’t take shortcuts on curated DX for humans and we go the extra mile to make sure coding assistants can help you use those interfaces the right way. +6. **Embrace common standards:** We respect what came before, and do not want to reinvent something that is already widely adopted or done better. + +When proposing solutions or reviewing code, we reference these principles to guide our decisions. If two approaches seem equally valid, we choose the one that best aligns with our tenets. + ## Development Environment This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management. diff --git a/pyproject.toml b/pyproject.toml index af8e45ffc..b542c7481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", + "jsonschema>=4.0.0,<5.0.0", "mcp>=1.11.0,<2.0.0", "pydantic>=2.4.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", diff --git a/src/strands/__init__.py b/src/strands/__init__.py index ae784a58f..3718a29c5 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -5,4 +5,12 @@ from .tools.decorator import tool from .types.tools import ToolContext -__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"] +__all__ = [ + "Agent", + "agent", + "models", + "tool", + "ToolContext", + "types", + "telemetry", +] diff --git a/src/strands/_async.py b/src/strands/_async.py new file mode 100644 index 000000000..976487c37 --- /dev/null +++ b/src/strands/_async.py @@ -0,0 +1,31 @@ +"""Private async execution utilities.""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Awaitable, Callable, TypeVar + +T = TypeVar("T") + + +def run_async(async_func: Callable[[], Awaitable[T]]) -> T: + """Run an async function in a separate thread to avoid event loop conflicts. + + This utility handles the common pattern of running async code from sync contexts + by using ThreadPoolExecutor to isolate the async execution. + + Args: + async_func: A callable that returns an awaitable + + Returns: + The result of the async function + """ + + async def execute_async() -> T: + return await async_func() + + def execute() -> T: + return asyncio.run(execute_async()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() diff --git a/src/strands/_exception_notes.py b/src/strands/_exception_notes.py new file mode 100644 index 000000000..019b9cde4 --- /dev/null +++ b/src/strands/_exception_notes.py @@ -0,0 +1,21 @@ +"""Exception note utilities for Python 3.10+ compatibility.""" + +# add_note was added in 3.11 - we hoist to a constant to facilitate testing +supports_add_note = hasattr(Exception, "add_note") + + +def add_exception_note(exception: Exception, note: str) -> None: + """Add a note to an exception, compatible with Python 3.10+. + + Uses add_note() if it's available (Python 3.11+) or modifies the exception message if it is not. + """ + if supports_add_note: + # we ignore the mypy error because the version-check for add_note is extracted into a constant up above and + # mypy doesn't detect that + exception.add_note(note) # type: ignore + else: + # For Python 3.10, append note to the exception message + if hasattr(exception, "args") and exception.args: + exception.args = (f"{exception.args[0]}\n{note}",) + exception.args[1:] + else: + exception.args = (note,) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..9de33fbfc 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,12 +9,12 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ -import asyncio import json import logging import random -from concurrent.futures import ThreadPoolExecutor +import warnings from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, @@ -31,7 +31,11 @@ from pydantic import BaseModel from .. import _identifier +from .._async import run_async from ..event_loop.event_loop import event_loop_cycle + +if TYPE_CHECKING: + from ..experimental.tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -49,11 +53,13 @@ from ..tools.executors import ConcurrentToolExecutor from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry +from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher -from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent +from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, ToolInterruptEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException +from ..types.interrupt import InterruptResponseContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -61,6 +67,7 @@ ConversationManager, SlidingWindowConversationManager, ) +from .interrupt import InterruptState from .state import AgentState logger = logging.getLogger(__name__) @@ -142,6 +149,9 @@ def caller( Raises: AttributeError: If the tool doesn't exist. """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + normalized_name = self._find_normalized_tool_name(name) # Create unique tool ID and set up the tool request @@ -156,16 +166,13 @@ def caller( async def acall() -> ToolResult: async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - _ = event + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") return tool_results[0] - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() + tool_result = run_async(acall) if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -208,8 +215,9 @@ def __init__( self, model: Union[Model, str, None] = None, messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], Any]]] = None, + tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None, system_prompt: Optional[str] = None, + structured_output_model: Optional[Type[BaseModel]] = None, callback_handler: Optional[ Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] ] = _DEFAULT_CALLBACK_HANDLER, @@ -240,11 +248,16 @@ def __init__( - File paths (e.g., "/path/to/tool.py") - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) + - ToolProvider instances for managed tool collections - Functions decorated with `@strands.tool` decorator. If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. If None, the model will behave according to its default settings. + structured_output_model: Pydantic model type(s) for structured output. + When specified, all agent calls will attempt to return structured output of this type. + This can be overridden on the agent invocation. + Defaults to None (no structured output). callback_handler: Callback for processing events as they happen during agent execution. If not provided (using the default), a new PrintingCallbackHandler instance is created. If explicitly set to None, null_callback_handler is used. @@ -274,8 +287,8 @@ def __init__( """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] - self.system_prompt = system_prompt + self._default_structured_output_model = structured_output_model self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description @@ -337,6 +350,8 @@ def __init__( self.hooks = HookRegistry() + self._interrupt_state = InterruptState() + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -374,7 +389,14 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + def __call__( + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -389,7 +411,9 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result object containing: @@ -398,16 +422,22 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - message: The final message from the model - metrics: Performance metrics from the event loop - state: The final state of the event loop + - structured_output: Parsed structured output when structured_output_model was specified """ + return run_async( + lambda: self.invoke_async( + prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + ) + ) - def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, **kwargs)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() - - async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + async def invoke_async( + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -422,7 +452,9 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result: object containing: @@ -432,7 +464,9 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - events = self.stream_async(prompt, **kwargs) + events = self.stream_async( + prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + ) async for event in events: _ = event @@ -459,13 +493,15 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> Raises: ValueError: If no conversation history or prompt is provided. """ + warnings.warn( + "Agent.structured_output method is deprecated." + " You should pass in `structured_output_model` directly into the agent invocation." + " see: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/structured-output/", + category=DeprecationWarning, + stacklevel=2, + ) - def execute() -> T: - return asyncio.run(self.structured_output_async(output_model, prompt)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.structured_output_async(output_model, prompt)) async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. @@ -483,7 +519,18 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu Raises: ValueError: If no conversation history or prompt is provided. + - """ + if self._interrupt_state.activated: + raise RuntimeError("cannot call structured output during interrupt") + + warnings.warn( + "Agent.structured_output_async method is deprecated." + " You should pass in `structured_output_model` directly into the agent invocation." + " see: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/structured-output/", + category=DeprecationWarning, + stacklevel=2, + ) self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT @@ -527,9 +574,31 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + def cleanup(self) -> None: + """Clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through finalizers as a fallback, but explicit cleanup is recommended. + """ + self.tool_registry.cleanup() + + def __del__(self) -> None: + """Clean up resources when agent is garbage collected.""" + # __del__ is called even when an exception is thrown in the constructor, + # so there is no guarantee tool_registry was set.. + if hasattr(self, "tool_registry"): + self.tool_registry.cleanup() + async def stream_async( self, prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -546,7 +615,9 @@ async def stream_async( - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass to the event loop. + invocation_state: Additional parameters to pass through the event loop. + structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: An async iterator that yields events. Each event is a dictionary containing @@ -567,7 +638,21 @@ async def stream_async( yield event["data"] ``` """ - callback_handler = kwargs.get("callback_handler", self.callback_handler) + self._resume_interrupt(prompt) + + merged_state = {} + if kwargs: + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + merged_state.update(kwargs) + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state + else: + if invocation_state is not None: + merged_state = invocation_state + + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) messages = self._convert_prompt_to_messages(prompt) @@ -576,10 +661,10 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, invocation_state=kwargs) + events = self._run_loop(messages, merged_state, structured_output_model) async for event in events: - event.prepare(invocation_state=kwargs) + event.prepare(invocation_state=merged_state) if event.is_callback_event: as_dict = event.as_dict() @@ -596,12 +681,50 @@ async def stream_async( self._end_agent_trace_span(error=e) raise - async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + def _resume_interrupt(self, prompt: AgentInput) -> None: + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ + if not self._interrupt_state.activated: + return + + if not isinstance(prompt, list): + raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") + + invalid_types = [ + content_type for content in prompt for content_type in content if content_type != "interruptResponse" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" + ) + + for content in cast(list[InterruptResponseContent], prompt): + interrupt_id = content["interruptResponse"]["interruptId"] + interrupt_response = content["interruptResponse"]["response"] + + if interrupt_id not in self._interrupt_state.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + self._interrupt_state.interrupts[interrupt_id].response = interrupt_response + + async def _run_loop( + self, + messages: Messages, + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, + ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. Args: messages: The input messages to add to the conversation. invocation_state: Additional parameters to pass to the event loop. + structured_output_model: Optional Pydantic model type for structured output. Yields: Events from the event loop cycle. @@ -614,8 +737,12 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) for message in messages: self._append_message(message) + structured_output_context = StructuredOutputContext( + structured_output_model or self._default_structured_output_model + ) + # Execute the event loop cycle with retry logic for context limits - events = self._execute_event_loop_cycle(invocation_state) + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) async for event in events: # Signal from the model provider that the message sent by the user should be redacted, # likely due to a guardrail. @@ -636,24 +763,33 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) self.conversation_manager.apply_management(self) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + async def _execute_event_loop_cycle( + self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None + ) -> AsyncGenerator[TypedEvent, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements retry logic for handling context window overflow exceptions by reducing the conversation context and retrying. + Args: + invocation_state: Additional parameters to pass to the event loop. + structured_output_context: Optional structured output context for this invocation. + Yields: Events of the loop cycle. """ # Add `Agent` to invocation_state to keep backwards-compatibility invocation_state["agent"] = self + if structured_output_context: + structured_output_context.register_tool(self.tool_registry) + try: - # Execute the main event loop cycle events = event_loop_cycle( agent=self, invocation_state=invocation_state, + structured_output_context=structured_output_context, ) async for event in events: yield event @@ -666,11 +802,18 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A if self._session_manager: self._session_manager.sync_agent(self) - events = self._execute_event_loop_cycle(invocation_state) + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) async for event in events: yield event + finally: + if structured_output_context: + structured_output_context.cleanup(self.tool_registry) + def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: + if self._interrupt_state.activated: + return [] + messages: Messages | None = None if prompt is not None: if isinstance(prompt, str): diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f3758c8d2..076a94d7a 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,8 +4,11 @@ """ from dataclasses import dataclass -from typing import Any +from typing import Any, Sequence, cast +from pydantic import BaseModel + +from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message from ..types.streaming import StopReason @@ -20,12 +23,16 @@ class AgentResult: message: The last message generated by the agent. metrics: Performance metrics collected during processing. state: Additional state information from the event loop. + interrupts: List of interrupts if raised by user. + structured_output: Parsed structured output when structured_output_model was specified. """ stop_reason: StopReason message: Message metrics: EventLoopMetrics state: Any + interrupts: Sequence[Interrupt] | None = None + structured_output: BaseModel | None = None def __str__(self) -> str: """Get the agent's last message as a string. @@ -43,3 +50,34 @@ def __str__(self) -> str: if isinstance(item, dict) and "text" in item: result += item.get("text", "") + "\n" return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AgentResult": + """Rehydrate an AgentResult from persisted JSON. + + Args: + data: Dictionary containing the serialized AgentResult data + Returns: + AgentResult instance + Raises: + TypeError: If the data format is invalid@ + """ + if data.get("type") != "agent_result": + raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}") + + message = cast(Message, data.get("message")) + stop_reason = cast(StopReason, data.get("stop_reason")) + + return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + + def to_dict(self) -> dict[str, Any]: + """Convert this AgentResult to JSON-serializable dictionary. + + Returns: + Dictionary containing serialized AgentResult data + """ + return { + "type": "agent_result", + "message": self.message, + "stop_reason": self.stop_reason, + } diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index b08b6853e..12185c286 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -5,8 +5,11 @@ from typing_extensions import override +from ...tools._tool_helpers import noop_tool +from ...tools.registry import ToolRegistry from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException +from ...types.tools import AgentTool from .conversation_manager import ConversationManager if TYPE_CHECKING: @@ -23,6 +26,10 @@ - You MUST create a structured and concise summary in bullet-point format. - You MUST NOT respond conversationally. - You MUST NOT address the user directly. +- You MUST NOT comment on tool availability. + +Assumptions: +- You MUST NOT assume tool executions failed unless otherwise stated. Task: Your task is to create a structured summary document: @@ -182,9 +189,10 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: # Choose which agent to use for summarization summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent - # Save original system prompt and messages to restore later + # Save original system prompt, messages, and tool registry to restore later original_system_prompt = summarization_agent.system_prompt original_messages = summarization_agent.messages.copy() + original_tool_registry = summarization_agent.tool_registry try: # Only override system prompt if no agent was provided during initialization @@ -197,6 +205,13 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: ) # Temporarily set the system prompt for summarization summarization_agent.system_prompt = system_prompt + + # Add no-op tool if agent has no tools to satisfy tool spec requirement + if not summarization_agent.tool_names: + tool_registry = ToolRegistry() + tool_registry.register_tool(cast(AgentTool, noop_tool)) + summarization_agent.tool_registry = tool_registry + summarization_agent.messages = messages # Use the agent to generate summary with rich content (can use tools if needed) @@ -207,6 +222,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: # Restore original agent state summarization_agent.system_prompt = original_system_prompt summarization_agent.messages = original_messages + summarization_agent.tool_registry = original_tool_registry def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. diff --git a/src/strands/agent/interrupt.py b/src/strands/agent/interrupt.py new file mode 100644 index 000000000..3cec1541b --- /dev/null +++ b/src/strands/agent/interrupt.py @@ -0,0 +1,59 @@ +"""Track the state of interrupt events raised by the user for human-in-the-loop workflows.""" + +from dataclasses import asdict, dataclass, field +from typing import Any + +from ..interrupt import Interrupt + + +@dataclass +class InterruptState: + """Track the state of interrupt events raised by the user. + + Note, interrupt state is cleared after resuming. + + Attributes: + interrupts: Interrupts raised by the user. + context: Additional context associated with an interrupt event. + activated: True if agent is in an interrupt state, False otherwise. + """ + + interrupts: dict[str, Interrupt] = field(default_factory=dict) + context: dict[str, Any] = field(default_factory=dict) + activated: bool = False + + def activate(self, context: dict[str, Any] | None = None) -> None: + """Activate the interrupt state. + + Args: + context: Context associated with the interrupt event. + """ + self.context = context or {} + self.activated = True + + def deactivate(self) -> None: + """Deacitvate the interrupt state. + + Interrupts and context are cleared. + """ + self.interrupts = {} + self.context = {} + self.activated = False + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InterruptState": + """Initiailize interrupt state from serialized interrupt state. + + Interrupt state can be serialized with the `to_dict` method. + """ + return cls( + interrupts={ + interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() + }, + context=data["context"], + activated=data["activated"], + ) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index d6367e9d9..3ea0097d8 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -19,6 +19,7 @@ from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools +from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..types._events import ( EventLoopStopEvent, EventLoopThrottleEvent, @@ -27,15 +28,18 @@ ModelStopReason, StartEvent, StartEventLoopEvent, + StructuredOutputEvent, + ToolInterruptEvent, ToolResultMessageEvent, TypedEvent, ) -from ..types.content import Message +from ..types.content import Message, Messages from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, MaxTokensReachedException, ModelThrottledException, + StructuredOutputException, ) from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse @@ -52,7 +56,31 @@ MAX_DELAY = 240 # 4 minutes -async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: +def _has_tool_use_in_latest_message(messages: "Messages") -> bool: + """Check if the latest message contains any ToolUse content blocks. + + Args: + messages: List of messages in the conversation. + + Returns: + True if the latest message contains at least one ToolUse content block, False otherwise. + """ + if len(messages) > 0: + latest_message = messages[-1] + content_blocks = latest_message.get("content", []) + + for content_block in content_blocks: + if "toolUse" in content_block: + return True + + return False + + +async def event_loop_cycle( + agent: "Agent", + invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, +) -> AsyncGenerator[TypedEvent, None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -73,6 +101,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> - request_state: State maintained across cycles - event_loop_cycle_id: Unique ID for this cycle - event_loop_cycle_span: Current tracing Span for this cycle + structured_output_context: Optional context for structured output management. Yields: Model and tool stream events. The last event is a tuple containing: @@ -86,6 +115,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> EventLoopException: If an error occurs during execution ContextWindowOverflowException: If the input is too large for the model """ + structured_output_context = structured_output_context or StructuredOutputContext() + # Initialize cycle state invocation_state["event_loop_cycle_id"] = uuid.uuid4() @@ -106,13 +137,24 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) invocation_state["event_loop_cycle_span"] = cycle_span - model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] + # Skip model invocation if the latest message contains ToolUse + elif _has_tool_use_in_latest_message(agent.messages): + stop_reason = "tool_use" + message = agent.messages[-1] + else: + model_events = _handle_model_execution( + agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context + ) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) try: if stop_reason == "max_tokens": @@ -131,7 +173,6 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) ) - # If the model is requesting to use tools if stop_reason == "tool_use": # Handle tool execution tool_events = _handle_tool_execution( @@ -142,6 +183,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_span=cycle_span, cycle_start_time=cycle_start_time, invocation_state=invocation_state, + tracer=tracer, + structured_output_context=structured_output_context, ) async for tool_event in tool_events: yield tool_event @@ -176,10 +219,33 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e + # Force structured output tool call if LLM didn't use it automatically + if structured_output_context.is_enabled and stop_reason == "end_turn": + if structured_output_context.force_attempted: + raise StructuredOutputException( + "The model failed to invoke the structured output tool even after it was forced." + ) + structured_output_context.set_forced_mode() + logger.debug("Forcing structured output tool") + agent._append_message( + {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} + ) + + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) + async for typed_event in events: + yield typed_event + return + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: +async def recurse_event_loop( + agent: "Agent", + invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, +) -> AsyncGenerator[TypedEvent, None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. @@ -187,7 +253,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - Args: agent: Agent for which the recursive call is being made. invocation_state: Arguments to pass through event_loop_cycle - + structured_output_context: Optional context for structured output management. Yields: Results from event_loop_cycle where the last result contains: @@ -205,7 +271,9 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - yield StartEvent() - events = event_loop_cycle(agent=agent, invocation_state=invocation_state) + events = event_loop_cycle( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) async for event in events: yield event @@ -218,6 +286,7 @@ async def _handle_model_execution( cycle_trace: Trace, invocation_state: dict[str, Any], tracer: Tracer, + structured_output_context: StructuredOutputContext, ) -> AsyncGenerator[TypedEvent, None]: """Handle model execution with retry logic for throttling exceptions. @@ -230,6 +299,7 @@ async def _handle_model_execution( cycle_trace: Trace object for the current event loop cycle. invocation_state: State maintained across cycles. tracer: Tracer instance for span management. + structured_output_context: Context for structured output management. Yields: Model stream events and throttle events during retries. @@ -258,10 +328,15 @@ async def _handle_model_execution( ) ) - tool_specs = agent.tool_registry.get_all_tool_specs() - + if structured_output_context.forced_mode: + tool_spec = structured_output_context.get_tool_spec() + tool_specs = [tool_spec] if tool_spec else [] + else: + tool_specs = agent.tool_registry.get_all_tool_specs() try: - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): + async for event in stream_messages( + agent.model, agent.system_prompt, agent.messages, tool_specs, structured_output_context.tool_choice + ): yield event stop_reason, message, usage, metrics = event["stop"] @@ -281,7 +356,7 @@ async def _handle_model_execution( message = recover_message_on_max_tokens_reached(message) if model_invoke_span: - tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) break # Success! Break out of retry loop except Exception as e: @@ -345,6 +420,8 @@ async def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, invocation_state: dict[str, Any], + tracer: Tracer, + structured_output_context: StructuredOutputContext, ) -> AsyncGenerator[TypedEvent, None]: """Handles the execution of tools requested by the model during an event loop cycle. @@ -356,6 +433,8 @@ async def _handle_tool_execution( cycle_span: Span object for tracing the cycle (type may vary). cycle_start_time: Start time of the current cycle. invocation_state: Additional keyword arguments, including request state. + tracer: Tracer instance for span management. + structured_output_context: Optional context for structured output management. Yields: Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple @@ -371,19 +450,52 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] - if not tool_uses: - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) - return + if agent._interrupt_state.activated: + tool_results.extend(agent._interrupt_state.context["tool_results"]) + + # Filter to only the interrupted tools when resuming from interrupt (tool uses without results) + tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results} + tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] + + interrupts = [] tool_events = agent.tool_executor._execute( - agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) async for tool_event in tool_events: + if isinstance(tool_event, ToolInterruptEvent): + interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"]) + yield tool_event - # Store parent cycle ID for the next cycle + structured_output_result = None + if structured_output_context.is_enabled: + if structured_output_result := structured_output_context.extract_result(tool_uses): + yield StructuredOutputEvent(structured_output=structured_output_result) + structured_output_context.stop_loop = True + invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] + if interrupts: + # Session state stored on AfterInvocationEvent. + agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results}) + + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield EventLoopStopEvent( + "interrupt", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + interrupts, + structured_output=structured_output_result, + ) + if cycle_span: + tracer.end_event_loop_cycle_span(span=cycle_span, message=message) + + return + + agent._interrupt_state.deactivate() + tool_result_message: Message = { "role": "user", "content": [{"toolResult": result} for result in tool_results], @@ -391,17 +503,25 @@ async def _handle_tool_execution( agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) + yield ToolResultMessageEvent(message=tool_result_message) if cycle_span: - tracer = get_tracer() tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) - if invocation_state["request_state"].get("stop_event_loop", False): + if invocation_state["request_state"].get("stop_event_loop", False) or structured_output_context.stop_loop: agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + yield EventLoopStopEvent( + stop_reason, + message, + agent.event_loop_metrics, + invocation_state["request_state"], + structured_output=structured_output_result, + ) return - events = recurse_event_loop(agent=agent, invocation_state=invocation_state) + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) async for event in events: yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index f24bd2a76..6d847f8af 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,6 +2,7 @@ import json import logging +import time from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model @@ -267,31 +268,38 @@ def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> N state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] -def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: +def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | None = None) -> tuple[Usage, Metrics]: """Extracts usage metrics from the metadata chunk. Args: event: metadata. + time_to_first_byte_ms: time to get the first byte from the model in milliseconds Returns: The extracted usage metrics and latency. """ usage = Usage(**event["usage"]) metrics = Metrics(**event["metrics"]) + if time_to_first_byte_ms: + metrics["timeToFirstByteMs"] = time_to_first_byte_ms return usage, metrics -async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]: +async def process_stream( + chunks: AsyncIterable[StreamEvent], start_time: float | None = None +) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. + start_time: Time when the model request is initiated Yields: The reason for stopping, the constructed message, and the usage metrics. """ stop_reason: StopReason = "end_turn" + first_byte_time = None state: dict[str, Any] = { "message": {"role": "assistant", "content": []}, @@ -303,10 +311,14 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T state["content"] = state["message"]["content"] usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics: Metrics = Metrics(latencyMs=0) + metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0) async for chunk in chunks: + # Track first byte time when we get first content + if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): + first_byte_time = time.time() yield ModelStreamChunkEvent(chunk=chunk) + if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: @@ -319,7 +331,10 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T elif "messageStop" in chunk: stop_reason = handle_message_stop(chunk["messageStop"]) elif "metadata" in chunk: - usage, metrics = extract_usage_metrics(chunk["metadata"]) + time_to_first_byte_ms = ( + int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None + ) + usage, metrics = extract_usage_metrics(chunk["metadata"], time_to_first_byte_ms) elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], state) @@ -331,6 +346,7 @@ async def stream_messages( system_prompt: Optional[str], messages: Messages, tool_specs: list[ToolSpec], + tool_choice: Optional[Any] = None, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -339,6 +355,7 @@ async def stream_messages( system_prompt: The system prompt to send. messages: List of messages to send. tool_specs: The list of tool specs. + tool_choice: Optional tool choice constraint for forcing specific tool usage. Yields: The reason for stopping, the final message, and the usage metrics @@ -346,7 +363,8 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = remove_blank_messages_content_text(messages) - chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt) + start_time = time.time() + chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice) - async for event in process_stream(chunks): + async for event in process_stream(chunks, start_time): yield event diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index c40d0fcec..188c80c69 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -2,3 +2,8 @@ This module implements experimental features that are subject to change in future revisions without notice. """ + +from . import tools +from .agent_config import config_to_agent + +__all__ = ["config_to_agent", "tools"] diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py new file mode 100644 index 000000000..f65afb57d --- /dev/null +++ b/src/strands/experimental/agent_config.py @@ -0,0 +1,139 @@ +"""Experimental agent configuration utilities. + +This module provides utilities for creating agents from configuration files or dictionaries. + +Note: Configuration-based agent setup only works for tools that don't require code-based +instantiation. For tools that need constructor arguments or complex setup, use the +programmatic approach after creating the agent: + + agent = config_to_agent("config.json") + # Add tools that need code-based instantiation + agent.tool_registry.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))]) +""" + +import json +from pathlib import Path +from typing import Any + +import jsonschema +from jsonschema import ValidationError + +# JSON Schema for agent configuration +AGENT_CONFIG_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Agent Configuration", + "description": "Configuration schema for creating agents", + "type": "object", + "properties": { + "name": {"description": "Name of the agent", "type": ["string", "null"], "default": None}, + "model": { + "description": "The model ID to use for this agent. If not specified, uses the default model.", + "type": ["string", "null"], + "default": None, + }, + "prompt": { + "description": "The system prompt for the agent. Provides high level context to the agent.", + "type": ["string", "null"], + "default": None, + }, + "tools": { + "description": "List of tools the agent can use. Can be file paths, " + "Python module names, or @tool annotated functions in files.", + "type": "array", + "items": {"type": "string"}, + "default": [], + }, + }, + "additionalProperties": False, +} + +# Pre-compile validator for better performance +_VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) + + +def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any: + """Create an Agent from a configuration file or dictionary. + + This function supports tools that can be loaded declaratively (file paths, module names, + or @tool annotated functions). For tools requiring code-based instantiation with constructor + arguments, add them programmatically after creating the agent: + + agent = config_to_agent("config.json") + agent.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))]) + + Args: + config: Either a file path (with optional file:// prefix) or a configuration dictionary + **kwargs: Additional keyword arguments to pass to the Agent constructor + + Returns: + Agent: A configured Agent instance + + Raises: + FileNotFoundError: If the configuration file doesn't exist + json.JSONDecodeError: If the configuration file contains invalid JSON + ValueError: If the configuration is invalid or tools cannot be loaded + + Examples: + Create agent from file: + >>> agent = config_to_agent("/path/to/config.json") + + Create agent from file with file:// prefix: + >>> agent = config_to_agent("file:///path/to/config.json") + + Create agent from dictionary: + >>> config = {"model": "anthropic.claude-3-5-sonnet-20241022-v2:0", "tools": ["calculator"]} + >>> agent = config_to_agent(config) + """ + # Parse configuration + if isinstance(config, str): + # Handle file path + file_path = config + + # Remove file:// prefix if present + if file_path.startswith("file://"): + file_path = file_path[7:] + + # Load JSON from file + config_path = Path(file_path) + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {file_path}") + + with open(config_path, "r") as f: + config_dict = json.load(f) + elif isinstance(config, dict): + config_dict = config.copy() + else: + raise ValueError("Config must be a file path string or dictionary") + + # Validate configuration against schema + try: + _VALIDATOR.validate(config_dict) + except ValidationError as e: + # Provide more detailed error message + error_path = " -> ".join(str(p) for p in e.absolute_path) if e.absolute_path else "root" + raise ValueError(f"Configuration validation error at {error_path}: {e.message}") from e + + # Prepare Agent constructor arguments + agent_kwargs = {} + + # Map configuration keys to Agent constructor parameters + config_mapping = { + "model": "model", + "prompt": "system_prompt", + "tools": "tools", + "name": "name", + } + + # Only include non-None values from config + for config_key, agent_param in config_mapping.items(): + if config_key in config_dict and config_dict[config_key] is not None: + agent_kwargs[agent_param] = config_dict[config_key] + + # Override with any additional kwargs provided + agent_kwargs.update(kwargs) + + # Import Agent at runtime to avoid circular imports + from ..agent import Agent + + # Create and return Agent + return Agent(**agent_kwargs) diff --git a/src/strands/experimental/hooks/multiagent/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..d059d0da5 --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -0,0 +1,20 @@ +"""Multi-agent hook events and utilities. + +Provides event classes for hooking into multi-agent orchestrator lifecycle. +""" + +from .events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) + +__all__ = [ + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", +] diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py new file mode 100644 index 000000000..9e54296a4 --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -0,0 +1,93 @@ +"""Multi-agent execution lifecycle events for hook system integration. + +These events are fired by orchestrators (Graph/Swarm) at key points so +hooks can persist, monitor, or debug execution. No intermediate state model +is used—hooks read from the orchestrator directly. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from ....hooks import BaseHookEvent + +if TYPE_CHECKING: + from ....multiagent.base import MultiAgentBase + + +@dataclass +class MultiAgentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class BeforeNodeCallEvent(BaseHookEvent): + """Event triggered before individual node execution starts. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node about to execute + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterNodeCallEvent(BaseHookEvent): + """Event triggered after individual node execution completes. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered before orchestrator execution starts. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..ad693f8ac --- /dev/null +++ b/src/strands/experimental/tools/__init__.py @@ -0,0 +1,5 @@ +"""Experimental tools package.""" + +from .tool_provider import ToolProvider + +__all__ = ["ToolProvider"] diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py new file mode 100644 index 000000000..2c79ceafc --- /dev/null +++ b/src/strands/experimental/tools/tool_provider.py @@ -0,0 +1,52 @@ +"""Tool provider interface.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Sequence + +if TYPE_CHECKING: + from ...types.tools import AgentTool + + +class ToolProvider(ABC): + """Interface for providing tools with lifecycle management. + + Provides a way to load a collection of tools and clean them up + when done, with lifecycle managed by the agent. + """ + + @abstractmethod + async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: + """Load and return the tools in this provider. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of tools that are ready to use. + """ + pass + + @abstractmethod + def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. + + Args: + consumer_id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + """ + pass + + @abstractmethod + def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + This method must be idempotent - calling it multiple times with the same ID + should have no additional effect after the first call. + + Provider may clean up resources when no consumers remain. + + Args: + consumer_id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + """ + pass diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8f611e4e2..05be255f6 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -3,10 +3,14 @@ This module defines the events that are emitted as Agents run through the lifecycle of a request. """ +import uuid from dataclasses import dataclass from typing import Any, Optional +from typing_extensions import override + from ..types.content import Message +from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -84,7 +88,7 @@ class MessageAddedEvent(HookEvent): @dataclass -class BeforeToolCallEvent(HookEvent): +class BeforeToolCallEvent(HookEvent, _Interruptible): """Event triggered before a tool is invoked. This event is fired just before the agent executes a tool, allowing hook @@ -110,6 +114,18 @@ class BeforeToolCallEvent(HookEvent): def _can_write(self, name: str) -> bool: return name in ["cancel_tool", "selected_tool", "tool_use"] + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:before_tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + @dataclass class AfterToolCallEvent(HookEvent): diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index b8e7f82ab..564be85cb 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -7,12 +7,17 @@ via hook provider objects. """ +import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from ..interrupt import Interrupt, InterruptException + if TYPE_CHECKING: from ..agent import Agent +logger = logging.getLogger(__name__) + @dataclass class BaseHookEvent: @@ -184,7 +189,7 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) - def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: + def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: """Invoke all registered callbacks for the given event. This method finds all callbacks registered for the event's type and @@ -192,11 +197,16 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: callbacks are invoked in reverse registration order. Any exceptions raised by callback functions will propagate to the caller. + Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows. + Args: event: The event to dispatch to registered callbacks. Returns: - The event dispatched to registered callbacks. + The event dispatched to registered callbacks and any interrupts raised by the user. + + Raises: + ValueError: If interrupt name is used more than once. Example: ```python @@ -204,10 +214,22 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: registry.invoke_callbacks(event) ``` """ + interrupts: dict[str, Interrupt] = {} + for callback in self.get_callbacks_for(event): - callback(event) + try: + callback(event) + except InterruptException as exception: + interrupt = exception.interrupt + if interrupt.name in interrupts: + message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" + logger.error(message) + raise ValueError(message) from exception + + # Each callback is allowed to raise their own interrupt. + interrupts[interrupt.name] = interrupt - return event + return event, list(interrupts.values()) def has_callbacks(self) -> bool: """Check if the registry has any registered callbacks. diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py new file mode 100644 index 000000000..f0ed52389 --- /dev/null +++ b/src/strands/interrupt.py @@ -0,0 +1,33 @@ +"""Human-in-the-loop interrupt system for agent workflows.""" + +from dataclasses import asdict, dataclass +from typing import Any + + +@dataclass +class Interrupt: + """Represents an interrupt that can pause agent execution for human-in-the-loop workflows. + + Attributes: + id: Unique identifier. + name: User defined name. + reason: User provided reason for raising the interrupt. + response: Human response provided when resuming the agent after an interrupt. + """ + + id: str + name: str + reason: Any = None + response: Any = None + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + +class InterruptException(Exception): + """Exception raised when human input is required.""" + + def __init__(self, interrupt: Interrupt) -> None: + """Set the interrupt.""" + self.interrupt = interrupt diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index a95b0d027..48351da19 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -14,7 +14,7 @@ from typing_extensions import Required, Unpack, override from ..event_loop.streaming import process_stream -from ..tools import convert_pydantic_to_tool_spec +from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c6a500597..576f7c43e 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -16,8 +16,10 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from .._exception_notes import add_exception_note from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec +from ..tools._tool_helpers import noop_tool from ..types.content import ContentBlock, Messages from ..types.exceptions import ( ContextWindowOverflowException, @@ -203,6 +205,12 @@ def format_request( Returns: A Bedrock converse stream request. """ + if not tool_specs: + has_tool_content = any( + any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages + ) + if has_tool_content: + tool_specs = [noop_tool.tool_spec] return { "modelId": self.config["model_id"], "messages": self._format_bedrock_messages(messages), @@ -707,7 +715,10 @@ def _stream( except ClientError as e: error_message = str(e) - if e.response["Error"]["Code"] == "ThrottlingException": + if ( + e.response["Error"]["Code"] == "ThrottlingException" + or e.response["Error"]["Code"] == "throttlingException" + ): raise ModelThrottledException(error_message) from e if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): @@ -716,29 +727,29 @@ def _stream( region = self.client.meta.region_name - # add_note added in Python 3.11 - if hasattr(e, "add_note"): - # Aid in debugging by adding more information - e.add_note(f"└ Bedrock region: {region}") - e.add_note(f"└ Model id: {self.config.get('model_id')}") - - if ( - e.response["Error"]["Code"] == "AccessDeniedException" - and "You don't have access to the model" in error_message - ): - e.add_note( - "└ For more information see " - "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue" - ) - - if ( - e.response["Error"]["Code"] == "ValidationException" - and "with on-demand throughput isn’t supported" in error_message - ): - e.add_note( - "└ For more information see " - "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported" - ) + # Aid in debugging by adding more information + add_exception_note(e, f"└ Bedrock region: {region}") + add_exception_note(e, f"└ Model id: {self.config.get('model_id')}") + + if ( + e.response["Error"]["Code"] == "AccessDeniedException" + and "You don't have access to the model" in error_message + ): + add_exception_note( + e, + "└ For more information see " + "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue", + ) + + if ( + e.response["Error"]["Code"] == "ValidationException" + and "with on-demand throughput isn’t supported" in error_message + ): + add_exception_note( + e, + "└ For more information see " + "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported", + ) raise e diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..f1cbf01a2 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -8,11 +8,14 @@ from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast import litellm +from litellm.exceptions import ContextWindowExceededError from litellm.utils import supports_response_schema from pydantic import BaseModel from typing_extensions import Unpack, override +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys @@ -108,6 +111,26 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] return super().format_request_message_content(content) + def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]: + """Handle switching to a new content stream. + + Args: + data_type: The next content data type. + prev_data_type: The previous content data type. + + Returns: + Tuple containing: + - Stop block for previous content and the start block for the next content. + - Next content data type. + """ + chunks = [] + if data_type != prev_data_type: + if prev_data_type is not None: + chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type})) + chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": data_type})) + + return chunks, data_type + @override async def stream( self, @@ -135,13 +158,17 @@ async def stream( logger.debug("request=<%s>", request) logger.debug("invoking model") - response = await litellm.acompletion(**self.client_args, **request) + try: + response = await litellm.acompletion(**self.client_args, **request) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None async for event in response: # Defensive: skip events with empty or missing choices @@ -149,28 +176,36 @@ async def stream( continue choice = event.choices[0] - if choice.delta.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - ) - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield chunk + yield self.format_chunk( { "chunk_type": "content_delta", - "data_type": "reasoning_content", + "data_type": data_type, "data": choice.delta.reasoning_content, } ) + if choice.delta.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield chunk + + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} + ) + for tool_call in choice.delta.tool_calls or []: tool_calls.setdefault(tool_call.index, []).append(tool_call) if choice.finish_reason: + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) break - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - for tool_deltas in tool_calls.values(): yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) @@ -196,6 +231,10 @@ async def structured_output( ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. + Some models do not support native structured output via response_format. + In cases of proxies, we may not have a way to determine support, so we + fallback to using tool calling to achieve structured output. + Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. @@ -205,9 +244,19 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - if not supports_response_schema(self.get_config()["model_id"]): - raise ValueError("Model does not support response_format") - + if supports_response_schema(self.get_config()["model_id"]): + logger.debug("structuring output using response schema") + result = await self._structured_output_using_response_schema(output_model, prompt, system_prompt) + else: + logger.debug("model does not support response schema, structuring output using tool approach") + result = await self._structured_output_using_tool(output_model, prompt, system_prompt) + + yield {"output": result} + + async def _structured_output_using_response_schema( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using native response_format support.""" response = await litellm.acompletion( **self.client_args, model=self.get_config()["model_id"], @@ -217,21 +266,47 @@ async def structured_output( if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") + + choice = response.choices[0] + try: + # Parse the message content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + async def _structured_output_using_tool( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using tool calling fallback.""" + tool_spec = convert_pydantic_to_tool_spec(output_model) + request = self.format_request(prompt, [tool_spec], system_prompt, cast(ToolChoice, {"any": {}})) + args = {**self.client_args, **request, "stream": False} + response = await litellm.acompletion(**args) - # Find the first choice with tool_calls - for choice in response.choices: - if choice.finish_reason == "tool_calls": - try: - # Parse the tool call content as JSON - tool_call_data = json.loads(choice.message.content) - # Instantiate the output model with the parsed data - yield {"output": output_model(**tool_call_data)} - return - except (json.JSONDecodeError, TypeError, ValueError) as e: - raise ValueError(f"Failed to parse or load content into model: {e}") from e - - # If no tool_calls found, raise an error - raise ValueError("No tool_calls found in response") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") + + choice = response.choices[0] + try: + # Parse the tool call content as JSON + tool_call = choice.message.tool_calls[0] + tool_call_data = json.loads(tool_call.function.arguments) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index d1447732e..25b3ca7ce 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -4,7 +4,7 @@ import logging import os from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union import boto3 from botocore.config import Config as BotocoreConfig @@ -151,8 +151,8 @@ def __init__( validate_config_keys(payload_config, self.SageMakerAIPayloadSchema) payload_config.setdefault("stream", True) payload_config.setdefault("tool_results_as_user_messages", False) - self.endpoint_config = dict(endpoint_config) - self.payload_config = dict(payload_config) + self.endpoint_config = self.SageMakerAIEndpointConfig(**endpoint_config) + self.payload_config = self.SageMakerAIPayloadSchema(**payload_config) logger.debug( "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config ) @@ -193,7 +193,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i Returns: The Amazon SageMaker model configuration. """ - return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) + return self.endpoint_config @override def format_request( @@ -238,6 +238,10 @@ def format_request( }, } + payload_additional_args = self.payload_config.get("additional_args") + if payload_additional_args: + payload.update(payload_additional_args) + # Remove tools and tool_choice if tools = [] if not payload["tools"]: payload.pop("tools") @@ -273,16 +277,20 @@ def format_request( } # Add optional SageMaker parameters if provided - if self.endpoint_config.get("inference_component_name"): - request["InferenceComponentName"] = self.endpoint_config["inference_component_name"] - if self.endpoint_config.get("target_model"): - request["TargetModel"] = self.endpoint_config["target_model"] - if self.endpoint_config.get("target_variant"): - request["TargetVariant"] = self.endpoint_config["target_variant"] - - # Add additional args if provided - if self.endpoint_config.get("additional_args"): - request.update(self.endpoint_config["additional_args"].__dict__) + inf_component_name = self.endpoint_config.get("inference_component_name") + if inf_component_name: + request["InferenceComponentName"] = inf_component_name + target_model = self.endpoint_config.get("target_model") + if target_model: + request["TargetModel"] = target_model + target_variant = self.endpoint_config.get("target_variant") + if target_variant: + request["TargetVariant"] = target_variant + + # Add additional request args if provided + additional_args = self.endpoint_config.get("additional_args") + if additional_args: + request.update(additional_args) return request diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 03d7de9b4..1628a8a9d 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,17 +3,20 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ -import asyncio +import logging +import warnings from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Any, Union +from .._async import run_async from ..agent import AgentResult from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage +logger = logging.getLogger(__name__) + class Status(Enum): """Execution status for both graphs and nodes.""" @@ -58,6 +61,54 @@ def get_agent_results(self) -> list[AgentResult]: flattened.extend(nested_node_result.get_agent_results()) return flattened + def to_dict(self) -> dict[str, Any]: + """Convert NodeResult to JSON-serializable dict, ignoring state field.""" + if isinstance(self.result, Exception): + result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} + elif isinstance(self.result, AgentResult): + result_data = self.result.to_dict() + else: + # MultiAgentResult case + result_data = self.result.to_dict() + + return { + "result": result_data, + "execution_time": self.execution_time, + "status": self.status.value, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeResult": + """Rehydrate a NodeResult from persisted JSON.""" + if "result" not in data: + raise TypeError("NodeResult.from_dict: missing 'result'") + raw = data["result"] + + result: Union[AgentResult, "MultiAgentResult", Exception] + if isinstance(raw, dict) and raw.get("type") == "agent_result": + result = AgentResult.from_dict(raw) + elif isinstance(raw, dict) and raw.get("type") == "exception": + result = Exception(str(raw.get("message", "node failed"))) + elif isinstance(raw, dict) and raw.get("type") == "multiagent_result": + result = MultiAgentResult.from_dict(raw) + else: + raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") + + usage = _parse_usage(data.get("accumulated_usage", {})) + metrics = _parse_metrics(data.get("accumulated_metrics", {})) + + return cls( + result=result, + execution_time=int(data.get("execution_time", 0)), + status=Status(data.get("status", "pending")), + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + ) + @dataclass class MultiAgentResult: @@ -75,6 +126,38 @@ class MultiAgentResult: execution_count: int = 0 execution_time: int = 0 + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": + """Rehydrate a MultiAgentResult from persisted JSON.""" + if data.get("type") != "multiagent_result": + raise TypeError(f"MultiAgentResult.from_dict: unexpected type {data.get('type')!r}") + + results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} + usage = _parse_usage(data.get("accumulated_usage", {})) + metrics = _parse_metrics(data.get("accumulated_metrics", {})) + + multiagent_result = cls( + status=Status(data.get("status", Status.PENDING.value)), + results=results, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + execution_time=int(data.get("execution_time", 0)), + ) + return multiagent_result + + def to_dict(self) -> dict[str, Any]: + """Convert MultiAgentResult to JSON-serializable dict.""" + return { + "type": "multiagent_result", + "status": self.status.value, + "results": {k: v.to_dict() for k, v in self.results.items()}, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, + "execution_time": self.execution_time, + } + class MultiAgentBase(ABC): """Base class for multi-agent helpers. @@ -111,9 +194,39 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) + if kwargs: + invocation_state.update(kwargs) + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + + return run_async(lambda: self.invoke_async(task, invocation_state)) + + def serialize_state(self) -> dict[str, Any]: + """Return a JSON-serializable snapshot of the orchestrator state.""" + raise NotImplementedError + + def deserialize_state(self, payload: dict[str, Any]) -> None: + """Restore orchestrator state from a session dict.""" + raise NotImplementedError + + +# Private helper function to avoid duplicate code + + +def _parse_usage(usage_data: dict[str, Any]) -> Usage: + """Parse Usage from dict data.""" + usage = Usage( + inputTokens=usage_data.get("inputTokens", 0), + outputTokens=usage_data.get("outputTokens", 0), + totalTokens=usage_data.get("totalTokens", 0), + ) + # Add optional fields if they exist + if "cacheReadInputTokens" in usage_data: + usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] + if "cacheWriteInputTokens" in usage_data: + usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] + return usage + - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() +def _parse_metrics(metrics_data: dict[str, Any]) -> Metrics: + """Parse Metrics from dict data.""" + return Metrics(latencyMs=metrics_data.get("latencyMs", 0)) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 738dc4d4c..0aaa6c7a3 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,12 +18,12 @@ import copy import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api +from .._async import run_async from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer @@ -399,12 +399,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -572,11 +567,19 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) elif isinstance(node.executor, Agent): if self.node_timeout is not None: agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input, **invocation_state), + node.executor.invoke_async(node_input, invocation_state=invocation_state), timeout=self.node_timeout, ) else: - agent_response = await node.executor.invoke_async(node_input, **invocation_state) + agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) + + if agent_response.stop_reason == "interrupt": + node.executor.messages.pop() # remove interrupted tool use message + node.executor._interrupt_state.deactivate() + + raise RuntimeError( + "user raised interrupt from agent | interrupts are not yet supported in graphs" + ) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 620fa5e24..3d9dc00c8 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -17,13 +17,14 @@ import json import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Tuple from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from .._async import run_async +from ..agent import Agent +from ..agent.agent_result import AgentResult from ..agent.state import AgentState from ..telemetry import get_tracer from ..tools.decorator import tool @@ -254,12 +255,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -635,8 +631,13 @@ async def _execute_node( # Execute node result = None node.reset_executor_state() - # Unpacking since this is the agent class. Other executors should not unpack - result = await node.executor.invoke_async(node_input, **invocation_state) + result = await node.executor.invoke_async(node_input, invocation_state=invocation_state) + + if result.stop_reason == "interrupt": + node.executor.messages.pop() # remove interrupted tool use message + node.executor._interrupt_state.deactivate() + + raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in swarms") execution_time = round((time.time() - start_time) * 1000) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 93adeb7f2..491f7ad60 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -1,6 +1,5 @@ """File-based session manager for local filesystem storage.""" -import asyncio import json import logging import os @@ -232,20 +231,11 @@ def list_messages( else: message_files = message_files[offset:] - return asyncio.run(self._load_messages_concurrently(messages_dir, message_files)) - - async def _load_messages_concurrently(self, messages_dir: str, message_files: list[str]) -> list[SessionMessage]: - """Load multiple message files concurrently using async.""" - if not message_files: - return [] - - async def load_message(filename: str) -> SessionMessage: + # Load only the message files + messages: list[SessionMessage] = [] + for filename in message_files: file_path = os.path.join(messages_dir, filename) - loop = asyncio.get_event_loop() - message_data = await loop.run_in_executor(None, self._read_file, file_path) - return SessionMessage.from_dict(message_data) - - tasks = [load_message(filename) for filename in message_files] - messages = await asyncio.gather(*tasks) + message_data = self._read_file(file_path) + messages.append(SessionMessage.from_dict(message_data)) return messages diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 75058b251..e5075de93 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -132,6 +132,8 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: ) agent.state = AgentState(session_agent.state) + session_agent.initialize_internal_state(agent) + # Restore the conversation manager to its previous state, and get the optional prepend messages prepend_messages = agent.conversation_manager.restore_from_session(session_agent.conversation_manager_state) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 1f6ffe7f1..c6ce28d80 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -1,6 +1,5 @@ """S3-based session manager for cloud storage.""" -import asyncio import json import logging from typing import Any, Dict, List, Optional, cast @@ -284,23 +283,14 @@ def list_messages( else: message_keys = message_keys[offset:] - # Load message objects concurrently using async - return asyncio.run(self._load_messages_concurrently(message_keys)) + # Load only the required message objects + messages: List[SessionMessage] = [] + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + + return messages except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e - - async def _load_messages_concurrently(self, message_keys: List[str]) -> List[SessionMessage]: - """Load multiple message objects concurrently using async.""" - if not message_keys: - return [] - - async def load_message(key: str) -> Optional[SessionMessage]: - loop = asyncio.get_event_loop() - message_data = await loop.run_in_executor(None, self._read_s3_object, key) - return SessionMessage.from_dict(message_data) if message_data else None - - tasks = [load_message(key) for key in message_keys] - loaded_messages = await asyncio.gather(*tasks) - - return [msg for msg in loaded_messages if msg is not None] diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 883273f64..abfbbffae 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -286,6 +286,8 @@ def update_metrics(self, metrics: Metrics) -> None: metrics: The metrics data to add to the accumulated totals. """ self._metrics_client.event_loop_latency.record(metrics["latencyMs"]) + if metrics.get("timeToFirstByteMs") is not None: + self._metrics_client.model_time_to_first_token.record(metrics["timeToFirstByteMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] def get_summary(self) -> Dict[str, Any]: @@ -448,7 +450,7 @@ class MetricsClient: event_loop_output_tokens: Histogram event_loop_cache_read_input_tokens: Histogram event_loop_cache_write_input_tokens: Histogram - + model_time_to_first_token: Histogram tool_call_count: Counter tool_success_count: Counter tool_error_count: Counter @@ -507,3 +509,6 @@ def create_instruments(self) -> None: self.event_loop_cache_write_input_tokens = self.meter.create_histogram( name=constants.STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS, unit="token" ) + self.model_time_to_first_token = self.meter.create_histogram( + name=constants.STRANDS_MODEL_TIME_TO_FIRST_TOKEN, unit="ms" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py index f8fac34da..2e1047581 100644 --- a/src/strands/telemetry/metrics_constants.py +++ b/src/strands/telemetry/metrics_constants.py @@ -15,3 +15,4 @@ STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS = "strands.event_loop.cache_read.input.tokens" STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS = "strands.event_loop.cache_write.input.tokens" +STRANDS_MODEL_TIME_TO_FIRST_TOKEN = "strands.model.time_to_first_token" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 7cd2d0e7b..9cefc6911 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -16,7 +16,7 @@ from ..agent.agent_result import AgentResult from ..types.content import ContentBlock, Message, Messages -from ..types.streaming import StopReason, Usage +from ..types.streaming import Metrics, StopReason, Usage from ..types.tools import ToolResult, ToolUse from ..types.traces import Attributes, AttributeValue @@ -153,6 +153,28 @@ def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> for key, value in attributes.items(): span.set_attribute(key, value) + def _add_optional_usage_and_metrics_attributes( + self, attributes: Dict[str, AttributeValue], usage: Usage, metrics: Metrics + ) -> None: + """Add optional usage and metrics attributes if they have values. + + Args: + attributes: Dictionary to add attributes to + usage: Token usage information from the model call + metrics: Metrics from the model call + """ + if "cacheReadInputTokens" in usage: + attributes["gen_ai.usage.cache_read_input_tokens"] = usage["cacheReadInputTokens"] + + if "cacheWriteInputTokens" in usage: + attributes["gen_ai.usage.cache_write_input_tokens"] = usage["cacheWriteInputTokens"] + + if metrics.get("timeToFirstByteMs", 0) > 0: + attributes["gen_ai.server.time_to_first_token"] = metrics["timeToFirstByteMs"] + + if metrics.get("latencyMs", 0) > 0: + attributes["gen_ai.server.request.duration"] = metrics["latencyMs"] + def _end_span( self, span: Span, @@ -271,13 +293,19 @@ def start_model_invoke_span( # Add additional kwargs as attributes attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))}) - span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT) + span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL) self._add_event_messages(span, messages) return span def end_model_invoke_span( - self, span: Span, message: Message, usage: Usage, stop_reason: StopReason, error: Optional[Exception] = None + self, + span: Span, + message: Message, + usage: Usage, + metrics: Metrics, + stop_reason: StopReason, + error: Optional[Exception] = None, ) -> None: """End a model invocation span with results and metrics. @@ -285,6 +313,7 @@ def end_model_invoke_span( span: The span to end. message: The message response from the model. usage: Token usage information from the model call. + metrics: Metrics from the model call. stop_reason (StopReason): The reason the model stopped generating. error: Optional exception if the model call failed. """ @@ -294,10 +323,11 @@ def end_model_invoke_span( "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.output_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], - "gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0), - "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), } + # Add optional attributes if they have values + self._add_optional_usage_and_metrics_attributes(attributes, usage, metrics) + if self.use_latest_genai_conventions: self._add_event( span, @@ -307,7 +337,7 @@ def end_model_invoke_span( [ { "role": message["role"], - "parts": [{"type": "text", "content": message["content"]}], + "parts": self._map_content_blocks_to_otel_parts(message["content"]), "finish_reason": str(stop_reason), } ] @@ -362,7 +392,7 @@ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None "type": "tool_call", "name": tool["name"], "id": tool["toolUseId"], - "arguments": [{"content": tool["input"]}], + "arguments": tool["input"], } ], } @@ -417,7 +447,7 @@ def end_tool_call_span( { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "result": tool_result.get("content"), + "response": tool_result.get("content"), } ], } @@ -504,7 +534,7 @@ def end_event_loop_cycle_span( [ { "role": tool_result_message["role"], - "parts": [{"type": "text", "content": tool_result_message["content"]}], + "parts": self._map_content_blocks_to_otel_parts(tool_result_message["content"]), } ] ) @@ -558,7 +588,7 @@ def start_agent_span( attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))}) span = self._start_span( - f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT + f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL ) self._add_event_messages(span, messages) @@ -634,19 +664,23 @@ def start_multiagent_span( ) span = self._start_span(operation, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT) - content = serialize(task) if isinstance(task, list) else task if self.use_latest_genai_conventions: + parts: list[dict[str, Any]] = [] + if isinstance(task, list): + parts = self._map_content_blocks_to_otel_parts(task) + else: + parts = [{"type": "text", "content": task}] self._add_event( span, "gen_ai.client.inference.operation.details", - {"gen_ai.input.messages": serialize([{"role": "user", "parts": [{"type": "text", "content": task}]}])}, + {"gen_ai.input.messages": serialize([{"role": "user", "parts": parts}])}, ) else: self._add_event( span, "gen_ai.user.message", - event_attributes={"content": content}, + event_attributes={"content": serialize(task) if isinstance(task, list) else task}, ) return span @@ -718,7 +752,7 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: input_messages: list = [] for message in messages: input_messages.append( - {"role": message["role"], "parts": [{"type": "text", "content": message["content"]}]} + {"role": message["role"], "parts": self._map_content_blocks_to_otel_parts(message["content"])} ) self._add_event( span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize(input_messages)} @@ -731,6 +765,41 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: {"content": serialize(message["content"])}, ) + def _map_content_blocks_to_otel_parts(self, content_blocks: list[ContentBlock]) -> list[dict[str, Any]]: + """Map ContentBlock objects to OpenTelemetry parts format.""" + parts: list[dict[str, Any]] = [] + + for block in content_blocks: + if "text" in block: + # Standard TextPart + parts.append({"type": "text", "content": block["text"]}) + elif "toolUse" in block: + # Standard ToolCallRequestPart + tool_use = block["toolUse"] + parts.append( + { + "type": "tool_call", + "name": tool_use["name"], + "id": tool_use["toolUseId"], + "arguments": tool_use["input"], + } + ) + elif "toolResult" in block: + # Standard ToolCallResponsePart + tool_result = block["toolResult"] + parts.append( + { + "type": "tool_call_response", + "id": tool_result["toolUseId"], + "response": tool_result["content"], + } + ) + else: + # For all other ContentBlock types, use the key as type and value as content + for key, value in block.items(): + parts.append({"type": key, "content": value}) + return parts + # Singleton instance for global access _tracer_instance = None diff --git a/src/strands/tools/_tool_helpers.py b/src/strands/tools/_tool_helpers.py new file mode 100644 index 000000000..d640f23b8 --- /dev/null +++ b/src/strands/tools/_tool_helpers.py @@ -0,0 +1,15 @@ +"""Helpers for tools.""" + +from strands.tools.decorator import tool + + +# https://github.com/strands-agents/sdk-python/issues/998 +@tool(name="noop", description="This is a fake tool that MUST be completely ignored.") +def noop_tool() -> None: + """No-op tool to satisfy tool spec requirement when tool messages are present. + + Some model providers (e.g., Bedrock) will return an error response if tool uses and tool results are present in + messages without any tool specs configured. Consequently, if the summarization agent has no registered tools, + summarization will fail. As a workaround, we register the no-op tool. + """ + pass diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 99aa7e372..5c49f4b58 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -62,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types._events import ToolResultEvent, ToolStreamEvent +from ..interrupt import InterruptException +from ..types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -99,6 +100,8 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - self.type_hints = get_type_hints(func) self._context_param = context_param + self._validate_signature() + # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) @@ -111,6 +114,20 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _validate_signature(self) -> None: + """Verify that ToolContext is used correctly in the function signature.""" + for param in self.signature.parameters.values(): + if param.annotation is ToolContext: + if self._context_param is None: + raise ValueError("@tool(context) must be set if passing in ToolContext param") + + if param.name != self._context_param: + raise ValueError( + f"param_name=<{param.name}> | ToolContext param must be named '{self._context_param}'" + ) + # Found the parameter, no need to check further + break + def _create_input_model(self) -> Type[BaseModel]: """Create a Pydantic model from function signature for input validation. @@ -477,6 +494,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore yield self._wrap_tool_result(tool_use_id, result) + except InterruptException as e: + yield ToolInterruptEvent(tool_use, [e.interrupt]) + return + except ValueError as e: # Special handling for validation errors error_msg = str(e) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index f78861f81..81a594488 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -13,10 +13,11 @@ from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace -from ...telemetry.tracer import get_tracer -from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent +from ...telemetry.tracer import get_tracer, serialize +from ...types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse +from ..structured_output._structured_output_context import StructuredOutputContext if TYPE_CHECKING: # pragma: no cover from ...agent import Agent @@ -33,6 +34,7 @@ async def _stream( tool_use: ToolUse, tool_results: list[ToolResult], invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Stream tool events. @@ -43,12 +45,14 @@ async def _stream( - Before/after hook execution - Tracing and metrics collection - Error handling and recovery + - Interrupt handling for human-in-the-loop workflows Args: agent: The agent for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output management. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -56,9 +60,18 @@ async def _stream( """ logger.debug("tool_use=<%s> | streaming", tool_use) tool_name = tool_use["name"] + structured_output_context = structured_output_context or StructuredOutputContext() tool_info = agent.tool_registry.dynamic_tools.get(tool_name) tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) + tool_spec = tool_func.tool_spec if tool_func is not None else None + + current_span = trace_api.get_current_span() + if current_span and tool_spec is not None: + current_span.set_attribute("gen_ai.tool.description", tool_spec["description"]) + input_schema = tool_spec["inputSchema"] + if "json" in input_schema: + current_span.set_attribute("gen_ai.tool.json_schema", serialize(input_schema["json"])) invocation_state.update( { @@ -72,7 +85,7 @@ async def _stream( } ) - before_event = agent.hooks.invoke_callbacks( + before_event, interrupts = agent.hooks.invoke_callbacks( BeforeToolCallEvent( agent=agent, selected_tool=tool_func, @@ -81,6 +94,10 @@ async def _stream( ) ) + if interrupts: + yield ToolInterruptEvent(tool_use, interrupts) + return + if before_event.cancel_tool: cancel_message = ( before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" @@ -92,7 +109,7 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, tool_use=tool_use, @@ -130,7 +147,7 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -142,7 +159,8 @@ async def _stream( yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return - + if structured_output_context.is_enabled: + kwargs["structured_output_context"] = structured_output_context async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. @@ -150,18 +168,23 @@ async def _stream( # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in # ToolStreamEvent and the last event is just the result. + if isinstance(event, ToolInterruptEvent): + yield event + return + if isinstance(event, ToolResultEvent): # below the last "event" must point to the tool_result event = event.tool_result break - elif isinstance(event, ToolStreamEvent): + + if isinstance(event, ToolStreamEvent): yield event else: yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -181,7 +204,7 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -202,6 +225,7 @@ async def _stream_with_trace( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Execute tool with tracing and metrics collection. @@ -213,12 +237,14 @@ async def _stream_with_trace( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output management. **kwargs: Additional keyword arguments for future extensibility. Yields: Tool events with the last being the tool result. """ tool_name = tool_use["name"] + structured_output_context = structured_output_context or StructuredOutputContext() tracer = get_tracer() @@ -227,9 +253,15 @@ async def _stream_with_trace( tool_start_time = time.time() with trace_api.use_span(tool_call_span): - async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): + async for event in ToolExecutor._stream( + agent, tool_use, tool_results, invocation_state, structured_output_context, **kwargs + ): yield event + if isinstance(event, ToolInterruptEvent): + tracer.end_tool_call_span(tool_call_span, tool_result=None) + return + result_event = cast(ToolResultEvent, event) result = result_event.tool_result @@ -251,6 +283,7 @@ def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: "StructuredOutputContext", ) -> AsyncGenerator[TypedEvent, None]: """Execute the given tools according to this executor's strategy. @@ -261,6 +294,7 @@ def _execute( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output management. Yields: Events from the tool execution stream. diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 8ef8a8b65..bf78d6f6a 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ..structured_output._structured_output_context import StructuredOutputContext class ConcurrentToolExecutor(ToolExecutor): @@ -26,6 +27,7 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: "StructuredOutputContext", ) -> AsyncGenerator[TypedEvent, None]: """Execute tools concurrently. @@ -36,6 +38,7 @@ async def _execute( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output handling. Yields: Events from the tool execution stream. @@ -57,6 +60,7 @@ async def _execute( task_queue, task_events[task_id], stop_event, + structured_output_context, ) ) for task_id, tool_use in enumerate(tool_uses) @@ -84,6 +88,7 @@ async def _task( task_queue: asyncio.Queue, task_event: asyncio.Event, stop_event: object, + structured_output_context: "StructuredOutputContext", ) -> None: """Execute a single tool and put results in the task queue. @@ -98,10 +103,11 @@ async def _task( task_queue: Queue to put tool events into. task_event: Event to signal when task can continue. stop_event: Sentinel object to signal task completion. + structured_output_context: Context for structured output handling. """ try: events = ToolExecutor._stream_with_trace( - agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) async for event in events: task_queue.put_nowait((task_id, event)) diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 60e5c7fa7..74024455a 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -5,12 +5,13 @@ from typing_extensions import override from ...telemetry.metrics import Trace -from ...types._events import TypedEvent +from ...types._events import ToolInterruptEvent, TypedEvent from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ..structured_output._structured_output_context import StructuredOutputContext class SequentialToolExecutor(ToolExecutor): @@ -25,9 +26,12 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: "StructuredOutputContext", ) -> AsyncGenerator[TypedEvent, None]: """Execute tools sequentially. + Breaks early if an interrupt is raised by the user. + Args: agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. @@ -35,13 +39,22 @@ async def _execute( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output handling. Yields: Events from the tool execution stream. """ + interrupted = False + for tool_use in tool_uses: events = ToolExecutor._stream_with_trace( - agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) async for event in events: + if isinstance(event, ToolInterruptEvent): + interrupted = True + yield event + + if interrupted: + break diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 5935077db..31e8dc788 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -5,7 +5,10 @@ import os import sys import warnings +from importlib.machinery import ModuleSpec from pathlib import Path +from posixpath import expanduser +from types import ModuleType from typing import List, cast from ..types.tools import AgentTool @@ -15,16 +18,151 @@ logger = logging.getLogger(__name__) +def load_tool_from_string(tool_string: str) -> List[AgentTool]: + """Load tools follows strands supported input string formats. + + This function can load a tool based on a string in the following ways: + 1. Local file path to a module based tool: `./path/to/module/tool.py` + 2. Module import path + 2.1. Path to a module based tool: `strands_tools.file_read` + 2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool` + 2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say` + """ + # Case 1: Local file path to a tool + # Ex: ./path/to/my_cool_tool.py + tool_path = expanduser(tool_string) + if os.path.exists(tool_path): + return load_tools_from_file_path(tool_path) + + # Case 2: Module import path + # Ex: test.fixtures.say_tool:say (Load specific @tool decorated function) + # Ex: strands_tools.file_read (Load all @tool decorated functions, or module tool) + return load_tools_from_module_path(tool_string) + + +def load_tools_from_file_path(tool_path: str) -> List[AgentTool]: + """Load module from specified path, and then load tools from that module. + + This function attempts to load the passed in path as a python module, and if it succeeds, + then it tries to import strands tool(s) from that module. + """ + abs_path = str(Path(tool_path).resolve()) + logger.debug("tool_path=<%s> | loading python tool from path", abs_path) + + # Load the module by spec + + # Using this to determine the module name + # ./path/to/my_cool_tool.py -> my_cool_tool + module_name = os.path.basename(tool_path).split(".")[0] + + # This function imports a module based on its path, and gives it the provided name + + spec: ModuleSpec = cast(ModuleSpec, importlib.util.spec_from_file_location(module_name, abs_path)) + if not spec: + raise ImportError(f"Could not create spec for {module_name}") + if not spec.loader: + raise ImportError(f"No loader available for {module_name}") + + module = importlib.util.module_from_spec(spec) + # Load, or re-load, the module + sys.modules[module_name] = module + # Execute the module to run any top level code + spec.loader.exec_module(module) + + return load_tools_from_module(module, module_name) + + +def load_tools_from_module_path(module_tool_path: str) -> list[AgentTool]: + """Load strands tool from a module path. + + Example module paths: + my.module.path + my.module.path:tool_name + """ + if ":" in module_tool_path: + module_path, tool_func_name = module_tool_path.split(":") + else: + module_path, tool_func_name = (module_tool_path, None) + + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as e: + raise AttributeError(f'Tool string: "{module_tool_path}" is not a valid tool string.') from e + + # If a ':' is present in the string, then its a targeted function in a module + if tool_func_name: + if hasattr(module, tool_func_name): + target_tool = getattr(module, tool_func_name) + if isinstance(target_tool, DecoratedFunctionTool): + return [target_tool] + + raise AttributeError(f"Tool {tool_func_name} not found in module {module_path}") + + # Else, try to import all of the @tool decorated tools, or the module based tool + module_name = module_path.split(".")[-1] + return load_tools_from_module(module, module_name) + + +def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTool]: + """Load tools from a module. + + First checks if the passed in module has instances of DecoratedToolFunction classes as atributes to the module. + If so, then it returns them as a list of tools. If not, then it attempts to load the module as a module based tool. + """ + logger.debug("tool_name=<%s>, module=<%s> | loading tools from module", module_name, module_name) + + # Try and see if any of the attributes in the module are function-based tools decorated with @tool + # This means that there may be more than one tool available in this module, so we load them all + + function_tools: List[AgentTool] = [] + # Function tools will appear as attributes in the module + for attr_name in dir(module): + attr = getattr(module, attr_name) + # Check if the module attribute is a DecoratedFunctiontool + if isinstance(attr, DecoratedFunctionTool): + logger.debug("tool_name=<%s>, module=<%s> | found function-based tool in module", attr_name, module_name) + function_tools.append(cast(AgentTool, attr)) + + if function_tools: + return function_tools + + # Finally, if no DecoratedFunctionTools are found in the module, fall back + # to module based tools, and search for TOOL_SPEC + function + module_tool_name = module_name + tool_spec = getattr(module, "TOOL_SPEC", None) + if not tool_spec: + raise AttributeError( + f"The module {module_tool_name} is not a valid module for loading tools." + "This module must contain @tool decorated function(s), or must be a module based tool." + ) + + # If this is a module based tool, the module should have a function with the same name as the module itself + if not hasattr(module, module_tool_name): + raise AttributeError(f"Module-based tool {module_tool_name} missing function {module_tool_name}") + + tool_func = getattr(module, module_tool_name) + if not callable(tool_func): + raise TypeError(f"Tool {module_tool_name} function is not callable") + + return [PythonAgentTool(module_tool_name, tool_spec, tool_func)] + + class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: - """Load a Python tool module and return all discovered function-based tools as a list. + """DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list. This method always returns a list of AgentTool (possibly length 1). It is the canonical API for retrieving multiple tools from a single Python file. """ + warnings.warn( + "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", + DeprecationWarning, + stacklevel=2, + ) try: # Support module:function style (e.g. package.module:function) if not os.path.exists(tool_path) and ":" in tool_path: @@ -108,7 +246,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: """ warnings.warn( "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " - "Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.", + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", DeprecationWarning, stacklevel=2, ) @@ -127,7 +265,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: """ warnings.warn( "ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. " - "Use ToolLoader.load_tools(...) which always returns a list of AgentTool.", + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", DeprecationWarning, stacklevel=2, ) @@ -140,7 +278,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: @classmethod def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: - """Load tools from a file based on its file extension. + """DEPRECATED: Load tools from a file based on its file extension. Args: tool_path: Path to the tool file. @@ -154,6 +292,12 @@ def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: ValueError: If the tool file has an unsupported extension. Exception: For other errors during tool loading. """ + warnings.warn( + "ToolLoader.load_tools is deprecated and will be removed in Strands SDK 2.0. " + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", + DeprecationWarning, + stacklevel=2, + ) ext = Path(tool_path).suffix.lower() abs_path = str(Path(tool_path).resolve()) diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index d95c54fed..cfa841c46 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -7,7 +7,7 @@ """ from .mcp_agent_tool import MCPAgentTool -from .mcp_client import MCPClient +from .mcp_client import MCPClient, ToolFilters from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index acc48443c..af0c069a1 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -28,26 +28,29 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None: + def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: """Initialize a new MCPAgentTool instance. Args: mcp_tool: The MCP tool to adapt mcp_client: The MCP server connection to use for tool invocation + name_override: Optional name to use for the agent tool (for disambiguation) + If None, uses the original MCP tool name """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client + self._agent_tool_name = name_override or mcp_tool.name @property def tool_name(self) -> str: """Get the name of the tool. Returns: - str: The name of the MCP tool + str: The agent-facing name of the tool (may be disambiguated) """ - return self.mcp_tool.name + return self._agent_tool_name @property def tool_spec(self) -> ToolSpec: @@ -63,7 +66,7 @@ def tool_spec(self) -> ToolSpec: spec: ToolSpec = { "inputSchema": {"json": self.mcp_tool.inputSchema}, - "name": self.mcp_tool.name, + "name": self.tool_name, # Use agent-facing name in spec "description": description, } @@ -100,7 +103,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await self.mcp_client.call_tool_async( tool_use_id=tool_use["toolUseId"], - name=self.tool_name, + name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], ) yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index dec8ec313..61f3d9185 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -16,19 +16,22 @@ from concurrent import futures from datetime import timedelta from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast import anyio from mcp import ClientSession, ListToolsResult +from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents from mcp.types import CallToolResult as MCPCallToolResult -from mcp.types import GetPromptResult, ListPromptsResult +from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from typing_extensions import Protocol, TypedDict +from ...experimental.tools import ToolProvider from ...types import PaginatedList -from ...types.exceptions import MCPClientInitializationError +from ...types.exceptions import MCPClientInitializationError, ToolProviderException from ...types.media import ImageFormat -from ...types.tools import ToolResultContent, ToolResultStatus +from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport @@ -37,6 +40,26 @@ T = TypeVar("T") + +class _ToolFilterCallback(Protocol): + def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... + + +_ToolMatcher = str | Pattern[str] | _ToolFilterCallback + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + """ + + allowed: list[_ToolMatcher] + rejected: list[_ToolMatcher] + + MIME_TO_FORMAT: Dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -52,7 +75,7 @@ ) -class MCPClient: +class MCPClient(ToolProvider): """Represents a connection to a Model Context Protocol (MCP) server. This class implements a context manager pattern for efficient connection management, @@ -62,17 +85,32 @@ class MCPClient: The connection runs in a background thread to avoid blocking the main application thread while maintaining communication with the MCP service. When structured content is available from MCP tools, it will be returned as the last item in the content array of the ToolResult. + + Warning: + This class implements the experimental ToolProvider interface and its methods + are subject to change. """ - def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30): + def __init__( + self, + transport_callable: Callable[[], MCPTransport], + *, + startup_timeout: int = 30, + tool_filters: ToolFilters | None = None, + prefix: str | None = None, + ): """Initialize a new MCP Server connection. Args: transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple startup_timeout: Timeout after which MCP server initialization should be cancelled Defaults to 30. + tool_filters: Optional filters to apply to tools. + prefix: Optional prefix for tool names. """ self._startup_timeout = startup_timeout + self._tool_filters = tool_filters + self._prefix = prefix mcp_instrumentation() self._session_id = uuid.uuid4() @@ -86,6 +124,9 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti self._background_thread: threading.Thread | None = None self._background_thread_session: ClientSession | None = None self._background_thread_event_loop: AbstractEventLoop | None = None + self._loaded_tools: list[MCPAgentTool] | None = None + self._tool_provider_started = False + self._consumers: set[Any] = set() def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -136,6 +177,101 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client initialization failed") from e return self + # ToolProvider interface methods (experimental, as ToolProvider is experimental) + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: + """Load and return tools from the MCP server. + + This method implements the ToolProvider interface by loading tools + from the MCP server and caching them for reuse. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of AgentTool instances from the MCP server. + """ + logger.debug( + "started=<%s>, cached_tools=<%s> | loading tools", + self._tool_provider_started, + self._loaded_tools is not None, + ) + + if not self._tool_provider_started: + try: + logger.debug("starting MCP client") + self.start() + self._tool_provider_started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._loaded_tools is None: + logger.debug("loading tools from MCP server") + self._loaded_tools = [] + pagination_token = None + page_count = 0 + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + # Use constructor defaults for prefix and filters in load_tools + paginated_tools = self.list_tools_sync( + pagination_token, prefix=self._prefix, tool_filters=self._tool_filters + ) + + # Tools are already filtered by list_tools_sync, so add them all + for tool in paginated_tools: + self._loaded_tools.append(tool) + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._loaded_tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._loaded_tools)) + + return self._loaded_tools + + def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + """ + self._consumers.add(consumer_id) + logger.debug("added provider consumer, count=%d", len(self._consumers)) + + def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + This method is idempotent - calling it multiple times with the same ID + has no additional effect after the first call. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + Uses existing synchronous stop() method for safe cleanup. + """ + self._consumers.discard(consumer_id) + logger.debug("removed provider consumer, count=%d", len(self._consumers)) + + if not self._consumers and self._tool_provider_started: + logger.debug("no consumers remaining, cleaning up") + try: + self.stop(None, None, None) # Existing sync method - safe for finalizers + self._tool_provider_started = False + self._loaded_tools = None + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # MCP-specific methods + def stop( self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: @@ -186,13 +322,28 @@ async def _set_close_event() -> None: self._background_thread_session = None self._background_thread_event_loop = None self._session_id = uuid.uuid4() + self._loaded_tools = None + self._tool_provider_started = False + self._consumers = set() - def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: + def list_tools_sync( + self, + pagination_token: str | None = None, + prefix: str | None = None, + tool_filters: ToolFilters | None = None, + ) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. This method calls the asynchronous list_tools method on the MCP session and adapts the returned tools to the AgentTool interface. + Args: + pagination_token: Optional token for pagination + prefix: Optional prefix to apply to tool names. If None, uses constructor default. + If explicitly provided (including empty string), overrides constructor default. + tool_filters: Optional filters to apply to tools. If None, uses constructor default. + If explicitly provided (including empty dict), overrides constructor default. + Returns: List[AgentTool]: A list of available tools adapted to the AgentTool interface """ @@ -200,13 +351,29 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + effective_prefix = self._prefix if prefix is None else prefix + effective_filters = self._tool_filters if tool_filters is None else tool_filters + async def _list_tools_async() -> ListToolsResult: return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) - mcp_tools = [MCPAgentTool(tool, self) for tool in list_tools_response.tools] + mcp_tools = [] + for tool in list_tools_response.tools: + # Apply prefix if specified + if effective_prefix: + prefixed_name = f"{effective_prefix}_{tool.name}" + mcp_tool = MCPAgentTool(tool, self, name_override=prefixed_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", tool.name, prefixed_name) + else: + mcp_tool = MCPAgentTool(tool, self) + + # Apply filters if specified + if self._should_include_tool_with_filters(mcp_tool, effective_filters): + mcp_tools.append(mcp_tool) + self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) @@ -358,8 +525,7 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes """ self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) - # Build a typed list of ToolResultContent. Use a clearer local name to avoid shadowing - # and annotate the result for mypy so it knows the intended element type. + # Build a typed list of ToolResultContent. mapped_contents: list[ToolResultContent] = [ mc for content in call_tool_result.content @@ -438,7 +604,7 @@ def _background_task(self) -> None: def _map_mcp_content_to_tool_result_content( self, - content: MCPTextContent | MCPImageContent | Any, + content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, ) -> Union[ToolResultContent, None]: """Maps MCP content types to tool result content types. @@ -462,6 +628,58 @@ def _map_mcp_content_to_tool_result_content( "source": {"bytes": base64.b64decode(content.data)}, } } + elif isinstance(content, MCPEmbeddedResource): + """ + TODO: Include URI information in results. + Models may find it useful to be aware not only of the information, + but the location of the information too. + + This may be difficult without taking an opinionated position. For example, + a content block may need to indicate that the following Image content block + is of particular URI. + """ + + self._log_debug_with_thread("mapping MCP embedded resource content") + + resource = content.resource + if isinstance(resource, TextResourceContents): + return {"text": resource.text} + elif isinstance(resource, BlobResourceContents): + try: + raw_bytes = base64.b64decode(resource.blob) + except Exception: + self._log_debug_with_thread("embedded resource blob could not be decoded - dropping") + return None + + if resource.mimeType and ( + resource.mimeType.startswith("text/") + or resource.mimeType + in ( + "application/json", + "application/xml", + "application/javascript", + "application/yaml", + "application/x-yaml", + ) + or resource.mimeType.endswith(("+json", "+xml")) + ): + try: + return {"text": raw_bytes.decode("utf-8", errors="replace")} + except Exception: + pass + + if resource.mimeType in MIME_TO_FORMAT: + return { + "image": { + "format": MIME_TO_FORMAT[resource.mimeType], + "source": {"bytes": raw_bytes}, + } + } + + self._log_debug_with_thread("embedded resource blob with non-textual/unknown mimeType - dropping") + return None + + return None # type: ignore[unreachable] # Defensive: future MCP resource types else: self._log_debug_with_thread("unhandled content type: %s - dropping content", content.__class__.__name__) return None @@ -478,5 +696,40 @@ def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures. raise MCPClientInitializationError("the client session was not initialized") return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on constructor filters.""" + return self._should_include_tool_with_filters(tool, self._tool_filters) + + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optional[ToolFilters]) -> bool: + """Check if a tool should be included based on provided filters.""" + if not filters: + return True + + # Apply allowed filter + if "allowed" in filters: + if not self._matches_patterns(tool, filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in filters: + if self._matches_patterns(tool, filters["rejected"]): + return False + + return True + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif isinstance(pattern, Pattern): + if pattern.match(tool.mcp_tool.name): + return True + elif isinstance(pattern, str): + if pattern == tool.mcp_tool.name: + return True + return False + def _is_session_active(self) -> bool: return self._background_thread is not None and self._background_thread.is_alive() diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 0660337a2..c80b80f64 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -8,16 +8,21 @@ import logging import os import sys +import uuid +import warnings from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence from typing_extensions import TypedDict, cast from strands.tools.decorator import DecoratedFunctionTool +from .._async import run_async +from ..experimental.tools import ToolProvider from ..types.tools import AgentTool, ToolSpec +from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -34,20 +39,27 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None + self._tool_providers: List[ToolProvider] = [] + self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: - """Process tools list that can contain tool names, paths, imported modules, or functions. + """Process tools list. + + Process list of tools that can contain local file path string, module import path string, + imported modules, @tool decorated functions, or instances of AgentTool. Args: tools: List of tool specifications. Can be: + 1. Local file path to a module based tool: `./path/to/module/tool.py` + 2. Module import path + 2.1. Path to a module based tool: `strands_tools.file_read` + 2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool` + 2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say` + 3. A module for a module based tool + 4. Instances of AgentTool (@tool decorated functions) + 5. Dictionaries with name/path keys (deprecated) - - String tool names (e.g., "calculator") - - File paths (e.g., "/path/to/tool.py") - - Imported Python modules (e.g., a module object) - - Functions decorated with @tool - - Dictionaries with name/path keys - - Instance of an AgentTool Returns: List of tool names that were processed. @@ -55,62 +67,90 @@ def process_tools(self, tools: List[Any]) -> List[str]: tool_names = [] def add_tool(tool: Any) -> None: - # Case 1: String file path - if isinstance(tool, str): - # Extract tool name from path - tool_name = os.path.basename(tool).split(".")[0] - self.load_tool_from_filepath(tool_name=tool_name, tool_path=tool) - tool_names.append(tool_name) - - # Case 2: Dictionary with name and path - elif isinstance(tool, dict) and "name" in tool and "path" in tool: - self.load_tool_from_filepath(tool_name=tool["name"], tool_path=tool["path"]) - tool_names.append(tool["name"]) - - # Case 3: Dictionary with path only - elif isinstance(tool, dict) and "path" in tool: - tool_name = os.path.basename(tool["path"]).split(".")[0] - self.load_tool_from_filepath(tool_name=tool_name, tool_path=tool["path"]) - tool_names.append(tool_name) - - # Case 4: Imported Python module - elif hasattr(tool, "__file__") and inspect.ismodule(tool): - # Get the module file path - module_path = tool.__file__ - # Extract the tool name from the module name - tool_name = tool.__name__.split(".")[-1] - - # Check for TOOL_SPEC in module to validate it's a Strands tool - if hasattr(tool, "TOOL_SPEC") and hasattr(tool, tool_name) and module_path: - self.load_tool_from_filepath(tool_name=tool_name, tool_path=module_path) - tool_names.append(tool_name) + try: + # String based tool + # Can be a file path, a module path, or a module path with a targeted function. Examples: + # './path/to/tool.py' + # 'my.module.tool' + # 'my.module.tool:tool_name' + if isinstance(tool, str): + tools = load_tool_from_string(tool) + for a_tool in tools: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Dictionary with name and path + elif isinstance(tool, dict) and "name" in tool and "path" in tool: + tools = load_tool_from_string(tool["path"]) + + tool_found = False + for a_tool in tools: + if a_tool.tool_name == tool["name"]: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + tool_found = True + + if not tool_found: + raise ValueError(f'Tool "{tool["name"]}" not found in "{tool["path"]}"') + + # Dictionary with path only + elif isinstance(tool, dict) and "path" in tool: + tools = load_tool_from_string(tool["path"]) + + for a_tool in tools: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Imported Python module + elif hasattr(tool, "__file__") and inspect.ismodule(tool): + # Extract the tool name from the module name + module_tool_name = tool.__name__.split(".")[-1] + + tools = load_tools_from_module(tool, module_tool_name) + for a_tool in tools: + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Case 5: AgentTools (which also covers @tool) + elif isinstance(tool, AgentTool): + self.register_tool(tool) + tool_names.append(tool.tool_name) + + # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool + elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): + for t in tool: + add_tool(t) + + # Case 5: ToolProvider + elif isinstance(tool, ToolProvider): + self._tool_providers.append(tool) + tool.add_consumer(self._registry_id) + + async def get_tools() -> Sequence[AgentTool]: + return await tool.load_tools() + + provider_tools = run_async(get_tools) + + for provider_tool in provider_tools: + self.register_tool(provider_tool) + tool_names.append(provider_tool.tool_name) else: - function_tools = self._scan_module_for_tools(tool) - for function_tool in function_tools: - self.register_tool(function_tool) - tool_names.append(function_tool.tool_name) - - if not function_tools: - logger.warning("tool_name=<%s>, module_path=<%s> | invalid agent tool", tool_name, module_path) - - # Case 5: AgentTools (which also covers @tool) - elif isinstance(tool, AgentTool): - self.register_tool(tool) - tool_names.append(tool.tool_name) - # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool - elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): - for t in tool: - add_tool(t) - else: - logger.warning("tool=<%s> | unrecognized tool specification", tool) + logger.warning("tool=<%s> | unrecognized tool specification", tool) - for a_tool in tools: - add_tool(a_tool) + except Exception as e: + exception_str = str(e) + logger.exception("tool_name=<%s> | failed to load tool", tool) + raise ValueError(f"Failed to load tool {tool}: {exception_str}") from e + for tool in tools: + add_tool(tool) return tool_names def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: - """Load a tool from a file path. + """DEPRECATED: Load a tool from a file path. Args: tool_name: Name of the tool. @@ -120,6 +160,13 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: FileNotFoundError: If the tool file is not found. ValueError: If the tool cannot be loaded. """ + warnings.warn( + "load_tool_from_filepath is deprecated and will be removed in Strands SDK 2.0. " + "`process_tools` automatically handles loading tools from a filepath.", + DeprecationWarning, + stacklevel=2, + ) + from .loader import ToolLoader try: @@ -496,6 +543,21 @@ def get_all_tool_specs(self) -> list[ToolSpec]: tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] return tools + def register_dynamic_tool(self, tool: AgentTool) -> None: + """Register a tool dynamically for temporary use. + + Args: + tool: The tool to register dynamically + + Raises: + ValueError: If a tool with this name already exists + """ + if tool.tool_name in self.registry or tool.tool_name in self.dynamic_tools: + raise ValueError(f"Tool '{tool.tool_name}' already exists") + + self.dynamic_tools[tool.tool_name] = tool + logger.debug("Registered dynamic tool: %s", tool.tool_name) + def validate_tool_spec(self, tool_spec: ToolSpec) -> None: """Validate tool specification against required schema. @@ -612,3 +674,20 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e) return tools + + def cleanup(self, **kwargs: Any) -> None: + """Synchronously clean up all tool providers in this registry.""" + # Attempt cleanup of all providers even if one fails to minimize resource leakage + exceptions = [] + for provider in self._tool_providers: + try: + provider.remove_consumer(self._registry_id) + logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) + except Exception as e: + exceptions.append(e) + logger.error( + "provider=<%s>, error=<%s> | failed to remove provider consumer", type(provider).__name__, e + ) + + if exceptions: + raise exceptions[0] diff --git a/src/strands/tools/structured_output/__init__.py b/src/strands/tools/structured_output/__init__.py new file mode 100644 index 000000000..777d5d846 --- /dev/null +++ b/src/strands/tools/structured_output/__init__.py @@ -0,0 +1,5 @@ +"""Structured output tools for the Strands Agents framework.""" + +from .structured_output_utils import convert_pydantic_to_tool_spec + +__all__ = ["convert_pydantic_to_tool_spec"] diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py new file mode 100644 index 000000000..f33a06915 --- /dev/null +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -0,0 +1,143 @@ +"""Context management for structured output in the event loop.""" + +import logging +from typing import TYPE_CHECKING, Optional, Type + +from pydantic import BaseModel + +from ...types.tools import ToolChoice, ToolSpec, ToolUse +from .structured_output_tool import StructuredOutputTool + +if TYPE_CHECKING: + from ..registry import ToolRegistry + +logger = logging.getLogger(__name__) + + +class StructuredOutputContext: + """Per-invocation context for structured output execution.""" + + def __init__(self, structured_output_model: Type[BaseModel] | None = None): + """Initialize a new structured output context. + + Args: + structured_output_model: Optional Pydantic model type for structured output. + """ + self.results: dict[str, BaseModel] = {} + self.structured_output_model: Type[BaseModel] | None = structured_output_model + self.structured_output_tool: StructuredOutputTool | None = None + self.forced_mode: bool = False + self.force_attempted: bool = False + self.tool_choice: ToolChoice | None = None + self.stop_loop: bool = False + self.expected_tool_name: Optional[str] = None + + if structured_output_model: + self.structured_output_tool = StructuredOutputTool(structured_output_model) + self.expected_tool_name = self.structured_output_tool.tool_name + + @property + def is_enabled(self) -> bool: + """Check if structured output is enabled for this context. + + Returns: + True if a structured output model is configured, False otherwise. + """ + return self.structured_output_model is not None + + def store_result(self, tool_use_id: str, result: BaseModel) -> None: + """Store a validated structured output result. + + Args: + tool_use_id: Unique identifier for the tool use. + result: Validated Pydantic model instance. + """ + self.results[tool_use_id] = result + + def get_result(self, tool_use_id: str) -> BaseModel | None: + """Retrieve a stored structured output result. + + Args: + tool_use_id: Unique identifier for the tool use. + + Returns: + The validated Pydantic model instance, or None if not found. + """ + return self.results.get(tool_use_id) + + def set_forced_mode(self, tool_choice: dict | None = None) -> None: + """Mark this context as being in forced structured output mode. + + Args: + tool_choice: Optional tool choice configuration. + """ + if not self.is_enabled: + return + self.forced_mode = True + self.force_attempted = True + self.tool_choice = tool_choice or {"any": {}} + + def has_structured_output_tool(self, tool_uses: list[ToolUse]) -> bool: + """Check if any tool uses are for the structured output tool. + + Args: + tool_uses: List of tool use dictionaries to check. + + Returns: + True if any tool use matches the expected structured output tool name, + False if no structured output tool is present or expected. + """ + if not self.expected_tool_name: + return False + return any(tool_use.get("name") == self.expected_tool_name for tool_use in tool_uses) + + def get_tool_spec(self) -> Optional[ToolSpec]: + """Get the tool specification for structured output. + + Returns: + Tool specification, or None if no structured output model. + """ + if self.structured_output_tool: + return self.structured_output_tool.tool_spec + return None + + def extract_result(self, tool_uses: list[ToolUse]) -> BaseModel | None: + """Extract and remove structured output result from stored results. + + Args: + tool_uses: List of tool use dictionaries from the current execution cycle. + + Returns: + The structured output result if found, or None if no result available. + """ + if not self.has_structured_output_tool(tool_uses): + return None + + for tool_use in tool_uses: + if tool_use.get("name") == self.expected_tool_name: + tool_use_id = str(tool_use.get("toolUseId", "")) + result = self.results.pop(tool_use_id, None) + if result is not None: + logger.debug("Extracted structured output for %s", tool_use.get("name")) + return result + return None + + def register_tool(self, registry: "ToolRegistry") -> None: + """Register the structured output tool with the registry. + + Args: + registry: The tool registry to register the tool with. + """ + if self.structured_output_tool and self.structured_output_tool.tool_name not in registry.dynamic_tools: + registry.register_dynamic_tool(self.structured_output_tool) + logger.debug("Registered structured output tool: %s", self.structured_output_tool.tool_name) + + def cleanup(self, registry: "ToolRegistry") -> None: + """Clean up the registered structured output tool from the registry. + + Args: + registry: The tool registry to clean up the tool from. + """ + if self.structured_output_tool and self.structured_output_tool.tool_name in registry.dynamic_tools: + del registry.dynamic_tools[self.structured_output_tool.tool_name] + logger.debug("Cleaned up structured output tool: %s", self.structured_output_tool.tool_name) diff --git a/src/strands/tools/structured_output/structured_output_tool.py b/src/strands/tools/structured_output/structured_output_tool.py new file mode 100644 index 000000000..25173d048 --- /dev/null +++ b/src/strands/tools/structured_output/structured_output_tool.py @@ -0,0 +1,158 @@ +"""Structured output tool implementation. + +This module provides a real tool implementation for structured output that integrates +with the existing tool execution and error handling infrastructure. +""" + +import logging +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Type + +from pydantic import BaseModel, ValidationError +from typing_extensions import override + +from ...types._events import ToolResultEvent +from ...types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse +from .structured_output_utils import convert_pydantic_to_tool_spec + +logger = logging.getLogger(__name__) + +_TOOL_SPEC_CACHE: dict[Type[BaseModel], ToolSpec] = {} + +if TYPE_CHECKING: + from ._structured_output_context import StructuredOutputContext + + +class StructuredOutputTool(AgentTool): + """Tool implementation for structured output validation.""" + + def __init__(self, structured_output_model: Type[BaseModel]) -> None: + """Initialize a structured output tool. + + Args: + structured_output_model: The Pydantic model class that defines the expected output structure. + """ + super().__init__() + self._structured_output_type = structured_output_model + self._tool_spec = self._get_tool_spec(structured_output_model) + self._tool_spec["description"] = ( + "IMPORTANT: This StructuredOutputTool should only be invoked as the last and final tool " + f"before returning the completed result to the caller. " + f"{self._tool_spec.get('description', '')}" + ) + self._tool_name = self._tool_spec.get("name", "StructuredOutputTool") + + @classmethod + def _get_tool_spec(cls, structured_output_model: Type[BaseModel]) -> ToolSpec: + """Get a cached tool spec for the given output type. + + Args: + structured_output_model: The Pydantic model class that defines the expected output structure. + + Returns: + Cached tool specification for the output type. + """ + if structured_output_model not in _TOOL_SPEC_CACHE: + _TOOL_SPEC_CACHE[structured_output_model] = convert_pydantic_to_tool_spec(structured_output_model) + return deepcopy(_TOOL_SPEC_CACHE[structured_output_model]) + + @property + def tool_name(self) -> str: + """Get the name of the tool. + + Returns: + The name of the tool (same as the Pydantic model class name). + """ + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification for this structured output tool. + + Returns: + The tool specification generated from the Pydantic model. + """ + return self._tool_spec + + @property + def tool_type(self) -> str: + """Identifies this as a structured output tool implementation. + + Returns: + "structured_output". + """ + return "structured_output" + + @property + def structured_output_model(self) -> Type[BaseModel]: + """Get the Pydantic model type for this tool. + + Returns: + The Pydantic model class. + """ + return self._structured_output_type + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Validate the structured output and return appropriate result. + + Args: + tool_use: The tool use request containing the data to validate. + invocation_state: Context for the tool invocation (kept for compatibility). + **kwargs: Additional keyword arguments, including structured_output_context. + + Yields: + Tool events with the last being the tool result (success or error). + """ + tool_input: dict[str, Any] = tool_use.get("input", {}) + tool_use_id = str(tool_use.get("toolUseId", "")) + + context: StructuredOutputContext = kwargs.get("structured_output_context") # type: ignore + try: + validated_object = self._structured_output_type(**tool_input) + logger.debug("tool_name=<%s> | structured output validated", self._tool_name) + context.store_result(tool_use_id, validated_object) + + result: ToolResult = { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": f"Successfully validated {self._tool_name} structured output"}], + } + + yield ToolResultEvent(result) + + except ValidationError as e: + error_details = [] + for error in e.errors(): + field_path = " -> ".join(str(loc) for loc in error["loc"]) if error["loc"] else "root" + error_details.append(f"Field '{field_path}': {error['msg']}") + + error_message = f"Validation failed for {self._tool_name}. Please fix the following errors:\n" + "\n".join( + f"- {detail}" for detail in error_details + ) + logger.error( + "tool_name=<%s> | structured output validation failed | error_message=<%s>", + self._tool_name, + error_message, + ) + + # Create error result that will be sent back to the LLM so it can decide if it needs to retry + validation_error_result: ToolResult = { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": error_message}], + } + + yield ToolResultEvent(validation_error_result) + + except Exception as e: + error_message = f"Unexpected error validating {self._tool_name}: {str(e)}" + logger.exception(error_message) + + exception_result: ToolResult = { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": error_message}], + } + + yield ToolResultEvent(exception_result) diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output/structured_output_utils.py similarity index 99% rename from src/strands/tools/structured_output.py rename to src/strands/tools/structured_output/structured_output_utils.py index 2c5922925..093d67f7c 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output/structured_output_utils.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from ..types.tools import ToolSpec +from ...types.tools import ToolSpec def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index e20bf658a..36977e90f 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,10 +5,12 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Sequence, cast +from pydantic import BaseModel from typing_extensions import override +from ..interrupt import Interrupt from ..telemetry import EventLoopMetrics from .citations import Citation from .content import Message @@ -220,6 +222,8 @@ def __init__( message: Message, metrics: "EventLoopMetrics", request_state: Any, + interrupts: Sequence[Interrupt] | None = None, + structured_output: BaseModel | None = None, ) -> None: """Initialize with the final execution results. @@ -228,8 +232,10 @@ def __init__( message: Final message from the model metrics: Execution metrics and performance data request_state: Final state of the agent execution + interrupts: Interrupts raised by user during agent execution. + structured_output: Optional structured output result """ - super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts, structured_output)}) @property @override @@ -237,6 +243,18 @@ def is_callback_event(self) -> bool: return False +class StructuredOutputEvent(TypedEvent): + """Event emitted when structured output is detected and processed.""" + + def __init__(self, structured_output: BaseModel) -> None: + """Initialize with the structured output result. + + Args: + structured_output: The parsed structured output instance + """ + super().__init__({"structured_output": structured_output}) + + class EventLoopThrottleEvent(TypedEvent): """Event emitted when the event loop is throttled due to rate limiting.""" @@ -313,12 +331,30 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancelled_event")).get("tool_use")).get("toolUseId")) + return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId")) @property def message(self) -> str: """The tool cancellation message.""" - return cast(str, self["message"]) + return cast(str, self["tool_cancel_event"]["message"]) + + +class ToolInterruptEvent(TypedEvent): + """Event emitted when a tool is interrupted.""" + + def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: + """Set interrupt in the event payload.""" + super().__init__({"tool_interrupt_event": {"tool_use": tool_use, "interrupts": interrupts}}) + + @property + def tool_use_id(self) -> str: + """The id of the tool interrupted.""" + return cast(str, cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId")) + + @property + def interrupts(self) -> list[Interrupt]: + """The interrupt instances.""" + return cast(list[Interrupt], self["tool_interrupt_event"]["interrupts"]) class ModelMessageEvent(TypedEvent): diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index 151c88f89..a2a4c7dce 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -6,5 +6,6 @@ from typing import TypeAlias from .content import ContentBlock, Messages +from .interrupt import InterruptResponse -AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None +AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponse] | Messages | None diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 2c240972b..2a7ad344e 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -23,20 +23,24 @@ class Usage(TypedDict, total=False): cacheWriteInputTokens: int -class Metrics(TypedDict): +class Metrics(TypedDict, total=False): """Performance metrics for model interactions. Attributes: latencyMs (int): Latency of the model request in milliseconds. + timeToFirstByteMs (int): Latency from sending model request to first + content chunk (contentBlockDelta or contentBlockStart) from the model in milliseconds. """ - latencyMs: int + latencyMs: Required[int] + timeToFirstByteMs: int StopReason = Literal[ "content_filtered", "end_turn", "guardrail_intervened", + "interrupt", "max_tokens", "stop_sequence", "tool_use", @@ -46,6 +50,7 @@ class Metrics(TypedDict): - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened +- "interrupt": Agent was interrupted for human input - "max_tokens": Maximum token limit reached - "stop_sequence": Stop sequence encountered - "tool_use": Model requested to use a tool diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 90f2b8d7f..b9c5bc769 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -75,3 +75,22 @@ class SessionException(Exception): """Exception raised when session operations fail.""" pass + + +class ToolProviderException(Exception): + """Exception raised when a tool provider fails to load or cleanup tools.""" + + pass + + +class StructuredOutputException(Exception): + """Exception raised when structured output validation fails after maximum retry attempts.""" + + def __init__(self, message: str): + """Initialize the exception with details about the failure. + + Args: + message: The error message describing the structured output failure + """ + self.message = message + super().__init__(message) diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py new file mode 100644 index 000000000..001ce6993 --- /dev/null +++ b/src/strands/types/interrupt.py @@ -0,0 +1,142 @@ +"""Interrupt related type definitions for human-in-the-loop workflows. + +Interrupt Flow: + ```mermaid + flowchart TD + A[Invoke Agent] --> B[Execute Hook/Tool] + B --> C{Interrupts Raised?} + C -->|No| D[Continue Agent Loop] + C -->|Yes| E[Stop Agent Loop] + E --> F[Return Interrupts] + F --> G[Respond to Interrupts] + G --> H[Execute Hook/Tool with Responses] + H --> I{New Interrupts?} + I -->|Yes| E + I -->|No| D + ``` + +Example: + ```Python + from typing import Any + + from strands import Agent, tool + from strands.hooks import BeforeToolCallEvent, HookProvider, HookRegistry + + + @tool + def delete_tool(key: str) -> bool: + print("DELETE_TOOL | deleting") + return True + + + class ToolInterruptHook(HookProvider): + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + registry.add_callback(BeforeToolCallEvent, self.approve) + + def approve(self, event: BeforeToolCallEvent) -> None: + if event.tool_use["name"] != "delete_tool": + return + + approval = event.interrupt("for_delete_tool", reason="APPROVAL") + if approval != "A": + event.cancel_tool = "approval was not granted" + + agent = Agent( + hooks=[ToolInterruptHook()], + tools=[delete_tool], + system_prompt="You delete objects given their keys.", + callback_handler=None, + ) + result = agent(f"delete object with key 'X'") + + if result.stop_reason == "interrupt": + responses = [] + for interrupt in result.interrupts: + if interrupt.name == "for_delete_tool": + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "A"}) + + result = agent(responses) + + ... + ``` + +Details: + - User raises interrupt on their hook event by calling `event.interrupt()`. + - User can raise one interrupt per hook callback. + - Interrupts stop the agent event loop. + - Interrupts are returned to the user in AgentResult. + - User resumes by invoking agent with interrupt responses. + - Second call to `event.interrupt()` returns user response. + - Process repeats if user raises additional interrupts. + - Interrupts are session managed in-between return and user response. +""" + +from typing import TYPE_CHECKING, Any, Protocol, TypedDict + +from ..interrupt import Interrupt, InterruptException + +if TYPE_CHECKING: + from ..agent import Agent + + +class _Interruptible(Protocol): + """Interface that adds interrupt support to hook events and tools.""" + + agent: "Agent" + + def interrupt(self, name: str, reason: Any = None, response: Any = None) -> Any: + """Trigger the interrupt with a reason. + + Args: name: User defined name for the interrupt. + Must be unique across hook callbacks. + reason: User provided reason for the interrupt. + response: Preemptive response from user if available. + + Returns: + The response from a human user when resuming from an interrupt state. + + Raises: + InterruptException: If human input is required. + """ + id = self._interrupt_id(name) + state = self.agent._interrupt_state + + interrupt_ = state.interrupts.setdefault(id, Interrupt(id, name, reason, response)) + if interrupt_.response: + return interrupt_.response + + raise InterruptException(interrupt_) + + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + reason: User provided reason for the interrupt. + + Returns: + Interrupt id. + """ + ... + + +class InterruptResponse(TypedDict): + """User response to an interrupt. + + Attributes: + interruptId: Unique identifier for the interrupt. + response: User response to the interrupt. + """ + + interruptId: str + response: Any + + +class InterruptResponseContent(TypedDict): + """Content block containing a user response to an interrupt. + + Attributes: + interruptResponse: User response to an interrupt event. + """ + + interruptResponse: InterruptResponse diff --git a/src/strands/types/session.py b/src/strands/types/session.py index e51816f74..926480f2c 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,8 +5,9 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional +from ..agent.interrupt import InterruptState from .content import Message if TYPE_CHECKING: @@ -104,11 +105,20 @@ def to_dict(self) -> dict[str, Any]: @dataclass class SessionAgent: - """Agent that belongs to a Session.""" + """Agent that belongs to a Session. + + Attributes: + agent_id: Unique id for the agent. + state: User managed state. + conversation_manager_state: State for conversation management. + created_at: Created at time. + updated_at: Updated at time. + """ agent_id: str - state: Dict[str, Any] - conversation_manager_state: Dict[str, Any] + state: dict[str, Any] + conversation_manager_state: dict[str, Any] + _internal_state: dict[str, Any] = field(default_factory=dict) # Strands managed state created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -121,6 +131,9 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": agent_id=agent.agent_id, conversation_manager_state=agent.conversation_manager.get_state(), state=agent.state.get(), + _internal_state={ + "interrupt_state": agent._interrupt_state.to_dict(), + }, ) @classmethod @@ -132,6 +145,11 @@ def to_dict(self) -> dict[str, Any]: """Convert the SessionAgent to a dictionary representation.""" return asdict(self) + def initialize_internal_state(self, agent: "Agent") -> None: + """Initialize internal state of agent.""" + if "interrupt_state" in self._internal_state: + agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"]) + @dataclass class Session: diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 18c7013ee..8343647b2 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -5,12 +5,14 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union from typing_extensions import NotRequired, TypedDict +from .interrupt import _Interruptible from .media import DocumentContent, ImageContent if TYPE_CHECKING: @@ -126,7 +128,7 @@ class ToolChoiceTool(TypedDict): @dataclass -class ToolContext: +class ToolContext(_Interruptible): """Context object containing framework-provided data for decorated tools. This object provides access to framework-level information that may be useful @@ -148,6 +150,17 @@ class ToolContext: agent: "Agent" invocation_state: dict[str, Any] + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + # Individual ToolChoice type aliases ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto] diff --git a/tests/fixtures/mock_agent_tool.py b/tests/fixtures/mock_agent_tool.py new file mode 100644 index 000000000..eed33731f --- /dev/null +++ b/tests/fixtures/mock_agent_tool.py @@ -0,0 +1,27 @@ +from typing import Any + +from strands.types.content import ToolUse +from strands.types.tools import AgentTool, ToolSpec + + +class MockAgentTool(AgentTool): + """Mock AgentTool implementation for testing.""" + + def __init__(self, name: str): + super().__init__() + self._tool_name = name + + @property + def tool_name(self) -> str: + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + return ToolSpec(name=self._tool_name, description="Mock tool", input_schema={}) + + @property + def tool_type(self) -> str: + return "mock" + + def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any): + yield f"Mock result for {self._tool_name}" diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py new file mode 100644 index 000000000..727d28a48 --- /dev/null +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -0,0 +1,41 @@ +from typing import Iterator, Literal, Tuple, Type + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import ( + HookEvent, + HookProvider, + HookRegistry, +) + + +class MockMultiAgentHookProvider(HookProvider): + def __init__(self, event_types: list[Type] | Literal["all"]): + if event_types == "all": + event_types = [ + MultiAgentInitializedEvent, + BeforeNodeCallEvent, + AfterNodeCallEvent, + AfterMultiAgentInvocationEvent, + ] + + self.events_received = [] + self.events_types = event_types + + @property + def event_types_received(self): + return [type(event) for event in self.events_received] + + def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + return len(self.events_received), iter(self.events_received) + + def register_hooks(self, registry: HookRegistry) -> None: + for event_type in self.events_types: + registry.add_callback(event_type, self.add_event) + + def add_event(self, event: HookEvent) -> None: + self.events_received.append(event) diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index c05089f34..56817a6e4 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union from pydantic import BaseModel @@ -25,8 +25,8 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]): - self.agent_responses = agent_responses + def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]): + self.agent_responses = [*agent_responses] self.index = 0 def format_chunk(self, event: Any) -> StreamEvent: @@ -53,7 +53,11 @@ async def structured_output( pass async def stream( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: Optional[Any] = None, ) -> AsyncGenerator[Any, None]: events = self.map_agent_message_to_events(self.agent_responses[self.index]) for event in events: diff --git a/tests/fixtures/say_tool.py b/tests/fixtures/say_tool.py new file mode 100644 index 000000000..4607b2501 --- /dev/null +++ b/tests/fixtures/say_tool.py @@ -0,0 +1,17 @@ +from strands import tool + + +@tool +def say(input: str) -> str: + """Say something.""" + return f"Hello {input}!" + + +@tool +def dont_say(input: str) -> str: + """Dont say something.""" + return "Didnt say anything!" + + +def not_a_tool() -> str: + return "Not a tool!" diff --git a/tests/fixtures/tool_with_spec_but_no_function.py b/tests/fixtures/tool_with_spec_but_no_function.py new file mode 100644 index 000000000..75f8bf6f6 --- /dev/null +++ b/tests/fixtures/tool_with_spec_but_no_function.py @@ -0,0 +1 @@ +TOOL_SPEC = {"hello": "world!"} diff --git a/tests/fixtures/tool_with_spec_but_non_callable_function.py b/tests/fixtures/tool_with_spec_but_non_callable_function.py new file mode 100644 index 000000000..0ca2f092c --- /dev/null +++ b/tests/fixtures/tool_with_spec_but_non_callable_function.py @@ -0,0 +1,3 @@ +TOOL_SPEC = {"hello": "world"} + +tool_with_spec_but_non_callable_function = "not a function!" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..c1ff13412 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +import warnings from uuid import uuid4 import pytest @@ -16,6 +17,8 @@ from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize @@ -326,6 +329,7 @@ def test_agent__call__( ], [tool.tool_spec], system_prompt, + tool_choice=None, ), unittest.mock.call( [ @@ -362,6 +366,7 @@ def test_agent__call__( ], [tool.tool_spec], system_prompt, + tool_choice=None, ), ], ) @@ -481,6 +486,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener expected_messages, unittest.mock.ANY, unittest.mock.ANY, + tool_choice=None, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -624,6 +630,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene expected_messages, unittest.mock.ANY, unittest.mock.ANY, + tool_choice=None, ) assert conversation_manager_spy.reduce_context.call_count == 2 @@ -887,10 +894,6 @@ def test_agent_tool_names(tools, agent): assert actual == expected -def test_agent__del__(agent): - del agent - - def test_agent_init_with_no_model_or_model_id(): agent = Agent() assert agent.model is not None @@ -1877,3 +1880,283 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text + + +def test_agent__call__handles_none_invocation_state(mock_model, agent): + """Test that agent handles None invocation_state without AttributeError.""" + mock_model.mock_stream.return_value = [ + {"contentBlockDelta": {"delta": {"text": "test response"}}}, + {"contentBlockStop": {}}, + ] + + # This should not raise AttributeError: 'NoneType' object has no attribute 'get' + result = agent("test", invocation_state=None) + + assert result.message["content"][0]["text"] == "test response" + assert result.stop_reason == "end_turn" + + +def test_agent__call__invocation_state_with_kwargs_deprecation_warning(agent, mock_event_loop_cycle): + """Test that kwargs trigger deprecation warning and are merged correctly with invocation_state.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + # Should have nested structure when both invocation_state and kwargs are provided + assert invocation_state["invocation_state"] == {"my": "state"} + assert invocation_state["other_kwarg"] == "foobar" + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}, other_kwarg="foobar") + + # Verify deprecation warning was issued + assert len(captured_warnings) == 1 + assert issubclass(captured_warnings[0].category, UserWarning) + assert "`**kwargs` parameter is deprecating, use `invocation_state` instead." in str(captured_warnings[0].message) + + +def test_agent__call__invocation_state_only_no_warning(agent, mock_event_loop_cycle): + """Test that using only invocation_state does not trigger warning and passes state directly.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + + assert invocation_state["my"] == "state" + assert "agent" in invocation_state + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}) + + assert len(captured_warnings) == 0 + + +def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator): + tool_use_message = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_decorated", + "input": {"random_string": "test input"}, + } + }, + ], + } + agent = Agent( + messages=[tool_use_message], + model=mock_model, + tools=[tool_decorated], + ) + + interrupt = Interrupt( + id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": []}) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + mock_model.mock_stream.return_value = agenerator( + [ + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "resumed"}}}, + {"contentBlockStop": {}}, + ] + ) + + prompt = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test response", + } + } + ] + agent(prompt) + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "test input"}], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state + + +def test_agent__call__resume_interrupt_invalid_prompt(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"prompt_type= \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + agent("invalid") + + +def test_agent__call__resume_interrupt_invalid_content(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"content_types=<\['text'\]> \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + agent([{"text": "invalid"}]) + + +def test_agent__call__resume_interrupt_invalid_id(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"interrupt_id= \| no interrupt found" + with pytest.raises(KeyError, match=exp_message): + agent([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) + + +def test_agent_structured_output_interrupt(user): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"cannot call structured output during interrupt" + with pytest.raises(RuntimeError, match=exp_message): + agent.structured_output(type(user), "invalid") + + +def test_agent_tool_caller_interrupt(): + @strands.tool(context=True) + def test_tool(tool_context): + tool_context.interrupt("test-interrupt") + + agent = Agent(tools=[test_tool]) + + exp_message = r"cannot raise interrupt in direct tool call" + with pytest.raises(RuntimeError, match=exp_message): + agent.tool.test_tool(agent=agent) + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state + + tru_messages = agent.messages + exp_messages = [] + assert tru_messages == exp_messages + + +def test_agent_tool_caller_interrupt_activated(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"cannot directly call tool during interrupt" + with pytest.raises(RuntimeError, match=exp_message): + agent.tool.test_tool() + + +def test_latest_message_tool_use_skips_model_invoke(tool_decorated): + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "I see the tool result"}]}]) + + messages: Messages = [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "tool_decorated", "input": {"random_string": "Hello"}}} + ], + } + ] + agent = Agent(model=mock_model, tools=[tool_decorated], messages=messages) + + agent() + + assert mock_model.index == 1 + assert len(agent.messages) == 3 + assert agent.messages[1]["content"][0]["toolResult"]["content"][0]["text"] == "Hello" + assert agent.messages[2]["content"][0]["text"] == "I see the tool result" + + +def test_agent_del_before_tool_registry_set(): + """Test that Agent.__del__ doesn't fail if called before tool_registry is set.""" + agent = Agent() + del agent.tool_registry + agent.__del__() # Should not raise + + +def test_agent__call__invalid_tool_name(): + @strands.tool + def shell(command: str): + pass + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_use_id", + "name": "invalid tool", + "input": "{}", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ) + + agent = Agent(tools=[shell], model=model) + result = agent("Test") + + # Ensure the stop_reason is + assert result.stop_reason == "end_turn" + + # Assert that there exists a message with a toolResponse + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "Error: tool_name= | invalid tool name pattern"}], + "status": "error", + "toolUseId": "tool_use_id", + } + } + ], + "role": "user", + } + + # And that it continued to the LLM call + assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"} diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 6c5625e0b..32266c3eb 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -124,7 +124,10 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): assert length == 6 assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -170,7 +173,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -231,7 +237,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 409b08a2d..3a3a3f5f7 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -1,7 +1,8 @@ import unittest.mock -from typing import cast +from typing import Optional, cast import pytest +from pydantic import BaseModel from strands.agent.agent_result import AgentResult from strands.telemetry.metrics import EventLoopMetrics @@ -48,6 +49,7 @@ def test__init__(mock_metrics, simple_message: Message): assert result.message == simple_message assert result.metrics == mock_metrics assert result.state == state + assert result.structured_output is None def test__str__simple(mock_metrics, simple_message: Message): @@ -95,3 +97,107 @@ def test__str__non_dict_content(mock_metrics): message_string = str(result) assert message_string == "Valid text\nMore valid text\n" + + +def test_to_dict(mock_metrics, simple_message: Message): + """Test that to_dict serializes AgentResult correctly.""" + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={"key": "value"}) + + data = result.to_dict() + + assert data == { + "type": "agent_result", + "message": simple_message, + "stop_reason": "end_turn", + } + + +def test_from_dict(): + """Test that from_dict works with valid data.""" + data = { + "type": "agent_result", + "message": {"role": "assistant", "content": [{"text": "Test response"}]}, + "stop_reason": "end_turn", + } + + result = AgentResult.from_dict(data) + + assert result.message == data["message"] + assert result.stop_reason == data["stop_reason"] + assert isinstance(result.metrics, EventLoopMetrics) + assert result.state == {} + + +def test_roundtrip_serialization(mock_metrics, complex_message: Message): + """Test that to_dict() and from_dict() work together correctly.""" + original = AgentResult( + stop_reason="max_tokens", message=complex_message, metrics=mock_metrics, state={"test": "data"} + ) + + # Serialize and deserialize + data = original.to_dict() + restored = AgentResult.from_dict(data) + + assert restored.message == original.message + assert restored.stop_reason == original.stop_reason + assert isinstance(restored.metrics, EventLoopMetrics) + assert restored.state == {} # State is not serialized + + +# Tests for structured output functionality +class StructuredOutputModel(BaseModel): + """Test model for structured output.""" + + name: str + value: int + optional_field: Optional[str] = None + + +def test__init__with_structured_output(mock_metrics, simple_message: Message): + """Test that AgentResult can be initialized with structured_output.""" + stop_reason: StopReason = "end_turn" + state = {"key": "value"} + structured_output = StructuredOutputModel(name="test", value=42) + + result = AgentResult( + stop_reason=stop_reason, + message=simple_message, + metrics=mock_metrics, + state=state, + structured_output=structured_output, + ) + + assert result.stop_reason == stop_reason + assert result.message == simple_message + assert result.metrics == mock_metrics + assert result.state == state + assert result.structured_output == structured_output + assert isinstance(result.structured_output, StructuredOutputModel) + assert result.structured_output.name == "test" + assert result.structured_output.value == 42 + + +def test__init__structured_output_defaults_to_none(mock_metrics, simple_message: Message): + """Test that structured_output defaults to None when not provided.""" + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + + assert result.structured_output is None + + +def test__str__with_structured_output(mock_metrics, simple_message: Message): + """Test that str() is not affected by structured_output.""" + structured_output = StructuredOutputModel(name="test", value=42) + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + structured_output=structured_output, + ) + + # The string representation should only include the message text, not structured output + message_string = str(result) + assert message_string == "Hello world!\n" + assert "test" not in message_string + assert "42" not in message_string diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py new file mode 100644 index 000000000..b679faed0 --- /dev/null +++ b/tests/strands/agent/test_agent_structured_output.py @@ -0,0 +1,414 @@ +"""Tests for Agent structured output functionality.""" + +from typing import Optional +from unittest import mock +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from strands import Agent +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output.structured_output_tool import StructuredOutputTool +from strands.types._events import EventLoopStopEvent +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +class UserModel(BaseModel): + """Test user model for structured output.""" + + name: str + age: int + email: str + + +class ProductModel(BaseModel): + """Test product model for structured output.""" + + title: str + price: float + description: Optional[str] = None + + +@pytest.fixture +def mock_model(): + """Create a mock model.""" + model = Mock() + + async def mock_stream(*args, **kwargs): + yield {"contentBlockDelta": {"delta": {"text": "test response"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + model.stream.side_effect = lambda *args, **kwargs: mock_stream(*args, **kwargs) + return model + + +@pytest.fixture +def mock_metrics(): + return mock.Mock(spec=EventLoopMetrics) + + +@pytest.fixture +def user_model(): + """Return the test user model class.""" + return UserModel + + +@pytest.fixture +def product_model(): + """Return the test product model class.""" + return ProductModel + + +class TestAgentStructuredOutputInit: + """Test Agent initialization with structured output model.""" + + def test_agent_init_with_structured_output_model(self, user_model): + """Test that Agent can be initialized with a structured_output_model.""" + agent = Agent(structured_output_model=user_model) + + assert agent._default_structured_output_model == user_model + assert agent.model is not None + + def test_agent_init_without_structured_output_model(self): + """Test that Agent can be initialized without structured_output_model.""" + agent = Agent() + + assert agent._default_structured_output_model is None + assert agent.model is not None + + +class TestAgentStructuredOutputInvocation: + """Test Agent invocation with structured output.""" + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_structured_output_model(self, mock_event_loop, user_model, mock_model, mock_metrics): + """Test Agent.__call__ with structured_output_model parameter.""" + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == user_model + + # Return a successful result + test_user = UserModel(name="John", age=30, email="john@example.com") + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=test_user, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent and call with structured_output_model + agent = Agent(model=mock_model) + agent("Extract user info", structured_output_model=user_model) + + # Verify event_loop_cycle was called with correct context + mock_event_loop.assert_called_once() + call_kwargs = mock_event_loop.call_args[1] + assert "structured_output_context" in call_kwargs + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_default_structured_output_model( + self, mock_event_loop, product_model, mock_model, mock_metrics + ): + """Test Agent.__call__ uses default structured_output_model when not specified.""" + + # Setup mock event loop + pm = ProductModel(title="Widget", price=9.99) + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == product_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=pm, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default structured_output_model + agent = Agent(model=mock_model, structured_output_model=product_model) + result = agent("Get product info") + + # Verify result uses default model + assert result.structured_output is pm + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_override_default_structured_output_model( + self, mock_event_loop, user_model, product_model, mock_model, mock_metrics + ): + """Test that invocation-level structured_output_model overrides default.""" + + # Setup mock event loop + um = UserModel(name="Jane", age=25, email="jane@example.com") + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + # Should use user_model, not the default product_model + assert structured_output_context.structured_output_model == user_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=um, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default product_model, but override with user_model + agent = Agent(model=mock_model, structured_output_model=product_model) + result = agent("Get user info", structured_output_model=user_model) + + # Verify result uses override model + assert result.structured_output is um + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_invoke_async_with_structured_output( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.invoke_async with structured_output_model.""" + + # Setup mock event loop + um = UserModel(name="Alice", age=28, email="alice@example.com") + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == user_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=um, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent and call async + agent = Agent(model=mock_model) + result = await agent.invoke_async("Get user", structured_output_model=user_model) + + # Verify result + assert result.structured_output is um + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_stream_async_with_structured_output( + self, mock_event_loop, product_model, mock_model, mock_metrics + ): + """Test Agent.stream_async with structured_output_model.""" + + # Setup mock event loop + pm = ProductModel(title="Gadget", price=19.99, description="Cool gadget") + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == product_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=pm, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent and stream async + agent = Agent(model=mock_model) + events = [] + async for event in agent.stream_async("Get product", structured_output_model=product_model): + events.append(event) + + # Verify we got result event + assert len(events) > 0 + result_event = events[-1] + assert "result" in result_event + result = result_event["result"] + assert result.structured_output is pm + + +class TestAgentStructuredOutputContext: + """Test StructuredOutputContext integration with Agent.""" + + @patch("strands.agent.agent.event_loop_cycle") + def test_structured_output_context_created_with_model(self, mock_event_loop, user_model, mock_model, mock_metrics): + """Test that StructuredOutputContext is created when structured_output_model is provided.""" + context = None + + async def mock_cycle(*args, **kwargs): + nonlocal context + context = kwargs.get("structured_output_context") + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + agent("Test", structured_output_model=user_model) + + # Verify context was created and passed + assert context is not None + assert isinstance(context, StructuredOutputContext) + assert context.structured_output_model == user_model + assert context.is_enabled is True + + @patch("strands.agent.agent.event_loop_cycle") + def test_structured_output_context_none_without_model(self, mock_event_loop, mock_model, mock_metrics): + """Test that StructuredOutputContext is created with None when no model provided.""" + context = None + + async def mock_cycle(*args, **kwargs): + nonlocal context + context = kwargs.get("structured_output_context") + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + agent("Test") # No structured_output_model + + # Verify context was created but disabled + assert context is not None + assert isinstance(context, StructuredOutputContext) + assert context.structured_output_model is None + assert context.is_enabled is False + + @patch("strands.tools.registry.ToolRegistry.register_dynamic_tool") + @patch("strands.agent.agent.event_loop_cycle") + def test_structured_output_tool_registered_dynamically( + self, mock_event_loop, mock_register, user_model, mock_model, mock_metrics + ): + """Test that StructuredOutputTool is registered dynamically when structured output is used.""" + captured_tool = None + + def capture_tool(tool): + nonlocal captured_tool + captured_tool = tool + + mock_register.side_effect = capture_tool + + async def mock_cycle(*args, **kwargs): + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + agent("Test", structured_output_model=user_model) + + # Verify tool was registered + mock_register.assert_called_once() + assert captured_tool is not None + assert isinstance(captured_tool, StructuredOutputTool) + assert captured_tool.structured_output_model == user_model + + +class TestAgentStructuredOutputEdgeCases: + """Test edge cases for structured output in Agent.""" + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_with_no_structured_output(self, mock_event_loop, mock_model, mock_metrics): + """Test that agent works normally when no structured output is specified.""" + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model is None + assert structured_output_context.is_enabled is False + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Normal response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + result = agent("Normal query") + + # Result should not have structured output + assert result.structured_output is None + assert result.message["content"][0]["text"] == "Normal response" + + def test_agent_multiple_structured_output_models(self, user_model, product_model, mock_metrics): + """Test that agent can switch between different structured output models.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "User response"}]}, + {"role": "assistant", "content": [{"text": "Product response"}]}, + ] + ) + + agent = Agent(model=model) + + # First call with user model + with patch("strands.agent.agent.event_loop_cycle") as mock_event_loop: + um = UserModel(name="Bob", age=40, email="bob@example.com") + + async def mock_user_cycle(*args, **kwargs): + ctx = kwargs.get("structured_output_context") + assert ctx.structured_output_model == user_model + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "User response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=um, + ) + + mock_event_loop.side_effect = mock_user_cycle + result1 = agent("Get user", structured_output_model=user_model) + assert result1.structured_output is um + + # Second call with product model + with patch("strands.agent.agent.event_loop_cycle") as mock_event_loop: + pm = ProductModel(title="Item", price=5.99) + + async def mock_product_cycle(*args, **kwargs): + ctx = kwargs.get("structured_output_context") + assert ctx.structured_output_model == product_model + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Product response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=pm, + ) + + mock_event_loop.side_effect = mock_product_cycle + result2 = agent("Get product", structured_output_model=product_model) + assert result2.structured_output is pm diff --git a/tests/strands/agent/test_interrupt.py b/tests/strands/agent/test_interrupt.py new file mode 100644 index 000000000..e248c29a6 --- /dev/null +++ b/tests/strands/agent/test_interrupt.py @@ -0,0 +1,61 @@ +import pytest + +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt + + +@pytest.fixture +def interrupt(): + return Interrupt(id="test_id", name="test_name", reason="test reason") + + +def test_interrupt_activate(): + interrupt_state = InterruptState() + + interrupt_state.activate(context={"test": "context"}) + + assert interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {"test": "context"} + assert tru_context == exp_context + + +def test_interrupt_deactivate(): + interrupt_state = InterruptState(context={"test": "context"}, activated=True) + + interrupt_state.deactivate() + + assert not interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {} + assert tru_context == exp_context + + +def test_interrupt_state_to_dict(interrupt): + interrupt_state = InterruptState(interrupts={"test_id": interrupt}, context={"test": "context"}, activated=True) + + tru_data = interrupt_state.to_dict() + exp_data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + assert tru_data == exp_data + + +def test_interrupt_state_from_dict(): + data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + + tru_state = InterruptState.from_dict(data) + exp_state = InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + assert tru_state == exp_state diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index 6003a1710..4b69e6653 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -19,6 +19,8 @@ def __init__(self, summary_response="This is a summary of the conversation."): self.messages = [] self.model = Mock() self.call_tracker = Mock() + self.tool_registry = Mock() + self.tool_names = [] def __call__(self, prompt): """Mock agent call that returns a summary.""" @@ -608,3 +610,30 @@ def test_summarizing_conversation_manager_properly_records_removed_message_count # so we dont count this toward the total: # 4 (Previously removed messages) + 2 (removed messages) - 1 (Previous summary message) = 5 assert manager.removed_message_count == 5 + + +@patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") +def test_summarizing_conversation_manager_generate_summary_with_noop_tool(mock_registry_cls, summarizing_manager): + mock_registry = mock_registry_cls.return_value + + messages = [{"role": "user", "content": [{"text": "test"}]}] + agent = create_mock_agent() + + original_tool_registry = agent.tool_registry + summarizing_manager._generate_summary(messages, agent) + + assert original_tool_registry == agent.tool_registry + mock_registry.register_tool.assert_called_once() + + +@patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") +def test_summarizing_conversation_manager_generate_summary_with_tools(mock_registry_cls, summarizing_manager): + mock_registry = mock_registry_cls.return_value + + messages = [{"role": "user", "content": [{"text": "test"}]}] + agent = create_mock_agent() + agent.tool_names = ["test_tool"] + + summarizing_manager._generate_summary(messages, agent) + + mock_registry.register_tool.assert_not_called() diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 2b71f3502..72c63e897 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,20 +1,24 @@ import concurrent import unittest.mock -from unittest.mock import MagicMock, call, patch +from unittest.mock import ANY, MagicMock, call, patch import pytest import strands import strands.telemetry +from strands.agent.interrupt import InterruptState from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, + BeforeToolCallEvent, HookRegistry, MessageAddedEvent, ) +from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry +from strands.types._events import EventLoopStopEvent from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -22,6 +26,7 @@ ModelThrottledException, ) from tests.fixtures.mock_hook_provider import MockHookProvider +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -138,6 +143,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor + mock._interrupt_state = InterruptState() return mock @@ -169,7 +175,7 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -201,7 +207,7 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -239,7 +245,7 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -330,7 +336,7 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -372,6 +378,7 @@ async def test_event_loop_cycle_tool_result( ], tool_registry.get_all_tool_specs(), "p1", + tool_choice=None, ) @@ -445,7 +452,7 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -739,6 +746,8 @@ async def test_event_loop_cycle_with_parent_span( async def test_request_state_initialization(alist): # Create a mock agent mock_agent = MagicMock() + # not setting this to False results in endless recursion + mock_agent._interrupt_state.activated = False mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) # Call without providing request_state @@ -747,7 +756,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _, _ = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -759,7 +768,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _, _ = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state @@ -862,3 +871,196 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, assert next(events) == MessageAddedEvent( agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} ) + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt(agent, model, tool_stream, agenerator, alist): + def interrupt_callback(event): + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator(tool_stream)] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, tru_interrupts, _ = events[-1]["stop"] + exp_stop_reason = "interrupt" + exp_interrupts = [ + Interrupt( + id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ), + ] + + assert tru_stop_reason == exp_stop_reason and tru_interrupts == exp_interrupts + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": True, + "context": { + "tool_results": [], + "tool_use_message": { + "content": [ + { + "toolUse": { + "input": {"random_string": "abcdEfghI123"}, + "name": "tool_for_testing", + "toolUseId": "t1", + }, + }, + ], + "role": "assistant", + }, + }, + "interrupts": { + "v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9": { + "id": "v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + "name": "test_name", + "reason": "test reason", + "response": None, + }, + }, + } + assert tru_state == exp_state + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt_resume(agent, model, tool, tool_times_2, agenerator, alist): + interrupt = Interrupt( + id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + + tool_use_message = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_for_testing", + "input": {"random_string": "test input"}, + } + }, + { + "toolUse": { + "toolUseId": "t2", + "name": "tool_times_2", + "input": {}, + } + }, + ], + } + tool_results = [ + { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + ] + + agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": tool_results}) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator([{"contentBlockStop": {}}])] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + }, + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "test input"}], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state + + +@pytest.mark.asyncio +async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): + model.stream = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_use_id", + "name": "invalid tool", + "input": "{}", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ).stream + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + + # ensure that we got end_turn and not tool_use + assert events[-1] == EventLoopStopEvent( + stop_reason="end_turn", + message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"}, + metrics=ANY, + request_state={}, + ) + + # Ensure that an "invalid tool name" message was added properly + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "Error: tool_name= | invalid tool name pattern"}], + "status": "error", + "toolUseId": "tool_use_id", + } + } + ], + "role": "user", + } diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py new file mode 100644 index 000000000..6d3e3a9b5 --- /dev/null +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -0,0 +1,439 @@ +"""Tests for structured output integration in the event loop.""" + +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.registry import ToolRegistry +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.types._events import EventLoopStopEvent, StructuredOutputEvent + + +class UserModel(BaseModel): + """Test model for structured output.""" + + name: str + age: int + email: str + + +class ProductModel(BaseModel): + """Another test model.""" + + title: str + price: float + in_stock: bool + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with required attributes.""" + agent = Mock(name="agent") + agent.model = Mock() + agent.system_prompt = "Test system prompt" + agent.messages = [] + agent.tool_registry = ToolRegistry() + agent.event_loop_metrics = EventLoopMetrics() + agent.hooks = Mock() + agent.hooks.invoke_callbacks = Mock() + agent.trace_span = None + agent.tool_executor = Mock() + agent._append_message = Mock() + + # Set up _interrupt_state properly + agent._interrupt_state = Mock() + agent._interrupt_state.activated = False + agent._interrupt_state.context = {} + + return agent + + +@pytest.fixture +def structured_output_context(): + """Create a structured output context with a test model.""" + return StructuredOutputContext(structured_output_model=UserModel) + + +@pytest.fixture +def agenerator(): + """Helper to create async generators.""" + + def _agenerator(items): + async def gen(): + for item in items: + yield item + + return gen() + + return _agenerator + + +@pytest.fixture +def alist(): + """Helper to consume async generators.""" + + async def _alist(async_gen): + items = [] + async for item in async_gen: + items.append(item) + return items + + return _alist + + +@pytest.mark.asyncio +async def test_event_loop_cycle_with_structured_output_context(mock_agent, agenerator, alist): + """Test event_loop_cycle with structured output context passed but not enabled.""" + # Create a context that's not enabled (no model) + structured_output_context = StructuredOutputContext() + + # Setup model to return a text response + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user data"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + # Run event loop cycle with structured output context + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should have received events + assert len(events) > 0 + + # The context should be passed through but not enabled + assert not structured_output_context.is_enabled + + +@pytest.mark.asyncio +async def test_event_loop_forces_structured_output_on_end_turn( + mock_agent, structured_output_context, agenerator, alist +): + """Test that event loop forces structured output tool when model returns end_turn.""" + # First call returns end_turn without using structured output tool + mock_agent.model.stream.side_effect = [ + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + # Second call (forced) uses the structured output tool + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "UserModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"name": "John", "age": 30, "email": "john@example.com"}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), + ] + + # Mock tool executor to handle the structured output tool + mock_agent.tool_executor._execute = Mock( + return_value=agenerator( + [ + # Tool execution events would go here + ] + ) + ) + + # Mock recurse_event_loop to return final result + with patch("strands.event_loop.event_loop.recurse_event_loop") as mock_recurse: + # Create a mock EventLoopStopEvent with the expected structure + mock_stop_event = Mock() + mock_stop_event.stop = ( + "end_turn", + {"role": "assistant", "content": [{"text": "Done"}]}, + mock_agent.event_loop_metrics, + {}, + None, + UserModel(name="John", age=30, email="john@example.com"), + ) + mock_stop_event.__getitem__ = lambda self, key: {"stop": self.stop}[key] + + mock_recurse.return_value = agenerator([mock_stop_event]) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + # Should have appended a message to force structured output + mock_agent._append_message.assert_called_once() + args = mock_agent._append_message.call_args[0][0] + assert args["role"] == "user" + + # Should have called recurse_event_loop with the context + mock_recurse.assert_called_once() + call_kwargs = mock_recurse.call_args[1] + assert call_kwargs["structured_output_context"] == structured_output_context + + +@pytest.mark.asyncio +async def test_structured_output_tool_execution_extracts_result( + mock_agent, structured_output_context, agenerator, alist +): + """Test that structured output result is extracted from tool execution.""" + # Model uses the structured output tool + mock_agent.model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "UserModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"name": "Alice", "age": 25, "email": "alice@test.com"}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + # Mock the tool executor to return an async generator + mock_agent.tool_executor._execute = Mock(return_value=agenerator([])) + + # Mock extract_result to return a model instance + test_result = UserModel(name="Alice", age=25, email="alice@test.com") + structured_output_context.extract_result = Mock(return_value=test_result) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should yield StructuredOutputEvent + structured_output_events = [e for e in events if isinstance(e, StructuredOutputEvent)] + assert len(structured_output_events) == 1 + assert structured_output_events[0]["structured_output"] == test_result + + # Extract_result should have been called + structured_output_context.extract_result.assert_called_once() + + +@pytest.mark.asyncio +async def test_structured_output_context_not_enabled(mock_agent, agenerator, alist): + """Test event loop with structured output context that's not enabled.""" + # Create a context that's not enabled (no model) + structured_output_context = StructuredOutputContext() + assert not structured_output_context.is_enabled + + # Model returns end_turn + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Regular response"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should complete normally without forcing structured output + stop_events = [e for e in events if isinstance(e, EventLoopStopEvent)] + assert len(stop_events) == 1 + assert stop_events[0]["stop"][-1] is None + + +@pytest.mark.asyncio +async def test_structured_output_forced_mode(mock_agent, agenerator, alist): + """Test event loop with structured output in forced mode.""" + # Create context in forced mode + structured_output_context = StructuredOutputContext(structured_output_model=ProductModel) + structured_output_context.set_forced_mode(tool_choice={"tool": {"name": "ProductModel"}}) + + # Model should be called with only the structured output tool spec + mock_agent.model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "ProductModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"title": "Book", "price": 19.99, "in_stock": true}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + # Mock tool executor + mock_agent.tool_executor._execute = Mock(return_value=agenerator([])) + + # Mock extract_result + test_result = ProductModel(title="Book", price=19.99, in_stock=True) + structured_output_context.extract_result = Mock(return_value=test_result) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + # Verify model.stream was called with the forced tool spec + mock_agent.model.stream.assert_called_once() + call_args = mock_agent.model.stream.call_args + + # The model.stream method signature (from streaming.py) is: + # model.stream(messages, tool_specs, system_prompt, tool_choice=tool_choice) + tool_specs = call_args.args[1] if len(call_args.args) > 1 else None + + # In forced mode, only the structured output tool spec should be passed + assert tool_specs is not None, "Expected tool_specs to be provided" + assert isinstance(tool_specs, list), f"Expected tool_specs to be a list, got {type(tool_specs)}" + assert len(tool_specs) == 1 + assert tool_specs[0]["name"] == "ProductModel" + + +@pytest.mark.asyncio +async def test_recurse_event_loop_with_structured_output(mock_agent, structured_output_context, agenerator, alist): + """Test recurse_event_loop preserves structured output context.""" + invocation_state = { + "event_loop_cycle_trace": Mock(), + "request_state": {}, + } + + # Mock event_loop_cycle to verify it receives the context + with patch("strands.event_loop.event_loop.event_loop_cycle") as mock_cycle: + # Create a mock EventLoopStopEvent with the expected structure + mock_stop_event = Mock(spec=EventLoopStopEvent) + mock_stop_event.stop = ( + "end_turn", + {"role": "assistant", "content": [{"text": "Done"}]}, + mock_agent.event_loop_metrics, + {}, + None, + UserModel(name="Test", age=20, email="test@example.com"), + ) + mock_stop_event.__getitem__ = lambda self, key: {"stop": self.stop}[key] + + mock_cycle.return_value = agenerator([mock_stop_event]) + + stream = recurse_event_loop( + agent=mock_agent, + invocation_state=invocation_state, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Verify event_loop_cycle was called with the context + mock_cycle.assert_called_once() + call_kwargs = mock_cycle.call_args[1] + assert call_kwargs["structured_output_context"] == structured_output_context + + # Verify the result includes structured output + stop_events = [ + e for e in events if isinstance(e, EventLoopStopEvent) or (hasattr(e, "stop") and hasattr(e, "__getitem__")) + ] + assert len(stop_events) == 1 + stop_event = stop_events[0] + if hasattr(stop_event, "__getitem__"): + assert stop_event["stop"][5].name == "Test" + else: + assert stop_event.stop[5].name == "Test" + + +@pytest.mark.asyncio +async def test_structured_output_stops_loop_after_extraction(mock_agent, structured_output_context, agenerator, alist): + """Test that loop stops after structured output is extracted.""" + # Model uses the structured output tool + mock_agent.model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "UserModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"name": "Bob", "age": 35, "email": "bob@test.com"}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + # Mock tool executor + mock_agent.tool_executor._execute = Mock(return_value=agenerator([])) + + # Mock extract_result to return a result and set stop_loop + test_result = UserModel(name="Bob", age=35, email="bob@test.com") + + def mock_extract(tool_uses): + structured_output_context.stop_loop = True + return test_result + + structured_output_context.extract_result = Mock(side_effect=mock_extract) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should have a StructuredOutputEvent with the result + structured_output_events = [e for e in events if isinstance(e, StructuredOutputEvent)] + assert len(structured_output_events) == 1 + assert structured_output_events[0]["structured_output"] == test_result + + # Verify stop_loop was set + assert structured_output_context.stop_loop + + # Extract_result should have been called + structured_output_context.extract_result.assert_called_once() diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 1de957619..92bf0de96 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -491,7 +491,7 @@ def test_extract_usage_metrics_with_cache_tokens(): "content": [], }, {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, + {"latencyMs": 0, "timeToFirstByteMs": 0}, ), }, ], @@ -781,7 +781,7 @@ async def test_stream_messages(agenerator, alist): "end_turn", {"role": "assistant", "content": [{"text": "test"}]}, {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, + {"latencyMs": 0, "timeToFirstByteMs": 0}, ) }, ] @@ -791,6 +791,7 @@ async def test_stream_messages(agenerator, alist): [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}], None, "test prompt", + tool_choice=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py new file mode 100644 index 000000000..e17044527 --- /dev/null +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -0,0 +1,157 @@ +"""Tests for streaming.py with structured output support.""" + +import unittest.mock + +import pytest +from pydantic import BaseModel + +import strands.event_loop.streaming +from strands.tools.structured_output.structured_output_tool import StructuredOutputTool +from strands.types._events import TypedEvent + + +class SampleModel(BaseModel): + """Sample model for structured output.""" + + name: str + age: int + + +@pytest.fixture(autouse=True) +def moto_autouse(moto_env, moto_mock_aws): + _ = moto_env + _ = moto_mock_aws + + +@pytest.mark.asyncio +async def test_stream_messages_with_tool_choice(agenerator, alist): + """Test stream_messages with tool_choice parameter for structured output.""" + mock_model = unittest.mock.MagicMock() + mock_model.stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "test-123", "name": "SampleModel"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "test", "age": 25}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + }, + ] + ) + + # Create a structured output tool and get its spec + structured_tool = StructuredOutputTool(SampleModel) + tool_spec = structured_tool.tool_spec + tool_choice = {"tool": {"name": "SampleModel"}} + + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=[{"role": "user", "content": [{"text": "Generate a test model"}]}], + tool_specs=[tool_spec], + tool_choice=tool_choice, + ) + + tru_events = await alist(stream) + + # Verify the model.stream was called with tool_choice + mock_model.stream.assert_called_with( + [{"role": "user", "content": [{"text": "Generate a test model"}]}], + [tool_spec], + "test prompt", + tool_choice=tool_choice, + ) + + # Verify we get the expected events + assert len(tru_events) > 0 + + # Find the stop event + stop_event = None + for event in tru_events: + if isinstance(event, dict) and "stop" in event: + stop_event = event + break + + assert stop_event is not None + assert stop_event["stop"][0] == "tool_use" + + # Ensure that we're getting typed events + non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] + assert non_typed_events == [] + + +@pytest.mark.asyncio +async def test_stream_messages_with_forced_structured_output(agenerator, alist): + """Test stream_messages with forced structured output tool.""" + mock_model = unittest.mock.MagicMock() + + # Simulate a response with tool use + mock_model.stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "SampleModel"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "Alice", "age": 30}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + "metrics": {"latencyMs": 150}, + } + }, + ] + ) + + # Create a structured output tool and get its spec + structured_tool = StructuredOutputTool(SampleModel) + tool_spec = structured_tool.tool_spec + tool_choice = {"any": {}} + + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="Extract user information", + messages=[{"role": "user", "content": [{"text": "Alice is 30 years old"}]}], + tool_specs=[tool_spec], + tool_choice=tool_choice, + ) + + tru_events = await alist(stream) + + # Verify the model.stream was called with the forced tool choice + mock_model.stream.assert_called_with( + [{"role": "user", "content": [{"text": "Alice is 30 years old"}]}], + [tool_spec], + "Extract user information", + tool_choice=tool_choice, + ) + + assert len(tru_events) > 0 + + # Find the stop event and verify it contains the extracted data + stop_event = None + for event in tru_events: + if isinstance(event, dict) and "stop" in event: + stop_event = event + break + + assert stop_event is not None + stop_reason, message, usage, metrics = stop_event["stop"] + + assert stop_reason == "tool_use" + assert message["role"] == "assistant" + assert len(message["content"]) > 0 + + # Check that the tool use contains the expected data + tool_use_content = None + for content in message["content"]: + if "toolUse" in content: + tool_use_content = content["toolUse"] + break + + assert tool_use_content is not None + assert tool_use_content["name"] == "SampleModel" + assert tool_use_content["input"] == {"name": "Alice", "age": 30} diff --git a/tests/strands/experimental/hooks/multiagent/__init__.py b/tests/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/hooks/multiagent/test_events.py b/tests/strands/experimental/hooks/multiagent/test_events.py new file mode 100644 index 000000000..6c4d7c4e7 --- /dev/null +++ b/tests/strands/experimental/hooks/multiagent/test_events.py @@ -0,0 +1,107 @@ +"""Tests for multi-agent execution lifecycle events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import BaseHookEvent + + +@pytest.fixture +def orchestrator(): + """Mock orchestrator for testing.""" + return Mock() + + +def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): + """Test MultiAgentInitializedEvent creation with orchestrator only.""" + event = MultiAgentInitializedEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_multi_agent_initialization_event_with_invocation_state(orchestrator): + """Test MultiAgentInitializedEvent creation with invocation state.""" + invocation_state = {"key": "value"} + event = MultiAgentInitializedEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state + + +def test_after_node_invocation_event_with_required_fields(orchestrator): + """Test AfterNodeCallEvent creation with required fields.""" + node_id = "node_1" + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_node_invocation_event_with_invocation_state(orchestrator): + """Test AfterNodeCallEvent creation with invocation state.""" + node_id = "node_2" + invocation_state = {"result": "success"} + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state == invocation_state + + +def test_after_multi_agent_invocation_event_with_orchestrator_only(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with orchestrator only.""" + event = AfterMultiAgentInvocationEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_multi_agent_invocation_event_with_invocation_state(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with invocation state.""" + invocation_state = {"final_state": "completed"} + event = AfterMultiAgentInvocationEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state + + +def test_before_node_call_event(orchestrator): + """Test BeforeNodeCallEvent creation.""" + node_id = "node_1" + event = BeforeNodeCallEvent(source=orchestrator, node_id=node_id) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_before_multi_agent_invocation_event(orchestrator): + """Test BeforeMultiAgentInvocationEvent creation.""" + event = BeforeMultiAgentInvocationEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_events_should_reverse_callbacks(orchestrator): + """Test that After events have should_reverse_callbacks property set to True.""" + after_node_event = AfterNodeCallEvent(source=orchestrator, node_id="test") + after_invocation_event = AfterMultiAgentInvocationEvent(source=orchestrator) + + assert after_node_event.should_reverse_callbacks is True + assert after_invocation_event.should_reverse_callbacks is True diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py new file mode 100644 index 000000000..e6188079b --- /dev/null +++ b/tests/strands/experimental/test_agent_config.py @@ -0,0 +1,172 @@ +"""Tests for experimental config_to_agent function.""" + +import json +import os +import tempfile + +import pytest + +from strands.experimental import config_to_agent + + +def test_config_to_agent_with_dict(): + """Test config_to_agent can be created with dict config.""" + config = {"model": "test-model"} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "test-model" + + +def test_config_to_agent_with_system_prompt(): + """Test config_to_agent handles system prompt correctly.""" + config = {"model": "test-model", "prompt": "Test prompt"} + agent = config_to_agent(config) + assert agent.system_prompt == "Test prompt" + + +def test_config_to_agent_with_tools_list(): + """Test config_to_agent handles tools list without failing.""" + # Use a simple test that doesn't require actual tool loading + config = {"model": "test-model", "tools": []} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "test-model" + + +def test_config_to_agent_with_kwargs_override(): + """Test that kwargs can override config values.""" + config = {"model": "test-model", "prompt": "Config prompt"} + agent = config_to_agent(config, system_prompt="Override prompt") + assert agent.system_prompt == "Override prompt" + + +def test_config_to_agent_file_prefix_required(): + """Test that file paths without file:// prefix work.""" + import json + import tempfile + + config_data = {"model": "test-model"} + temp_path = "" + + # We need to create files like this for windows compatibility + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + json.dump(config_data, f) + f.flush() + temp_path = f.name + + agent = config_to_agent(temp_path) + assert agent.model.config["model_id"] == "test-model" + finally: + # Clean up the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_config_to_agent_file_prefix_valid(): + """Test that file:// prefix is properly handled.""" + config_data = {"model": "test-model", "prompt": "Test prompt"} + temp_path = "" + + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + json.dump(config_data, f) + f.flush() + temp_path = f.name + + agent = config_to_agent(f"file://{temp_path}") + assert agent.model.config["model_id"] == "test-model" + assert agent.system_prompt == "Test prompt" + finally: + # Clean up the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_config_to_agent_file_not_found(): + """Test that FileNotFoundError is raised for missing files.""" + with pytest.raises(FileNotFoundError, match="Configuration file not found"): + config_to_agent("/nonexistent/path/config.json") + + +def test_config_to_agent_invalid_json(): + """Test that JSONDecodeError is raised for invalid JSON.""" + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + f.write("invalid json content") + temp_path = f.name + + with pytest.raises(json.JSONDecodeError): + config_to_agent(temp_path) + finally: + # Clean up the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_config_to_agent_invalid_config_type(): + """Test that ValueError is raised for invalid config types.""" + with pytest.raises(ValueError, match="Config must be a file path string or dictionary"): + config_to_agent(123) + + +def test_config_to_agent_with_name(): + """Test config_to_agent handles agent name.""" + config = {"model": "test-model", "name": "TestAgent"} + agent = config_to_agent(config) + assert agent.name == "TestAgent" + + +def test_config_to_agent_ignores_none_values(): + """Test that None values in config are ignored.""" + config = {"model": "test-model", "prompt": None, "name": None} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "test-model" + # Agent should use its defaults for None values + + +def test_config_to_agent_validation_error_invalid_field(): + """Test that invalid fields raise validation errors.""" + config = {"model": "test-model", "invalid_field": "value"} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_wrong_type(): + """Test that wrong field types raise validation errors.""" + config = {"model": "test-model", "tools": "not-a-list"} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_invalid_tool_item(): + """Test that invalid tool items raise validation errors.""" + config = {"model": "test-model", "tools": ["valid-tool", 123]} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_invalid_tool(): + """Test that invalid tools raise helpful error messages.""" + config = {"model": "test-model", "tools": ["nonexistent_tool"]} + with pytest.raises(ValueError, match="Failed to load tool nonexistent_tool"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_missing_module(): + """Test that missing modules raise helpful error messages.""" + config = {"model": "test-model", "tools": ["nonexistent.module.tool"]} + with pytest.raises(ValueError, match="Failed to load tool nonexistent.module.tool"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_missing_function(): + """Test that missing functions in existing modules raise helpful error messages.""" + config = {"model": "test-model", "tools": ["json.nonexistent_function"]} + with pytest.raises(ValueError, match="Failed to load tool json.nonexistent_function"): + config_to_agent(config) + + +def test_config_to_agent_with_tool(): + """Test that missing functions in existing modules raise helpful error messages.""" + config = {"model": "test-model", "tools": ["tests.fixtures.say_tool:say"]} + agent = config_to_agent(config) + assert "say" in agent.tool_names diff --git a/tests/strands/experimental/tools/__init__.py b/tests/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/hooks/__init__.py b/tests/strands/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py new file mode 100644 index 000000000..6918bd2ee --- /dev/null +++ b/tests/strands/hooks/test_registry.py @@ -0,0 +1,73 @@ +import unittest.mock + +import pytest + +from strands.agent.interrupt import InterruptState +from strands.hooks import BeforeToolCallEvent, HookRegistry +from strands.interrupt import Interrupt + + +@pytest.fixture +def registry(): + return HookRegistry() + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupt_state = InterruptState() + return instance + + +def test_hook_registry_invoke_callbacks_interrupt(registry, agent): + event = BeforeToolCallEvent( + agent=agent, + selected_tool=None, + tool_use={"toolUseId": "test_tool_id", "name": "test_tool_name", "input": {}}, + invocation_state={}, + ) + + callback1 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name_1", "test reason 1")) + callback2 = unittest.mock.Mock() + callback3 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name_2", "test reason 2")) + + registry.add_callback(BeforeToolCallEvent, callback1) + registry.add_callback(BeforeToolCallEvent, callback2) + registry.add_callback(BeforeToolCallEvent, callback3) + + _, tru_interrupts = registry.invoke_callbacks(event) + exp_interrupts = [ + Interrupt( + id="v1:before_tool_call:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", + name="test_name_1", + reason="test reason 1", + ), + Interrupt( + id="v1:before_tool_call:test_tool_id:0f5a8068-d1ba-5a48-bf67-c9d33786d8d4", + name="test_name_2", + reason="test reason 2", + ), + ] + assert tru_interrupts == exp_interrupts + + callback1.assert_called_once_with(event) + callback2.assert_called_once_with(event) + callback3.assert_called_once_with(event) + + +def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): + event = BeforeToolCallEvent( + agent=agent, + selected_tool=None, + tool_use={"toolUseId": "test_tool_id", "name": "test_tool_name", "input": {}}, + invocation_state={}, + ) + + callback1 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name", "test reason 1")) + callback2 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name", "test reason 2")) + + registry.add_callback(BeforeToolCallEvent, callback1) + registry.add_callback(BeforeToolCallEvent, callback2) + + with pytest.raises(ValueError, match="interrupt_name= | interrupt name used more than once"): + registry.invoke_callbacks(event) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 96fee67fa..4a6a0f9b0 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,5 +1,6 @@ import os import sys +import traceback import unittest.mock from unittest.mock import ANY @@ -10,6 +11,7 @@ from botocore.exceptions import ClientError, EventStreamError import strands +from strands import _exception_notes from strands.models import BedrockModel from strands.models.bedrock import ( _DEFAULT_BEDROCK_MODEL_ID, @@ -533,6 +535,40 @@ async def test_stream_throttling_exception_from_general_exception(bedrock_client ) +@pytest.mark.asyncio +async def test_stream_throttling_exception_lowercase(bedrock_client, model, messages, alist): + """Test that lowercase throttlingException is converted to ModelThrottledException.""" + error_message = "throttlingException: Rate exceeded for ConverseStream" + bedrock_client.converse_stream.side_effect = ClientError( + {"Error": {"Message": error_message, "Code": "throttlingException"}}, "Any" + ) + + with pytest.raises(ModelThrottledException) as excinfo: + await alist(model.stream(messages)) + + assert error_message in str(excinfo.value) + bedrock_client.converse_stream.assert_called_once_with( + modelId="m1", messages=messages, system=[], inferenceConfig={} + ) + + +@pytest.mark.asyncio +async def test_stream_throttling_exception_lowercase_non_streaming(bedrock_client, messages, alist): + """Test that lowercase throttlingException is converted to ModelThrottledException in non-streaming mode.""" + error_message = "throttlingException: Rate exceeded for Converse" + bedrock_client.converse.side_effect = ClientError( + {"Error": {"Message": error_message, "Code": "throttlingException"}}, "Any" + ) + + model = BedrockModel(model_id="test-model", streaming=False) + with pytest.raises(ModelThrottledException) as excinfo: + await alist(model.stream(messages)) + + assert error_message in str(excinfo.value) + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + @pytest.mark.asyncio async def test_general_exception_is_raised(bedrock_client, model, messages, alist): error_message = "Should be raised up" @@ -1209,6 +1245,23 @@ async def test_add_note_on_client_error(bedrock_client, model, alist, messages): assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] +@pytest.mark.asyncio +async def test_add_note_on_client_error_without_add_notes(bedrock_client, model, alist, messages): + """Test that when add_note is not used, the region & model are still included in the error output.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + # Mock the client error response + error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + + # Call the stream method which should catch and add notes to the exception + with pytest.raises(ClientError) as err: + await alist(model.stream(messages)) + + error_str = "".join(traceback.format_exception(err.value)) + assert "└ Bedrock region: us-west-2" in error_str + assert "└ Model id: m1" in error_str + + @pytest.mark.asyncio async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bc81fc819..3a427f759 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -3,9 +3,11 @@ import pydantic import pytest +from litellm.exceptions import ContextWindowExceededError import strands from strands.models.litellm import LiteLLMModel +from strands.types.exceptions import ContextWindowOverflowException @pytest.fixture @@ -140,39 +142,71 @@ def test_format_request_message_content(content, exp_result): @pytest.mark.asyncio async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, alist): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) mock_delta_1 = unittest.mock.Mock( reasoning_content="", content=None, tool_calls=None, ) + mock_delta_2 = unittest.mock.Mock( reasoning_content="\nI'm thinking", content=None, tool_calls=None, ) mock_delta_3 = unittest.mock.Mock( + reasoning_content=None, + content="One second", + tool_calls=None, + ) + mock_delta_4 = unittest.mock.Mock( + reasoning_content="\nI'm think", + content=None, + tool_calls=None, + ) + mock_delta_5 = unittest.mock.Mock( + reasoning_content="ing again", + content=None, + tool_calls=None, + ) + + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_6 = unittest.mock.Mock( content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None ) mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_4 = unittest.mock.Mock( + mock_delta_7 = unittest.mock.Mock( content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None ) - mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) + mock_delta_8 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)]) mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)]) - mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) - mock_event_6 = unittest.mock.Mock() + mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_5)]) + mock_event_6 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_6)]) + mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)]) + mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)]) + mock_event_9 = unittest.mock.Mock() litellm_acompletion.side_effect = unittest.mock.AsyncMock( - return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]) + return_value=agenerator( + [ + mock_event_1, + mock_event_2, + mock_event_3, + mock_event_4, + mock_event_5, + mock_event_6, + mock_event_7, + mock_event_8, + mock_event_9, + ] + ) ) messages = [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}] @@ -182,6 +216,15 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm thinking"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "One second"}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm think"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "ing again"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"text": "I'll calculate"}}}, {"contentBlockDelta": {"delta": {"text": "that for you"}}}, {"contentBlockStop": {}}, @@ -209,9 +252,9 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, { "metadata": { "usage": { - "inputTokens": mock_event_6.usage.prompt_tokens, - "outputTokens": mock_event_6.usage.completion_tokens, - "totalTokens": mock_event_6.usage.total_tokens, + "inputTokens": mock_event_9.usage.prompt_tokens, + "outputTokens": mock_event_9.usage.completion_tokens, + "totalTokens": mock_event_9.usage.total_tokens, }, "metrics": {"latencyMs": 0}, } @@ -251,8 +294,6 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene tru_events = await alist(response) exp_events = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, ] @@ -290,15 +331,27 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c @pytest.mark.asyncio -async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls): +async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.arguments = '{"name": "John", "age": 30}' + + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "tool_calls" + mock_choice.message.tool_calls = [mock_tool_call] + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + litellm_acompletion.return_value = mock_response + with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False): - with pytest.raises(ValueError, match="Model does not support response_format"): - stream = model.structured_output(test_output_model_cls, messages) - await stream.__anext__() + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + tru_result = events[-1] - litellm_acompletion.assert_not_called() + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings): @@ -332,3 +385,13 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_context_window_maps_to_typed_exception(litellm_acompletion, model): + """Test that a typed ContextWindowExceededError is mapped correctly.""" + litellm_acompletion.side_effect = ContextWindowExceededError(message="test error", model="x", llm_provider="y") + + with pytest.raises(ContextWindowOverflowException): + async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): + pass diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 219561025..b8249f504 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -128,3 +128,48 @@ async def stream(self, messages, tool_specs=None, system_prompt=None): assert len(events) == 3 assert events[1]["contentBlockDelta"]["delta"]["text"] == "Legacy model works" + + +@pytest.mark.asyncio +async def test_stream_with_tool_choice_parameter(messages, tool_specs, system_prompt, alist): + """Test that model can accept tool_choice parameter.""" + + class ModernModel(SAModel): + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs): + yield {"output": output_model(name="test", age=20)} + + async def stream(self, messages, tool_specs=None, system_prompt=None, *, tool_choice=None, **kwargs): + yield {"messageStart": {"role": "assistant"}} + if tool_choice: + yield {"contentBlockDelta": {"delta": {"text": f"Tool choice: {tool_choice}"}}} + else: + yield {"contentBlockDelta": {"delta": {"text": "No tool choice"}}} + yield {"messageStop": {"stopReason": "end_turn"}} + + model = ModernModel() + + # Test with tool_choice="auto" + response = model.stream(messages, tool_specs, system_prompt, tool_choice="auto") + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "Tool choice: auto" + + # Test with tool_choice="any" + response = model.stream(messages, tool_specs, system_prompt, tool_choice="any") + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "Tool choice: any" + + # Test with tool_choice={"type": "tool", "name": "test_tool"} + response = model.stream(messages, tool_specs, system_prompt, tool_choice={"tool": {"name": "SampleModel"}}) + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "Tool choice: {'tool': {'name': 'SampleModel'}}" + + # Test without tool_choice + response = model.stream(messages, tool_specs, system_prompt) + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "No tool choice" diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index a5662ecdc..72ebf01c6 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -112,11 +112,13 @@ def test_init_with_all_params(self, boto_session): "endpoint_name": "test-endpoint", "inference_component_name": "test-component", "region_name": "us-west-2", + "additional_args": {"test_req_arg_name": "test_req_arg_value"}, } payload_config = { "stream": False, "max_tokens": 1024, "temperature": 0.7, + "additional_args": {"test_payload_arg_name": "test_payload_arg_value"}, } client_config = BotocoreConfig(user_agent_extra="test-agent") @@ -129,9 +131,11 @@ def test_init_with_all_params(self, boto_session): assert model.endpoint_config["endpoint_name"] == "test-endpoint" assert model.endpoint_config["inference_component_name"] == "test-component" + assert model.endpoint_config["additional_args"]["test_req_arg_name"] == "test_req_arg_value" assert model.payload_config["stream"] is False assert model.payload_config["max_tokens"] == 1024 assert model.payload_config["temperature"] == 0.7 + assert model.payload_config["additional_args"]["test_payload_arg_name"] == "test_payload_arg_value" boto_session.client.assert_called_once_with( service_name="sagemaker-runtime", @@ -239,6 +243,30 @@ def test_get_config(self, model, endpoint_config): # assert "tools" in payload # assert payload["tools"] == [] + def test_format_request_with_additional_args(self, boto_session, endpoint_config, messages, payload_config): + """Test formatting a request's `additional_args` where provided""" + endpoint_config_ext = { + **endpoint_config, + "additional_args": { + "extra_request_key": "extra_request_value", + }, + } + payload_config_ext = { + **payload_config, + "additional_args": { + "extra_payload_key": "extra_payload_value", + }, + } + model = SageMakerAIModel( + boto_session=boto_session, + endpoint_config=endpoint_config_ext, + payload_config=payload_config_ext, + ) + request = model.format_request(messages) + assert request.get("extra_request_key") == "extra_request_value" + payload = json.loads(request["Body"]) + assert payload.get("extra_payload_key") == "extra_payload_value" + @pytest.mark.asyncio async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages): """Test streaming response with streaming enabled.""" diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index d21aa6e14..4e8a5dd06 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -28,6 +28,9 @@ def test_node_result_initialization_and_properties(agent_result): assert node_result.accumulated_metrics == {"latencyMs": 0.0} assert node_result.execution_count == 0 + default_node = NodeResult(result=agent_result) + assert default_node.status == Status.PENDING + # With custom metrics custom_usage = {"inputTokens": 100, "outputTokens": 200, "totalTokens": 300} custom_metrics = {"latencyMs": 250.0} @@ -95,6 +98,7 @@ def test_multi_agent_result_initialization(agent_result): assert result.accumulated_metrics == {"latencyMs": 0.0} assert result.execution_count == 0 assert result.execution_time == 0 + assert result.status == Status.PENDING # Custom values`` node_result = NodeResult(result=agent_result) @@ -141,6 +145,12 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) @@ -159,17 +169,73 @@ async def invoke_async(self, task, invocation_state, **kwargs): self.invoke_async_called = True self.received_task = task self.received_kwargs = kwargs + self.received_invocation_state = invocation_state return MultiAgentResult( status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + agent = TestMultiAgent() # Test with string task - result = agent("test task", param1="value1", param2="value2") + result = agent("test task", param1="value1", param2="value2", invocation_state={"value3": "value4"}) assert agent.invoke_async_called assert agent.received_task == "test task" - assert agent.received_kwargs == {"param1": "value1", "param2": "value2"} + assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"} assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED + + +def test_node_result_to_dict(agent_result): + """Test NodeResult to_dict method.""" + node_result = NodeResult(result=agent_result, execution_time=100, status=Status.COMPLETED) + result_dict = node_result.to_dict() + + assert result_dict["execution_time"] == 100 + assert result_dict["status"] == "completed" + assert result_dict["result"]["type"] == "agent_result" + assert result_dict["result"]["stop_reason"] == agent_result.stop_reason + assert result_dict["result"]["message"] == agent_result.message + + exception_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + result_dict = exception_result.to_dict() + + assert result_dict["result"]["type"] == "exception" + assert result_dict["result"]["message"] == "Test error" + assert result_dict["status"] == "failed" + + +def test_multi_agent_result_to_dict(agent_result): + """Test MultiAgentResult to_dict method.""" + node_result = NodeResult(result=agent_result) + multi_result = MultiAgentResult(status=Status.COMPLETED, results={"test_node": node_result}, execution_time=200) + + result_dict = multi_result.to_dict() + + assert result_dict["status"] == "completed" + assert result_dict["execution_time"] == 200 + assert "test_node" in result_dict["results"] + assert result_dict["results"]["test_node"]["result"]["type"] == "agent_result" + + +def test_serialize_node_result_for_persist(agent_result): + """Test serialize_node_result_for_persist method.""" + + node_result = NodeResult(result=agent_result) + serialized = node_result.to_dict() + + assert "result" in serialized + assert "execution_time" in serialized + assert "status" in serialized + + exception_node_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + serialized_exception = exception_node_result.to_dict() + assert "result" in serialized_exception + assert serialized_exception["result"]["type"] == "exception" + assert serialized_exception["result"]["message"] == "Test error" diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 8097d944e..c4c1a664f 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -310,7 +310,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): result = await graph.invoke_async([{"text": "Original task"}]) # Verify entry node was called with original task - entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}]) + entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}], invocation_state={}) assert result.status == Status.COMPLETED mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called_once() @@ -906,7 +906,7 @@ def __init__(self, name): self._session_manager = None self.hooks = HookRegistry() - async def invoke_async(self, input_data): + async def invoke_async(self, input_data, invocation_state=None): # Increment execution count in state count = self.state.get("execution_count") or 0 self.state.set("execution_count", count + 1) @@ -1300,7 +1300,9 @@ async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state) + kwargs_agent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing"}], invocation_state=test_invocation_state + ) assert result.status == Status.COMPLETED @@ -1335,5 +1337,7 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = graph("Test kwargs passing sync", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) + kwargs_agent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing sync"}], invocation_state=test_invocation_state + ) assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 7d3e69695..0968fd30c 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -558,7 +558,7 @@ async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): test_kwargs = {"custom_param": "test_value", "another_param": 42} result = await swarm.invoke_async("Test kwargs passing", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs} assert result.status == Status.COMPLETED @@ -572,5 +572,5 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): test_kwargs = {"custom_param": "test_value", "another_param": 42} result = swarm("Test kwargs passing sync", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs} assert result.status == Status.COMPLETED diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 2c25fcc38..923b13daa 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -5,6 +5,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.agent.interrupt import InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException @@ -95,6 +96,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): agent_id="existing-agent", state={"key": "value"}, conversation_manager_state=SlidingWindowConversationManager().get_state(), + _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, ) session_manager.session_repository.create_agent("test-session", session_agent) @@ -116,6 +118,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert len(agent.messages) == 1 assert agent.messages[0]["role"] == "user" assert agent.messages[0]["content"][0]["text"] == "Hello" + assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False) def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 12db81908..e87277eed 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -109,6 +109,18 @@ def metrics(request): return Metrics(**params) +@pytest.fixture +def metrics_with_ttfb(request): + params = { + "latencyMs": 1, + "timeToFirstByteMs": 10, + } + if hasattr(request, "param"): + params.update(request.param) + + return Metrics(**params) + + @pytest.mark.parametrize("end_time", [None, 1]) @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") def test_trace_end(mock_time, end_time, trace): @@ -132,8 +144,8 @@ def mock_get_meter_provider(): mock_create_counter = mock.MagicMock() mock_meter.create_counter.return_value = mock_create_counter - mock_create_histogram = mock.MagicMock() - mock_meter.create_histogram.return_value = mock_create_histogram + # Create separate mock objects for each histogram call + mock_meter.create_histogram.side_effect = lambda *args, **kwargs: mock.MagicMock() meter_provider_mock.get_meter.return_value = mock_meter mock_get_meter_provider.return_value = meter_provider_mock @@ -326,9 +338,9 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met metrics_client.event_loop_cache_write_input_tokens.record.assert_called() -def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): +def test_event_loop_metrics_update_metrics(metrics_with_ttfb, event_loop_metrics, mock_get_meter_provider): for _ in range(3): - event_loop_metrics.update_metrics(metrics) + event_loop_metrics.update_metrics(metrics_with_ttfb) tru_metrics = event_loop_metrics.accumulated_metrics exp_metrics = Metrics( @@ -338,6 +350,7 @@ def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get assert tru_metrics == exp_metrics mock_get_meter_provider.return_value.get_meter.assert_called() event_loop_metrics._metrics_client.event_loop_latency.record.assert_called_with(1) + event_loop_metrics._metrics_client.model_time_to_first_token.record.assert_called_with(10) def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_get_meter_provider): diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 4e9872100..05dbe387f 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -11,7 +11,7 @@ from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.content import ContentBlock -from strands.types.streaming import StopReason, Usage +from strands.types.streaming import Metrics, StopReason, Usage @pytest.fixture(autouse=True) @@ -153,7 +153,7 @@ def test_start_model_invoke_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" - assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.CLIENT + assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) @@ -173,14 +173,22 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer): mock_span = mock.MagicMock() mock_tracer.start_span.return_value = mock_span - messages = [{"role": "user", "content": [{"text": "Hello"}]}] + messages = [ + {"role": "user", "content": [{"text": "Hello 2025-1993"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"input": '"expression": "2025-1993"', "name": "calculator", "toolUseId": "123"}} + ], + }, + ] model_id = "test-model" span = tracer.start_model_invoke_span(messages=messages, agent_name="TestAgent", model_id=model_id) mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" - assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.CLIENT + assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) @@ -191,8 +199,19 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer): [ { "role": messages[0]["role"], - "parts": [{"type": "text", "content": messages[0]["content"]}], - } + "parts": [{"type": "text", "content": "Hello 2025-1993"}], + }, + { + "role": messages[1]["role"], + "parts": [ + { + "type": "tool_call", + "name": "calculator", + "id": "123", + "arguments": '"expression": "2025-1993"', + } + ], + }, ] ) }, @@ -205,17 +224,18 @@ def test_end_model_invoke_span(mock_span): tracer = Tracer() message = {"role": "assistant", "content": [{"text": "Response"}]} usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) stop_reason: StopReason = "end_turn" - tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) + mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -231,17 +251,18 @@ def test_end_model_invoke_span_latest_conventions(mock_span): tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) stop_reason: StopReason = "end_turn" - tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) + mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -249,7 +270,7 @@ def test_end_model_invoke_span_latest_conventions(mock_span): [ { "role": "assistant", - "parts": [{"type": "text", "content": message["content"]}], + "parts": [{"type": "text", "content": "Response"}], "finish_reason": "end_turn", } ] @@ -318,7 +339,7 @@ def test_start_tool_call_span_latest_conventions(mock_tracer): "type": "tool_call", "name": tool["name"], "id": tool["toolUseId"], - "arguments": [{"content": tool["input"]}], + "arguments": tool["input"], } ], } @@ -398,7 +419,7 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer) "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": [{"text": "Original Task: foo bar"}]}]}] + [{"role": "user", "parts": [{"type": "text", "content": "Original Task: foo bar"}]}] ) }, ) @@ -486,7 +507,7 @@ def test_end_tool_call_span_latest_conventions(mock_span): """Test ending a tool call span with the latest semantic conventions.""" tracer = Tracer() tracer.use_latest_genai_conventions = True - tool_result = {"status": "success", "content": [{"text": "Tool result"}]} + tool_result = {"status": "success", "content": [{"text": "Tool result"}, {"json": {"foo": "bar"}}]} tracer.end_tool_call_span(mock_span, tool_result) @@ -502,7 +523,7 @@ def test_end_tool_call_span_latest_conventions(mock_span): { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "result": tool_result.get("content"), + "response": tool_result.get("content"), } ], } @@ -558,9 +579,7 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer): mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ - "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": messages[0]["content"]}]}] - ) + "gen_ai.input.messages": serialize([{"role": "user", "parts": [{"type": "text", "content": "Hello"}]}]) }, ) assert span is not None @@ -570,7 +589,12 @@ def test_end_event_loop_cycle_span(mock_span): """Test ending an event loop cycle span.""" tracer = Tracer() message = {"role": "assistant", "content": [{"text": "Response"}]} - tool_result_message = {"role": "assistant", "content": [{"toolResult": {"response": "Success"}}]} + tool_result_message = { + "role": "assistant", + "content": [ + {"toolResult": {"toolUseId": "123", "status": "success", "content": [{"text": "Weather is sunny"}]}} + ], + } tracer.end_event_loop_cycle_span(mock_span, message, tool_result_message) @@ -590,7 +614,12 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span): tracer = Tracer() tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} - tool_result_message = {"role": "assistant", "content": [{"toolResult": {"response": "Success"}}]} + tool_result_message = { + "role": "assistant", + "content": [ + {"toolResult": {"toolUseId": "123", "status": "success", "content": [{"text": "Weather is sunny"}]}} + ], + } tracer.end_event_loop_cycle_span(mock_span, message, tool_result_message) @@ -601,7 +630,13 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span): [ { "role": "assistant", - "parts": [{"type": "text", "content": tool_result_message["content"]}], + "parts": [ + { + "type": "tool_call_response", + "id": "123", + "response": [{"text": "Weather is sunny"}], + } + ], } ] ) @@ -635,6 +670,7 @@ def test_start_agent_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" + assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) @@ -676,7 +712,7 @@ def test_start_agent_span_latest_conventions(mock_tracer): "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": [{"text": "test prompt"}]}]}] + [{"role": "user", "parts": [{"type": "text", "content": "test prompt"}]}] ) }, ) @@ -766,8 +802,9 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): cacheWriteInputTokens=3, ) stop_reason: StopReason = "end_turn" + metrics = Metrics(latencyMs=10, timeToFirstByteMs=5) - tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) @@ -776,6 +813,8 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 5) mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 3) + mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 10) + mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 5) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() diff --git a/tests/strands/test_async.py b/tests/strands/test_async.py new file mode 100644 index 000000000..2a98a953c --- /dev/null +++ b/tests/strands/test_async.py @@ -0,0 +1,25 @@ +"""Tests for _async module.""" + +import pytest + +from strands._async import run_async + + +def test_run_async_with_return_value(): + """Test run_async returns correct value.""" + + async def async_with_value(): + return 42 + + result = run_async(async_with_value) + assert result == 42 + + +def test_run_async_exception_propagation(): + """Test that exceptions are properly propagated.""" + + async def async_with_exception(): + raise ValueError("test exception") + + with pytest.raises(ValueError, match="test exception"): + run_async(async_with_exception) diff --git a/tests/strands/test_exception_notes.py b/tests/strands/test_exception_notes.py new file mode 100644 index 000000000..936cf0848 --- /dev/null +++ b/tests/strands/test_exception_notes.py @@ -0,0 +1,51 @@ +"""Tests for exception note utilities.""" + +import sys +import traceback +import unittest.mock + +import pytest + +from strands import _exception_notes +from strands._exception_notes import add_exception_note + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +def test_add_exception_note_python_311_plus(): + """Test add_exception_note uses add_note in Python 3.11+.""" + exception = ValueError("original message") + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: original message\n", "test note\n"] + + +def test_add_exception_note_python_310(): + """Test add_exception_note modifies args in Python 3.10.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + exception = ValueError("original message") + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: original message\ntest note\n"] + + +def test_add_exception_note_python_310_no_args(): + """Test add_exception_note handles exception with no args in Python 3.10.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + exception = ValueError() + exception.args = () + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: test note\n"] + + +def test_add_exception_note_python_310_multiple_args(): + """Test add_exception_note preserves additional args in Python 3.10.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + exception = ValueError("original message", "second arg") + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: ('original message\\ntest note', 'second arg')\n"] diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py new file mode 100644 index 000000000..8ce972103 --- /dev/null +++ b/tests/strands/test_interrupt.py @@ -0,0 +1,24 @@ +import pytest + +from strands.interrupt import Interrupt + + +@pytest.fixture +def interrupt(): + return Interrupt( + id="test_id:test_name", + name="test_name", + reason={"reason": "test"}, + response={"response": "test"}, + ) + + +def test_interrupt_to_dict(interrupt): + tru_dict = interrupt.to_dict() + exp_dict = { + "id": "test_id:test_name", + "name": "test_name", + "reason": {"reason": "test"}, + "response": {"response": "test"}, + } + assert tru_dict == exp_dict diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index be90226f6..d25cf14bd 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,8 +4,10 @@ import pytest import strands +from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry from strands.tools.registry import ToolRegistry +from strands.types.tools import ToolContext @pytest.fixture @@ -78,12 +80,22 @@ def func(): @pytest.fixture -def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool): +def interrupt_tool(): + @strands.tool(name="interrupt_tool", context=True) + def func(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_name", reason="test reason") + + return func + + +@pytest.fixture +def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool): registry = ToolRegistry() registry.register_tool(weather_tool) registry.register_tool(temperature_tool) registry.register_tool(exception_tool) registry.register_tool(thread_tool) + registry.register_tool(interrupt_tool) return registry @@ -92,6 +104,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry + mock_agent._interrupt_state = InterruptState() return mock_agent @@ -111,5 +124,5 @@ def cycle_span(): @pytest.fixture -def invocation_state(): - return {} +def invocation_state(agent): + return {"agent": agent} diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index f7fc64b25..ce07ee4ce 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,8 +1,10 @@ import pytest +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.tools.executors import ConcurrentToolExecutor -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolUse +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.types._events import ToolInterruptEvent, ToolResultEvent @pytest.fixture @@ -10,15 +12,22 @@ def executor(): return ConcurrentToolExecutor() +@pytest.fixture +def structured_output_context(): + return StructuredOutputContext(structured_output_model=None) + + @pytest.mark.asyncio async def test_concurrent_executor_execute( - executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist ): - tool_uses: list[ToolUse] = [ + tool_uses = [ {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] - stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ @@ -30,3 +39,40 @@ async def test_concurrent_executor_execute( tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_concurrent_executor_interrupt( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist +): + interrupt = Interrupt( + id="v1:before_tool_call:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + if event.tool_use["name"] == "weather_tool": + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + tool_uses = [ + {"name": "weather_tool", "toolUseId": "test_tool_id_1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, + ] + + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) + + tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) + exp_events = [ + ToolInterruptEvent(tool_uses[0], [interrupt]), + ToolResultEvent({"toolUseId": "test_tool_id_2", "status": "success", "content": [{"text": "75F"}]}), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[1].tool_result] + assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 2a0a44e10..a11e2eab2 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -5,9 +5,10 @@ import strands from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor -from strands.types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import ToolUse @@ -36,6 +37,7 @@ async def test_executor_stream_yields_result( executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist ): tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) @@ -250,3 +252,210 @@ def cancel_callback(event): tru_results = tool_results exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_sets_span_attributes( + executor, agent, tool_results, invocation_state, weather_tool, alist +): + """Test that span attributes are set correctly when tool_spec is available.""" + with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api: + mock_span = unittest.mock.MagicMock() + mock_trace_api.get_current_span.return_value = mock_span + + # Mock tool_spec with inputSchema containing json field + with unittest.mock.patch.object( + type(weather_tool), "tool_spec", new_callable=unittest.mock.PropertyMock + ) as mock_tool_spec: + mock_tool_spec.return_value = { + "name": "weather_tool", + "description": "Get weather information", + "inputSchema": {"json": {"type": "object", "properties": {}}, "type": "object"}, + } + + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + await alist(stream) + + # Verify set_attribute was called with correct values + calls = mock_span.set_attribute.call_args_list + assert len(calls) == 2 + + # Check description attribute + assert calls[0][0][0] == "gen_ai.tool.description" + assert calls[0][0][1] == "Get weather information" + + # Check json_schema attribute + assert calls[1][0][0] == "gen_ai.tool.json_schema" + # The serialize function should have been called on the json field + + +@pytest.mark.asyncio +async def test_executor_stream_handles_missing_json_in_input_schema( + executor, agent, tool_results, invocation_state, weather_tool, alist +): + """Test that span attributes handle inputSchema without json field gracefully.""" + with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api: + mock_span = unittest.mock.MagicMock() + mock_trace_api.get_current_span.return_value = mock_span + + # Mock tool_spec with inputSchema but no json field + with unittest.mock.patch.object( + type(weather_tool), "tool_spec", new_callable=unittest.mock.PropertyMock + ) as mock_tool_spec: + mock_tool_spec.return_value = { + "name": "weather_tool", + "description": "Get weather information", + "inputSchema": {"type": "object", "properties": {}}, + } + + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + # Should not raise an error - json_schema attribute just won't be set + await alist(stream) + + # Verify only description attribute was set (not json_schema) + calls = mock_span.set_attribute.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == "gen_ai.tool.description" + + +@pytest.mark.asyncio +async def test_executor_stream_no_span_attributes_when_no_tool_spec( + executor, agent, tool_results, invocation_state, alist +): + """Test that no span attributes are set when tool_spec is None.""" + with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api: + mock_span = unittest.mock.MagicMock() + mock_trace_api.get_current_span.return_value = mock_span + + # Use unknown tool which will have no tool_spec + tool_use: ToolUse = {"name": "unknown_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + await alist(stream) + + # Verify set_attribute was not called since tool_spec is None + mock_span.set_attribute.assert_not_called() + + +@pytest.mark.asyncio +async def test_executor_stream_hook_interrupt(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:before_tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + event.interrupt("test_name", reason="test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_use, [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_hook_interrupt_resume(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:before_tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", reason="test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "test_tool_id", + "status": "success", + "content": [{"text": "sunny"}], + }, + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_use, [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "test_tool_id", + "status": "success", + "content": [{"text": "test response"}], + }, + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index 37e098142..10e3ad484 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,20 @@ import pytest +from pydantic import BaseModel +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt +from strands.tools.decorator import tool from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.types._events import ToolInterruptEvent, ToolResultEvent +from strands.types.tools import ToolUse + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str + age: int @pytest.fixture @@ -9,6 +22,34 @@ def executor(): return SequentialToolExecutor() +@pytest.fixture +def structured_output_context(): + """Create a structured output context with SampleModel.""" + return StructuredOutputContext(structured_output_model=SampleModel) + + +@pytest.fixture +def capture_tool(): + """Create a tool that captures kwargs passed to it.""" + captured_kwargs = {} + + @tool(name="capture_tool") + def func(): + return "captured" + + # Override the stream method to capture kwargs + original_stream = func.stream + + async def capturing_stream(tool_use, invocation_state, **kwargs): + captured_kwargs.update(kwargs) + async for event in original_stream(tool_use, invocation_state, **kwargs): + yield event + + func.stream = capturing_stream + func.captured_kwargs = captured_kwargs + return func + + @pytest.mark.asyncio async def test_sequential_executor_execute( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist @@ -17,7 +58,10 @@ async def test_sequential_executor_execute( {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] - stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + structured_output_context = StructuredOutputContext(None) + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) tru_events = await alist(stream) exp_events = [ @@ -29,3 +73,75 @@ async def test_sequential_executor_execute( tru_results = tool_results exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_sequential_executor_interrupt( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + interrupt = Interrupt( + id="v1:before_tool_call:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + tool_uses = [ + {"name": "weather_tool", "toolUseId": "test_tool_id_1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, + ] + + structured_output_context = StructuredOutputContext(None) + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_uses[0], [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_sequential_executor_passes_structured_output_context( + executor, + agent, + tool_results, + cycle_trace, + cycle_span, + invocation_state, + structured_output_context, + capture_tool, + alist, +): + """Test that sequential executor properly passes structured output context to tools.""" + # Register the capture tool + agent.tool_registry.register_tool(capture_tool) + + # Set up tool uses + tool_uses: list[ToolUse] = [ + {"name": "capture_tool", "toolUseId": "1", "input": {}}, + ] + + # Execute tools with structured output context + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) + + # Collect events + events = await alist(stream) + + # Verify the structured_output_context was passed to the tool + assert "structured_output_context" in capture_tool.captured_kwargs + assert capture_tool.captured_kwargs["structured_output_context"] is structured_output_context + + # Verify event was generated + assert len(events) == 1 + assert events[0].tool_use_id == "1" diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 67d8fe558..130a4703e 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,3 +1,4 @@ +import base64 import time from unittest.mock import AsyncMock, MagicMock, patch @@ -541,3 +542,149 @@ def slow_transport(): assert client._background_thread_session is None assert client._background_thread_event_loop is None assert not client._init_future.done() # New future created + + +def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): + """EmbeddedResource.resource (uri + text) should map to plain text content.""" + embedded_resource = { + "type": "resource", # required literal + "resource": { + "uri": "mcp://resource/embedded-text-1", + "text": "inner text", + "mimeType": "text/plain", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "inner text" + + +def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock_session): + """EmbeddedResource.resource (uri + blob with textual MIME) should decode to text.""" + + payload = base64.b64encode(b'{"k":"v"}').decode() + + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-blob-1", + # NOTE: blob is a STRING, mimeType is sibling + "blob": payload, + "mimeType": "application/json", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == '{"k":"v"}' + + +def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): + """EmbeddedResource.resource (blob with image MIME) should map to image content.""" + # Read yellow.png file + with open("tests_integ/yellow.png", "rb") as image_file: + png_data = image_file.read() + payload = base64.b64encode(png_data).decode() + + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-image", + "blob": payload, + "mimeType": "image/png", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert "image" in result["content"][0] + assert result["content"][0]["image"]["format"] == "png" + assert "bytes" in result["content"][0]["image"]["source"] + + +def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_session): + """EmbeddedResource.resource (blob with non-textual/unknown MIME) should be dropped.""" + payload = base64.b64encode(b"\x00\x01\x02\x03").decode() + + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-binary", + "blob": payload, + "mimeType": "application/octet-stream", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 0 # Content should be dropped + + +def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_session): + """EmbeddedResource with different textual MIME types should decode to text.""" + + # Test YAML content + yaml_content = base64.b64encode(b"key: value\nlist:\n - item1\n - item2").decode() + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-yaml", + "blob": yaml_content, + "mimeType": "application/yaml", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert "key: value" in result["content"][0]["text"] + + +def test_call_tool_sync_embedded_unknown_resource_type_dropped(mock_transport, mock_session): + """EmbeddedResource with unknown resource type should be dropped for forward compatibility.""" + + # Mock an unknown resource type that's neither TextResourceContents nor BlobResourceContents + class UnknownResourceContents: + def __init__(self): + self.uri = "mcp://resource/unknown-type" + self.mimeType = "application/unknown" + self.data = "some unknown data" + + # Create a mock embedded resource with unknown resource type + mock_embedded_resource = MagicMock() + mock_embedded_resource.resource = UnknownResourceContents() + + mock_session.call_tool.return_value = MagicMock( + isError=False, content=[mock_embedded_resource], structuredContent=None + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 0 # Unknown resource type should be dropped diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py new file mode 100644 index 000000000..9cb90167d --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -0,0 +1,826 @@ +"""Unit tests for MCPClient ToolProvider functionality.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest +from mcp.types import Tool as MCPTool + +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.tools.mcp.mcp_client import ToolFilters +from strands.types import PaginatedList +from strands.types.exceptions import ToolProviderException + + +@pytest.fixture +def mock_transport(): + """Create a mock transport callable.""" + + def transport(): + read_stream = MagicMock() + write_stream = MagicMock() + return read_stream, write_stream + + return transport + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + return tool + + +@pytest.fixture +def mock_agent_tool(mock_mcp_tool): + """Create a mock MCPAgentTool.""" + agent_tool = MagicMock(spec=MCPAgentTool) + agent_tool.tool_name = "test_tool" + agent_tool.mcp_tool = mock_mcp_tool + return agent_tool + + +def create_mock_tool(tool_name: str, mcp_tool_name: str | None = None) -> MagicMock: + """Helper to create mock tools with specific names.""" + tool = MagicMock(spec=MCPAgentTool) + tool.tool_name = tool_name + tool.tool_spec = { + "name": tool_name, + "description": f"Description for {tool_name}", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + tool.mcp_tool = MagicMock(spec=MCPTool) + tool.mcp_tool.name = mcp_tool_name or tool_name + tool.mcp_tool.description = f"Description for {tool_name}" + return tool + + +def test_init_with_tool_filters_and_prefix(mock_transport): + """Test initialization with tool filters and prefix.""" + filters = {"allowed": ["tool1"]} + prefix = "test_prefix" + + client = MCPClient(mock_transport, tool_filters=filters, prefix=prefix) + + assert client._tool_filters == filters + assert client._prefix == prefix + assert client._loaded_tools is None + assert client._tool_provider_started is False + + +@pytest.mark.asyncio +async def test_load_tools_starts_client_when_not_started(mock_transport, mock_agent_tool): + """Test that load_tools starts the client when not already started.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_called_once() + assert client._tool_provider_started is True + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_does_not_start_client_when_already_started(mock_transport, mock_agent_tool): + """Test that load_tools does not start client when already started.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_not_called() + assert len(tools) == 1 + + +@pytest.mark.asyncio +async def test_load_tools_raises_exception_on_client_start_failure(mock_transport): + """Test that load_tools raises ToolProviderException when client start fails.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start: + mock_start.side_effect = Exception("Client start failed") + + with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): + await client.load_tools() + + +@pytest.mark.asyncio +async def test_load_tools_caches_tools(mock_transport, mock_agent_tool): + """Test that load_tools caches tools and doesn't reload them.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + # First call + tools1 = await client.load_tools() + # Second call + tools2 = await client.load_tools() + + # Client should only be called once + mock_list_tools.assert_called_once() + assert tools1 is tools2 + + +@pytest.mark.asyncio +async def test_load_tools_handles_pagination(mock_transport): + """Test that load_tools handles pagination correctly.""" + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token + mock_list_tools.side_effect = [ + PaginatedList([tool1], token="page2"), + PaginatedList([tool2], token=None), + ] + + tools = await client.load_tools() + + # Should have called list_tools_sync twice + assert mock_list_tools.call_count == 2 + # First call with no token, second call with "page2" token + mock_list_tools.assert_any_call(None, prefix=None, tool_filters=None) + mock_list_tools.assert_any_call("page2", prefix=None, tool_filters=None) + + assert len(tools) == 2 + assert tools[0] is tool1 + assert tools[1] is tool2 + + +@pytest.mark.asyncio +async def test_allowed_filter_string_match(mock_transport): + """Test allowed filter with string matching.""" + tool1 = create_mock_tool("allowed_tool") + + filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results (simulating the filtering) + mock_list_tools.return_value = PaginatedList([tool1]) # Only allowed tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "allowed_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_regex_match(mock_transport): + """Test allowed filter with regex matching.""" + tool1 = create_mock_tool("echo_tool") + + filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only echo tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "echo_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_callable_match(mock_transport): + """Test allowed filter with callable matching.""" + tool1 = create_mock_tool("short") + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 10 + + filters: ToolFilters = {"allowed": [short_names_only]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only short tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "short" + + +@pytest.mark.asyncio +async def test_rejected_filter_string_match(mock_transport): + """Test rejected filter with string matching.""" + tool1 = create_mock_tool("good_tool") + + filters: ToolFilters = {"rejected": ["bad_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only good tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "good_tool" + + +@pytest.mark.asyncio +async def test_prefix_renames_tools(mock_transport): + """Test that prefix properly renames tools.""" + # Create a mock MCP tool (not MCPAgentTool) + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_name" + + client = MCPClient(mock_transport, prefix="prefix") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call list_tools_sync directly to test prefix functionality + result = client.list_tools_sync(prefix="prefix") + + # Should create MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="prefix_original_name") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_add_consumer(mock_transport): + """Test adding a provider consumer.""" + client = MCPClient(mock_transport) + + client.add_consumer("consumer1") + + assert "consumer1" in client._consumers + assert len(client._consumers) == 1 + + +def test_remove_consumer_without_cleanup(mock_transport): + """Test removing a provider consumer without triggering cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._consumers.add("consumer2") + client._tool_provider_started = True + + client.remove_consumer("consumer1") + + assert "consumer1" not in client._consumers + assert "consumer2" in client._consumers + assert client._tool_provider_started is True # Should not cleanup yet + + +def test_remove_consumer_with_cleanup(mock_transport): + """Test removing the last provider consumer triggers cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + client._loaded_tools = [MagicMock()] + + with patch.object(client, "stop") as mock_stop: + client.remove_consumer("consumer1") + + assert len(client._consumers) == 0 + assert client._tool_provider_started is False + assert client._loaded_tools is None + mock_stop.assert_called_once_with(None, None, None) + + +def test_remove_consumer_cleanup_failure(mock_transport): + """Test that remove_consumer raises ToolProviderException when cleanup fails.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + + with patch.object(client, "stop") as mock_stop: + mock_stop.side_effect = Exception("Cleanup failed") + + with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): + client.remove_consumer("consumer1") + + +def test_mcp_client_reuse_across_multiple_agents(mock_transport): + """Test that a single MCPClient can be used across multiple agents.""" + from strands import Agent + + tool1 = create_mock_tool(tool_name="shared_echo", mcp_tool_name="echo") + client = MCPClient(mock_transport, tool_filters={"allowed": ["echo"]}, prefix="shared") + + with ( + patch.object(client, "list_tools_sync") as mock_list_tools, + patch.object(client, "start") as mock_start, + patch.object(client, "stop") as mock_stop, + ): + mock_list_tools.return_value = PaginatedList([tool1]) + + # Create two agents with the same client + agent_1 = Agent(tools=[client]) + agent_2 = Agent(tools=[client]) + + # Both agents should have the same tool + assert "shared_echo" in agent_1.tool_names + assert "shared_echo" in agent_2.tool_names + assert agent_1.tool_names == agent_2.tool_names + + # Client should only be started once + mock_start.assert_called_once() + + # First agent cleanup - client should remain active + agent_1.cleanup() + mock_stop.assert_not_called() # Should not stop yet + + # Second agent should still work + assert "shared_echo" in agent_2.tool_names + + # Final cleanup when last agent is removed + agent_2.cleanup() + mock_stop.assert_called_once() # Now it should stop + + +def test_list_tools_sync_prefix_override_constructor_default(mock_transport): + """Test that list_tools_sync can override constructor prefix.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "override_original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with override prefix + result = client.list_tools_sync(prefix="override") + + # Should use override prefix, not constructor prefix + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="override_original_tool") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_prefix_override_with_empty_string(mock_transport): + """Test that list_tools_sync can override constructor prefix with empty string.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with empty string prefix (should override constructor default) + result = client.list_tools_sync(prefix="") + + # Should use no prefix (empty string overrides constructor) + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client) + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_prefix_uses_constructor_default_when_none(mock_transport): + """Test that list_tools_sync uses constructor prefix when None is passed.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "constructor_original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with None prefix (should use constructor default) + result = client.list_tools_sync(prefix=None) + + # Should use constructor prefix + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="constructor_original_tool") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_tool_filters_override_constructor_default(mock_transport): + """Test that list_tools_sync can override constructor tool_filters.""" + # Create mock tools + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + # Client with constructor filters that would allow both + constructor_filters: ToolFilters = {"allowed": ["allowed_tool", "rejected_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="allowed_tool"), MagicMock(name="rejected_tool")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Override filters to only allow one tool + override_filters: ToolFilters = {"allowed": ["allowed_tool"]} + result = client.list_tools_sync(tool_filters=override_filters) + + # Should only include the allowed tool based on override filters + assert len(result) == 1 + assert result[0] is tool1 + + +def test_list_tools_sync_tool_filters_override_with_empty_dict(mock_transport): + """Test that list_tools_sync can override constructor filters with empty dict.""" + # Create mock tools + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + # Client with constructor filters that would reject tools + constructor_filters: ToolFilters = {"rejected": ["tool1", "tool2"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="tool1"), MagicMock(name="tool2")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Override with empty filters (should allow all tools) + result = client.list_tools_sync(tool_filters={}) + + # Should include both tools since empty filters allow everything + assert len(result) == 2 + assert result[0] is tool1 + assert result[1] is tool2 + + +def test_list_tools_sync_tool_filters_uses_constructor_default_when_none(mock_transport): + """Test that list_tools_sync uses constructor filters when None is passed.""" + # Create mock tools + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + # Client with constructor filters + constructor_filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="allowed_tool"), MagicMock(name="rejected_tool")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Call with None filters (should use constructor default) + result = client.list_tools_sync(tool_filters=None) + + # Should only include allowed tool based on constructor filters + assert len(result) == 1 + assert result[0] is tool1 + + +def test_list_tools_sync_combined_prefix_and_filter_overrides(mock_transport): + """Test that list_tools_sync can override both prefix and filters simultaneously.""" + # Client with constructor defaults + constructor_filters: ToolFilters = {"allowed": ["echo_tool", "other_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters, prefix="constructor") + + # Create mock tools + mock_echo_tool = MagicMock() + mock_echo_tool.name = "echo_tool" + mock_other_tool = MagicMock() + mock_other_tool.name = "other_tool" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_echo_tool, mock_other_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_echo_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_other_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Override both prefix and filters + override_filters: ToolFilters = {"allowed": ["echo_tool"]} + result = client.list_tools_sync(prefix="override", tool_filters=override_filters) + + # Verify prefix override: should use "override" not "constructor" + calls = mock_agent_tool_class.call_args_list + assert len(calls) == 2 + + # First tool should have override prefix + args1, kwargs1 = calls[0] + assert args1 == (mock_echo_tool, client) + assert kwargs1 == {"name_override": "override_echo_tool"} + + # Second tool should have override prefix + args2, kwargs2 = calls[1] + assert args2 == (mock_other_tool, client) + assert kwargs2 == {"name_override": "override_other_tool"} + + # Verify filter override: should only include echo_tool based on override filters + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_direct_usage_without_constructor_defaults(mock_transport): + """Test direct usage of list_tools_sync without constructor defaults.""" + # Client without constructor defaults + client = MCPClient(mock_transport) + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_tool1, mock_tool2] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_tool1 + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_tool2 + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Direct usage with explicit parameters + filters: ToolFilters = {"allowed": ["tool1"]} + result = client.list_tools_sync(prefix="direct", tool_filters=filters) + + # Verify prefix is applied + calls = mock_agent_tool_class.call_args_list + assert len(calls) == 2 + + # Should create tools with direct prefix + args1, kwargs1 = calls[0] + assert args1 == (mock_tool1, client) + assert kwargs1 == {"name_override": "direct_tool1"} + + args2, kwargs2 = calls[1] + assert args2 == (mock_tool2, client) + assert kwargs2 == {"name_override": "direct_tool2"} + + # Verify filtering: should only include tool1 + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_regex_filter_override(mock_transport): + """Test list_tools_sync with regex filter override.""" + # Client without constructor filters + client = MCPClient(mock_transport) + + # Create mock tools + mock_echo_tool = MagicMock() + mock_echo_tool.name = "echo_command" + mock_list_tool = MagicMock() + mock_list_tool.name = "list_files" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_echo_tool, mock_list_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_echo_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_list_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Use regex filter to match only echo tools + regex_filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + result = client.list_tools_sync(tool_filters=regex_filters) + + # Should create both tools + assert mock_agent_tool_class.call_count == 2 + + # Should only include echo tool (regex matches "echo_command") + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_callable_filter_override(mock_transport): + """Test list_tools_sync with callable filter override.""" + # Client without constructor filters + client = MCPClient(mock_transport) + + # Create mock tools + mock_short_tool = MagicMock() + mock_short_tool.name = "short" + mock_long_tool = MagicMock() + mock_long_tool.name = "very_long_tool_name" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_short_tool, mock_long_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_short_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_long_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Use callable filter for short names only + def short_names_only(tool) -> bool: + return len(tool.mcp_tool.name) <= 10 + + callable_filters: ToolFilters = {"allowed": [short_names_only]} + result = client.list_tools_sync(tool_filters=callable_filters) + + # Should create both tools + assert mock_agent_tool_class.call_count == 2 + + # Should only include short tool (name length <= 10) + assert len(result) == 1 + assert result[0] is mock_agent_tool1 diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py index 2c730624e..85d533403 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -340,6 +340,21 @@ def __getattr__(self, name): class TestMCPInstrumentation: + def test_mcp_instrumentation_called_on_client_init(self): + """Test that mcp_instrumentation is called when MCPClient is initialized.""" + with patch("strands.tools.mcp.mcp_client.mcp_instrumentation") as mock_instrumentation: + # Mock transport + def mock_transport(): + read_stream = AsyncMock() + write_stream = AsyncMock() + return read_stream, write_stream + + # Create MCPClient instance - should call mcp_instrumentation + MCPClient(mock_transport) + + # Verify mcp_instrumentation was called + mock_instrumentation.assert_called_once() + def test_mcp_instrumentation_idempotent_with_multiple_clients(self): """Test that mcp_instrumentation is only called once even with multiple MCPClient instances.""" diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py new file mode 100644 index 000000000..a7eb27ca5 --- /dev/null +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -0,0 +1,245 @@ +"""Tests for StructuredOutputContext class.""" + +from typing import Optional + +from pydantic import BaseModel, Field + +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output.structured_output_tool import StructuredOutputTool + + +class SampleModel(BaseModel): + """Test Pydantic model for testing.""" + + name: str = Field(..., description="Name field") + age: int = Field(..., description="Age field", ge=0) + email: Optional[str] = Field(None, description="Optional email field") + + +class AnotherSampleModel(BaseModel): + """Another test Pydantic model.""" + + value: str + count: int + + +class TestStructuredOutputContext: + """Test suite for StructuredOutputContext.""" + + def test_initialization_with_structured_output_model(self): + """Test initialization with a structured output model.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + assert context.structured_output_model == SampleModel + assert isinstance(context.structured_output_tool, StructuredOutputTool) + assert context.expected_tool_name == "SampleModel" + assert context.results == {} + assert context.forced_mode is False + assert context.tool_choice is None + assert context.stop_loop is False + + def test_initialization_without_structured_output_model(self): + """Test initialization without a structured output model.""" + context = StructuredOutputContext(structured_output_model=None) + + assert context.structured_output_model is None + assert context.structured_output_tool is None + assert context.expected_tool_name is None + assert context.results == {} + assert context.forced_mode is False + assert context.tool_choice is None + assert context.stop_loop is False + + def test_is_enabled_property(self): + """Test the is_enabled property.""" + # Test with model + context_with_model = StructuredOutputContext(structured_output_model=SampleModel) + assert context_with_model.is_enabled is True + + # Test without model + context_without_model = StructuredOutputContext(structured_output_model=None) + assert context_without_model.is_enabled is False + + def test_store_result_and_get_result(self): + """Test storing and retrieving results.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Create test result + test_result = SampleModel(name="John Doe", age=30, email="john@example.com") + tool_use_id = "test_tool_use_123" + + # Store result + context.store_result(tool_use_id, test_result) + assert tool_use_id in context.results + assert context.results[tool_use_id] == test_result + + # Retrieve result + retrieved_result = context.get_result(tool_use_id) + assert retrieved_result == test_result + + # Test retrieving non-existent result + non_existent = context.get_result("non_existent_id") + assert non_existent is None + + def test_set_forced_mode_with_tool_choice(self): + """Test set_forced_mode with custom tool_choice.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + custom_tool_choice = {"specific": {"tool": "SampleModel"}} + context.set_forced_mode(tool_choice=custom_tool_choice) + + assert context.forced_mode is True + assert context.tool_choice == custom_tool_choice + + def test_set_forced_mode_without_tool_choice(self): + """Test set_forced_mode without tool_choice (default).""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + context.set_forced_mode() + + assert context.forced_mode is True + assert context.tool_choice == {"any": {}} + + def test_set_forced_mode_when_disabled(self): + """Test set_forced_mode when context is not enabled.""" + context = StructuredOutputContext(structured_output_model=None) + + # Should not change state when not enabled + context.set_forced_mode(tool_choice={"test": "value"}) + + assert context.forced_mode is False + assert context.tool_choice is None + + def test_has_structured_output_tool(self): + """Test has_structured_output_tool method.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Create tool uses with the expected tool + tool_uses_with_output = [ + {"name": "SampleModel", "toolUseId": "123", "input": {}}, + {"name": "OtherTool", "toolUseId": "456", "input": {}}, + ] + + # Should find the structured output tool + assert context.has_structured_output_tool(tool_uses_with_output) is True + + # Create tool uses without the expected tool + tool_uses_without_output = [ + {"name": "OtherTool", "toolUseId": "456", "input": {}}, + {"name": "AnotherTool", "toolUseId": "789", "input": {}}, + ] + + # Should not find the structured output tool + assert context.has_structured_output_tool(tool_uses_without_output) is False + + # Test with empty list + assert context.has_structured_output_tool([]) is False + + def test_has_structured_output_tool_when_disabled(self): + """Test has_structured_output_tool when no expected tool name.""" + context = StructuredOutputContext(structured_output_model=None) + + tool_uses = [ + {"name": "SampleModel", "toolUseId": "123", "input": {}}, + ] + + # Should return False when no expected tool name + assert context.has_structured_output_tool(tool_uses) is False + + def test_get_tool_spec(self): + """Test get_tool_spec method.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + tool_spec = context.get_tool_spec() + assert tool_spec is not None + assert isinstance(tool_spec, dict) + assert "name" in tool_spec + assert tool_spec["name"] == "SampleModel" + assert "description" in tool_spec + assert "inputSchema" in tool_spec + + def test_get_tool_spec_when_disabled(self): + """Test get_tool_spec when no structured output tool.""" + context = StructuredOutputContext(structured_output_model=None) + + tool_spec = context.get_tool_spec() + assert tool_spec is None + + def test_extract_result(self): + """Test extract_result method.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Store some results + result1 = SampleModel(name="Alice", age=25) + result2 = SampleModel(name="Bob", age=30) + context.store_result("tool_use_1", result1) + context.store_result("tool_use_2", result2) + + # Create tool uses with matching tool + tool_uses = [ + {"name": "SampleModel", "toolUseId": "tool_use_1", "input": {}}, + {"name": "OtherTool", "toolUseId": "tool_use_3", "input": {}}, + ] + + # Extract result should return and remove the first matching result + extracted = context.extract_result(tool_uses) + assert extracted == result1 + assert "tool_use_1" not in context.results + assert "tool_use_2" in context.results # Other result should remain + + def test_extract_result_no_matching_tool(self): + """Test extract_result when no matching tool in tool_uses.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + result = SampleModel(name="Alice", age=25) + context.store_result("tool_use_1", result) + + # Tool uses without the expected tool name + tool_uses = [ + {"name": "OtherTool", "toolUseId": "tool_use_1", "input": {}}, + ] + + # Should return None + extracted = context.extract_result(tool_uses) + assert extracted is None + assert "tool_use_1" in context.results # Result should remain + + def test_extract_result_no_stored_result(self): + """Test extract_result when no stored result for tool use.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Tool uses with matching tool but no stored result + tool_uses = [ + {"name": "SampleModel", "toolUseId": "tool_use_1", "input": {}}, + ] + + # Should return None + extracted = context.extract_result(tool_uses) + assert extracted is None + + def test_extract_result_multiple_matching_tools(self): + """Test extract_result with multiple matching tool uses.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Store multiple results + result1 = SampleModel(name="Alice", age=25) + result2 = SampleModel(name="Bob", age=30) + context.store_result("tool_use_1", result1) + context.store_result("tool_use_2", result2) + + # Multiple matching tool uses + tool_uses = [ + {"name": "SampleModel", "toolUseId": "tool_use_1", "input": {}}, + {"name": "SampleModel", "toolUseId": "tool_use_2", "input": {}}, + ] + + # Should extract the first matching result + extracted = context.extract_result(tool_uses) + assert extracted == result1 + assert "tool_use_1" not in context.results + assert "tool_use_2" in context.results + + # Extract again for the second result + extracted2 = context.extract_result(tool_uses) + assert extracted2 == result2 + assert "tool_use_2" not in context.results diff --git a/tests/strands/tools/structured_output/test_structured_output_tool.py b/tests/strands/tools/structured_output/test_structured_output_tool.py new file mode 100644 index 000000000..66f1d465d --- /dev/null +++ b/tests/strands/tools/structured_output/test_structured_output_tool.py @@ -0,0 +1,307 @@ +"""Tests for StructuredOutputTool class.""" + +from typing import List, Optional +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output.structured_output_tool import _TOOL_SPEC_CACHE, StructuredOutputTool +from strands.types._events import ToolResultEvent + + +class SimpleModel(BaseModel): + """Simple test model.""" + + name: str = Field(..., description="Name field") + value: int = Field(..., description="Value field") + + +class ComplexModel(BaseModel): + """Complex test model with nested structures.""" + + title: str = Field(..., description="Title field") + count: int = Field(..., ge=0, le=100, description="Count between 0 and 100") + tags: List[str] = Field(default_factory=list, description="List of tags") + metadata: Optional[dict] = Field(None, description="Optional metadata") + + +class ValidationTestModel(BaseModel): + """Model for testing validation.""" + + email: str = Field(..., pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$", description="Email address") + age: int = Field(..., ge=0, le=150, description="Age between 0 and 150") + status: str = Field(..., pattern="^(active|inactive|pending)$", description="Status") + + +class TestStructuredOutputTool: + """Test suite for StructuredOutputTool.""" + + def test_tool_initialization_with_simple_model(self): + """Test tool initialization with a simple Pydantic model.""" + tool = StructuredOutputTool(SimpleModel) + + assert tool.structured_output_model == SimpleModel + assert tool.tool_name == "SimpleModel" + assert tool.tool_type == "structured_output" + assert isinstance(tool.tool_spec, dict) + assert tool.tool_spec["name"] == "SimpleModel" + + def test_tool_initialization_with_complex_model(self): + """Test tool initialization with a complex Pydantic model.""" + tool = StructuredOutputTool(ComplexModel) + + assert tool.structured_output_model == ComplexModel + assert tool.tool_name == "ComplexModel" + assert tool.tool_type == "structured_output" + assert isinstance(tool.tool_spec, dict) + assert tool.tool_spec["name"] == "ComplexModel" + + def test_get_tool_spec_caching_mechanism(self): + """Test that tool specs are cached properly.""" + # Clear cache first + _TOOL_SPEC_CACHE.clear() + + # First call should create and cache the spec + tool1 = StructuredOutputTool(SimpleModel) + spec1 = tool1.tool_spec + + # Cache should now contain the spec + assert SimpleModel in _TOOL_SPEC_CACHE + + # Second call with same model should use cached version + tool2 = StructuredOutputTool(SimpleModel) + spec2 = tool2.tool_spec + + # Specs should be equal but not the same object (deepcopy is used) + assert spec1 == spec2 + assert spec1 is not spec2 + + # Cache should still have only one entry for SimpleModel + assert len([k for k in _TOOL_SPEC_CACHE if k == SimpleModel]) == 1 + + def test_tool_name_property(self): + """Test the tool_name property.""" + tool = StructuredOutputTool(SimpleModel) + assert tool.tool_name == "SimpleModel" + + tool2 = StructuredOutputTool(ComplexModel) + assert tool2.tool_name == "ComplexModel" + + def test_tool_spec_property(self): + """Test the tool_spec property.""" + tool = StructuredOutputTool(SimpleModel) + spec = tool.tool_spec + + assert isinstance(spec, dict) + assert "name" in spec + assert "description" in spec + assert "inputSchema" in spec + assert spec["name"] == "SimpleModel" + + # Check that description includes the important message + assert "IMPORTANT: This StructuredOutputTool should only be invoked" in spec["description"] + + def test_tool_type_property(self): + """Test that tool_type property returns 'structured_output'.""" + tool = StructuredOutputTool(SimpleModel) + assert tool.tool_type == "structured_output" + + def test_structured_output_model_property(self): + """Test the structured_output_model property.""" + tool = StructuredOutputTool(SimpleModel) + assert tool.structured_output_model == SimpleModel + + tool2 = StructuredOutputTool(ComplexModel) + assert tool2.structured_output_model == ComplexModel + + @pytest.mark.asyncio + async def test_stream_with_valid_input(self): + """Test stream method with valid input.""" + tool = StructuredOutputTool(SimpleModel) + context = StructuredOutputContext(structured_output_model=SimpleModel) + + tool_use = {"name": "SimpleModel", "toolUseId": "test_123", "input": {"name": "Test Name", "value": 42}} + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Should have one ToolResultEvent + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + + # Check the result + result = events[0].tool_result + assert result["toolUseId"] == "test_123" + assert result["status"] == "success" + assert "Successfully validated SimpleModel" in result["content"][0]["text"] + + # Check that result was stored in context + stored_result = context.get_result("test_123") + assert stored_result is not None + assert stored_result.name == "Test Name" + assert stored_result.value == 42 + + @pytest.mark.asyncio + async def test_stream_with_missing_fields(self): + """Test stream method with missing required fields.""" + tool = StructuredOutputTool(SimpleModel) + context = StructuredOutputContext(structured_output_model=SimpleModel) + + tool_use = { + "name": "SimpleModel", + "toolUseId": "test_789", + "input": { + "name": "Test Name" + # Missing required 'value' field + }, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Should have one ToolResultEvent with error + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + + # Check the error result + result = events[0].tool_result + assert result["toolUseId"] == "test_789" + assert result["status"] == "error" + + error_text = result["content"][0]["text"] + assert "Validation failed for SimpleModel" in error_text + assert "Field 'value'" in error_text or "field required" in error_text.lower() + + @pytest.mark.asyncio + async def test_stream_with_unexpected_exception(self): + """Test stream method with unexpected exceptions.""" + tool = StructuredOutputTool(SimpleModel) + context = MagicMock() + + # Mock the context to raise an unexpected exception + context.store_result.side_effect = RuntimeError("Unexpected error") + + tool_use = {"name": "SimpleModel", "toolUseId": "test_error", "input": {"name": "Test", "value": 1}} + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Should have one ToolResultEvent with error + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + + # Check the error result + result = events[0].tool_result + assert result["toolUseId"] == "test_error" + assert result["status"] == "error" + + error_text = result["content"][0]["text"] + assert "Unexpected error validating SimpleModel" in error_text + assert "Unexpected error" in error_text + + @pytest.mark.asyncio + async def test_error_message_formatting_single_error(self): + """Test error message formatting with a single validation error.""" + tool = StructuredOutputTool(SimpleModel) + context = StructuredOutputContext(structured_output_model=SimpleModel) + + tool_use = { + "name": "SimpleModel", + "toolUseId": "test_format_1", + "input": { + "name": "Test", + "value": "not an integer", # Wrong type + }, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + result = events[0].tool_result + error_text = result["content"][0]["text"] + + # Check error formatting + assert "Validation failed for SimpleModel" in error_text + assert "Please fix the following errors:" in error_text + assert "- Field 'value':" in error_text + + @pytest.mark.asyncio + async def test_error_message_formatting_multiple_errors(self): + """Test error message formatting with multiple validation errors.""" + tool = StructuredOutputTool(ValidationTestModel) + context = StructuredOutputContext(structured_output_model=ValidationTestModel) + + tool_use = { + "name": "ValidationTestModel", + "toolUseId": "test_format_2", + "input": {"email": "bad-email", "age": -5, "status": "invalid"}, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + result = events[0].tool_result + error_text = result["content"][0]["text"] + + # Check that multiple errors are formatted properly + assert "Validation failed for ValidationTestModel" in error_text + assert "Please fix the following errors:" in error_text + # Should have multiple error lines + error_lines = [line for line in error_text.split("\n") if line.startswith("- Field")] + assert len(error_lines) >= 2 # At least 2 validation errors + + @pytest.mark.asyncio + async def test_stream_with_complex_nested_data(self): + """Test stream method with complex nested data.""" + tool = StructuredOutputTool(ComplexModel) + context = StructuredOutputContext(structured_output_model=ComplexModel) + + tool_use = { + "name": "ComplexModel", + "toolUseId": "test_complex", + "input": { + "title": "Test Title", + "count": 50, + "tags": ["tag1", "tag2", "tag3"], + "metadata": {"key1": "value1", "key2": 123}, + }, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Check success + result = events[0].tool_result + assert result["status"] == "success" + + # Check stored result + stored_result = context.get_result("test_complex") + assert stored_result.title == "Test Title" + assert stored_result.count == 50 + assert stored_result.tags == ["tag1", "tag2", "tag3"] + assert stored_result.metadata == {"key1": "value1", "key2": 123} + + def test_tool_spec_description_modification(self): + """Test that tool spec description is properly modified.""" + tool = StructuredOutputTool(SimpleModel) + spec = tool.tool_spec + + # Check that the IMPORTANT message is prepended + assert spec["description"].startswith("IMPORTANT: This StructuredOutputTool should only be invoked") + assert "last and final tool" in spec["description"] + assert "" in spec["description"] + assert "" in spec["description"] diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 5b4b5cdda..25f9bc39e 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -10,7 +10,9 @@ import strands from strands import Agent -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt +from strands.types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -138,6 +140,67 @@ def identity(a: int, agent: dict = None): assert tru_events == exp_events +@pytest.mark.asyncio +async def test_stream_interrupt(alist): + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + tool_use = {"toolUseId": "test_tool_id"} + + mock_agent = MagicMock() + mock_agent._interrupt_state = InterruptState() + + invocation_state = {"agent": mock_agent} + + @strands.tool(context=True) + def interrupt_tool(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_name", reason="test reason") + + stream = interrupt_tool.stream(tool_use, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_use, [interrupt])] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume(alist): + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + + tool_use = {"toolUseId": "test_tool_id"} + + mock_agent = MagicMock() + mock_agent._interrupt_state = InterruptState(interrupts={interrupt.id: interrupt}) + + invocation_state = {"agent": mock_agent} + + @strands.tool(context=True) + def interrupt_tool(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_name", reason="test reason") + + stream = interrupt_tool.stream(tool_use, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "test_tool_id", + "status": "success", + "content": [{"text": "test response"}], + }, + ), + ] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_basic_tool_creation(alist): """Test basic tool decorator functionality.""" @@ -1363,3 +1426,27 @@ async def async_generator() -> AsyncGenerator: ] assert act_results == exp_results + + +def test_function_tool_metadata_validate_signature_default_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'tool_context'"): + + @strands.tool(context=True) + def my_tool(context: ToolContext): + pass + + +def test_function_tool_metadata_validate_signature_custom_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'my_context'"): + + @strands.tool(context="my_context") + def my_tool(tool_context: ToolContext): + pass + + +def test_function_tool_metadata_validate_signature_missing_context_config(): + with pytest.raises(ValueError, match=r"@tool\(context\) must be set if passing in ToolContext param"): + + @strands.tool + def my_tool(tool_context: ToolContext): + pass diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index 6b86d00ee..13aca90c3 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -1,11 +1,12 @@ import os import re +import tempfile import textwrap import pytest from strands.tools.decorator import DecoratedFunctionTool -from strands.tools.loader import ToolLoader +from strands.tools.loader import ToolLoader, load_tools_from_file_path from strands.tools.tools import PythonAgentTool @@ -310,3 +311,9 @@ def test_load_tool_path_returns_single_tool(tool_path): assert loaded_python_tool.tool_name == "alpha" assert loaded_tool.tool_name == "alpha" + + +def test_load_tools_from_file_path_module_spec_missing(): + with tempfile.NamedTemporaryFile() as f: + with pytest.raises(ImportError, match=f"Could not create spec for {os.path.basename(f.name)}"): + load_tools_from_file_path(f.name) diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index f0759ea07..c700016f6 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -2,13 +2,15 @@ Tests for the SDK tool registry module. """ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest import strands +from strands.experimental.tools import ToolProvider from strands.tools import PythonAgentTool from strands.tools.decorator import DecoratedFunctionTool, tool +from strands.tools.mcp import MCPClient from strands.tools.registry import ToolRegistry @@ -26,7 +28,10 @@ def test_process_tools_with_invalid_path(): tool_registry = ToolRegistry() invalid_path = "not a filepath" - with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): + with pytest.raises( + ValueError, + match=f'Failed to load tool {invalid_path}: Tool string: "{invalid_path}" is not a valid tool string', + ): tool_registry.process_tools([invalid_path]) @@ -164,3 +169,223 @@ def test_register_tool_duplicate_name_with_hot_reload(): # Verify the second tool replaced the first assert tool_registry.registry["hot_reload_tool"] == tool_2 + + +def test_register_strands_tools_from_module(): + tool_registry = ToolRegistry() + tool_registry.process_tools(["tests.fixtures.say_tool"]) + + assert len(tool_registry.registry) == 2 + assert "say" in tool_registry.registry + assert "dont_say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_from_module(): + tool_registry = ToolRegistry() + tool_registry.process_tools(["tests.fixtures.say_tool:say"]) + + assert len(tool_registry.registry) == 1 + assert "say" in tool_registry.registry + assert "dont_say" not in tool_registry.registry + + +def test_register_strands_tools_specific_tool_from_module_tool_missing(): + tool_registry = ToolRegistry() + + with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:nay: "): + tool_registry.process_tools(["tests.fixtures.say_tool:nay"]) + + +def test_register_strands_tools_specific_tool_from_module_not_a_tool(): + tool_registry = ToolRegistry() + + with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:not_a_tool: "): + tool_registry.process_tools(["tests.fixtures.say_tool:not_a_tool"]) + + +def test_register_strands_tools_with_dict(): + tool_registry = ToolRegistry() + tool_registry.process_tools([{"path": "tests.fixtures.say_tool"}]) + + assert len(tool_registry.registry) == 2 + assert "say" in tool_registry.registry + assert "dont_say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_with_dict(): + tool_registry = ToolRegistry() + tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "say"}]) + + assert len(tool_registry.registry) == 1 + assert "say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_with_dict_not_found(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool {'path': 'tests.fixtures.say_tool'" + ", 'name': 'nay'}: Tool \"nay\" not found in \"tests.fixtures.say_tool\"", + ): + tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "nay"}]) + + +def test_register_strands_tools_module_no_spec(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.mocked_model_provider: " + "The module mocked_model_provider is not a valid module", + ): + tool_registry.process_tools(["tests.fixtures.mocked_model_provider"]) + + +def test_register_strands_tools_module_no_function(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.tool_with_spec_but_no_function: " + "Module-based tool tool_with_spec_but_no_function missing function tool_with_spec_but_no_function", + ): + tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_no_function"]) + + +def test_register_strands_tools_module_non_callable_function(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.tool_with_spec_but_non_callable_function:" + " Tool tool_with_spec_but_non_callable_function function is not callable", + ): + tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_non_callable_function"]) + + +def test_tool_registry_cleanup_with_mcp_client(): + """Test that ToolRegistry cleanup properly handles MCP clients without orphaning threads.""" + # Create a mock MCP client that simulates a real tool provider + mock_transport = MagicMock() + mock_client = MCPClient(mock_transport) + + # Mock the client to avoid actual network operations + mock_client.load_tools = AsyncMock(return_value=[]) + + registry = ToolRegistry() + + # Use process_tools to properly register the client + registry.process_tools([mock_client]) + + # Verify the client was registered as a consumer + assert registry._registry_id in mock_client._consumers + + # Test cleanup calls remove_consumer + registry.cleanup() + + # Verify cleanup was attempted + assert registry._registry_id not in mock_client._consumers + + +def test_tool_registry_cleanup_exception_handling(): + """Test that ToolRegistry cleanup attempts all providers even if some fail.""" + # Create mock providers - one that fails, one that succeeds + failing_provider = MagicMock() + failing_provider.remove_consumer.side_effect = Exception("Cleanup failed") + + working_provider = MagicMock() + + registry = ToolRegistry() + registry._tool_providers = [failing_provider, working_provider] + + # Cleanup should attempt both providers and raise the first exception + with pytest.raises(Exception, match="Cleanup failed"): + registry.cleanup() + + # Verify both providers were attempted + failing_provider.remove_consumer.assert_called_once() + working_provider.remove_consumer.assert_called_once() + + +def test_tool_registry_cleanup_idempotent(): + """Test that ToolRegistry cleanup is idempotent.""" + provider = MagicMock(spec=ToolProvider) + provider.load_tools = AsyncMock(return_value=[]) + + registry = ToolRegistry() + + # Use process_tools to properly register the provider + registry.process_tools([provider]) + + # First cleanup should call remove_consumer + registry.cleanup() + provider.remove_consumer.assert_called_once_with(registry._registry_id) + + # Reset mock call count + provider.remove_consumer.reset_mock() + + # Second cleanup should call remove_consumer again (not idempotent yet) + # This test documents current behavior - registry cleanup is not idempotent + registry.cleanup() + provider.remove_consumer.assert_called_once_with(registry._registry_id) + + +def test_tool_registry_process_tools_exception_after_add_consumer(): + """Test that tool provider is still tracked for cleanup even if load_tools fails.""" + # Create a mock tool provider that fails during load_tools + mock_provider = MagicMock(spec=ToolProvider) + mock_provider.add_consumer = MagicMock() + mock_provider.remove_consumer = MagicMock() + + async def failing_load_tools(): + raise Exception("Failed to load tools") + + mock_provider.load_tools = AsyncMock(side_effect=failing_load_tools) + + registry = ToolRegistry() + + # Processing should fail but provider should still be tracked + with pytest.raises(ValueError, match="Failed to load tool"): + registry.process_tools([mock_provider]) + + # Verify provider was added to registry for cleanup tracking + assert mock_provider in registry._tool_providers + + # Verify add_consumer was called before the failure + mock_provider.add_consumer.assert_called_once_with(registry._registry_id) + + # Cleanup should still work + registry.cleanup() + mock_provider.remove_consumer.assert_called_once_with(registry._registry_id) + + +def test_tool_registry_add_consumer_before_load_tools(): + """Test that add_consumer is called before load_tools to ensure cleanup tracking.""" + # Create a mock tool provider that tracks call order + mock_provider = MagicMock(spec=ToolProvider) + call_order = [] + + def track_add_consumer(*args, **kwargs): + call_order.append("add_consumer") + + async def track_load_tools(*args, **kwargs): + call_order.append("load_tools") + return [] + + mock_provider.add_consumer.side_effect = track_add_consumer + mock_provider.load_tools = AsyncMock(side_effect=track_load_tools) + + registry = ToolRegistry() + + # Process the tool provider + registry.process_tools([mock_provider]) + + # Verify add_consumer was called before load_tools + assert call_order == ["add_consumer", "load_tools"] + + # Verify the provider was added to the registry for cleanup + assert mock_provider in registry._tool_providers + + # Verify add_consumer was called with the registry ID + mock_provider.add_consumer.assert_called_once_with(registry._registry_id) diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py new file mode 100644 index 000000000..fdf4abb0a --- /dev/null +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -0,0 +1,328 @@ +"""Unit tests for ToolRegistry ToolProvider functionality.""" + +from unittest.mock import patch + +import pytest + +from strands.experimental.tools.tool_provider import ToolProvider +from strands.tools.registry import ToolRegistry +from tests.fixtures.mock_agent_tool import MockAgentTool + + +class MockToolProvider(ToolProvider): + """Mock ToolProvider for testing.""" + + def __init__(self, tools=None, cleanup_error=None): + self._tools = tools or [] + self._cleanup_error = cleanup_error + self.cleanup_called = False + self.remove_consumer_called = False + self.remove_consumer_id = None + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + return self._tools + + def cleanup(self): + self.cleanup_called = True + if self._cleanup_error: + raise self._cleanup_error + + def add_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + def remove_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + if self._cleanup_error: + raise self._cleanup_error + + +@pytest.fixture +def mock_run_async(): + """Fixture for mocking strands.tools.registry.run_async.""" + with patch("strands.tools.registry.run_async") as mock: + yield mock + + +@pytest.fixture +def mock_agent_tool(): + """Fixture factory for creating MockAgentTool instances.""" + return MockAgentTool + + +class TestToolRegistryToolProvider: + """Test ToolRegistry integration with ToolProvider.""" + + def test_process_tools_with_tool_provider(self, mock_run_async, mock_agent_tool): + """Test that process_tools handles ToolProvider correctly.""" + # Create mock tools + mock_tool1 = mock_agent_tool("provider_tool_1") + mock_tool2 = mock_agent_tool("provider_tool_2") + + # Create mock provider + provider = MockToolProvider([mock_tool1, mock_tool2]) + + registry = ToolRegistry() + + # Mock run_async to return the tools directly + mock_run_async.return_value = [mock_tool1, mock_tool2] + + tool_names = registry.process_tools([provider]) + + # Verify run_async was called with the provider's load_tools method + mock_run_async.assert_called_once() + + # Verify tools were registered + assert "provider_tool_1" in tool_names + assert "provider_tool_2" in tool_names + assert len(tool_names) == 2 + + # Verify provider was tracked + assert provider in registry._tool_providers + + # Verify tools are in registry + assert registry.registry["provider_tool_1"] is mock_tool1 + assert registry.registry["provider_tool_2"] is mock_tool2 + + def test_process_tools_with_multiple_providers(self, mock_run_async, mock_agent_tool): + """Test that process_tools handles multiple ToolProviders.""" + # Create mock tools for first provider + mock_tool1 = mock_agent_tool("provider1_tool") + provider1 = MockToolProvider([mock_tool1]) + + # Create mock tools for second provider + mock_tool2 = mock_agent_tool("provider2_tool") + provider2 = MockToolProvider([mock_tool2]) + + registry = ToolRegistry() + + # Mock run_async to return appropriate tools for each call + mock_run_async.side_effect = [[mock_tool1], [mock_tool2]] + + tool_names = registry.process_tools([provider1, provider2]) + + # Verify run_async was called twice + assert mock_run_async.call_count == 2 + + # Verify all tools were registered + assert "provider1_tool" in tool_names + assert "provider2_tool" in tool_names + assert len(tool_names) == 2 + + # Verify both providers were tracked + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers + assert len(registry._tool_providers) == 2 + + def test_process_tools_with_mixed_tools_and_providers(self, mock_run_async, mock_agent_tool): + """Test that process_tools handles mix of regular tools and providers.""" + # Create regular tool + regular_tool = mock_agent_tool("regular_tool") + + # Create provider tool + provider_tool = mock_agent_tool("provider_tool") + provider = MockToolProvider([provider_tool]) + + registry = ToolRegistry() + + mock_run_async.return_value = [provider_tool] + + tool_names = registry.process_tools([regular_tool, provider]) + + # Verify both tools were registered + assert "regular_tool" in tool_names + assert "provider_tool" in tool_names + assert len(tool_names) == 2 + + # Verify only provider was tracked + assert provider in registry._tool_providers + assert len(registry._tool_providers) == 1 + + def test_process_tools_with_empty_provider(self, mock_run_async): + """Test that process_tools handles provider with no tools.""" + provider = MockToolProvider([]) # Empty tools list + + registry = ToolRegistry() + + mock_run_async.return_value = [] + + tool_names = registry.process_tools([provider]) + + # Verify no tools were registered + assert not tool_names + + # Verify provider was still tracked + assert provider in registry._tool_providers + + def test_tool_providers_public_access(self): + """Test that tool_providers can be accessed directly.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Verify direct access works + assert len(registry._tool_providers) == 2 + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers + + def test_tool_providers_empty_by_default(self): + """Test that tool_providers is empty by default.""" + registry = ToolRegistry() + + assert not registry._tool_providers + assert isinstance(registry._tool_providers, list) + + def test_process_tools_provider_load_exception(self, mock_run_async): + """Test that process_tools handles exceptions from provider.load_tools().""" + provider = MockToolProvider() + + registry = ToolRegistry() + + # Make load_tools raise an exception + mock_run_async.side_effect = Exception("Load tools failed") + + # Should raise the exception from load_tools + with pytest.raises(Exception, match="Load tools failed"): + registry.process_tools([provider]) + + # Provider should still be tracked even if load_tools failed + assert provider in registry._tool_providers + + def test_tool_provider_tracking_persistence(self, mock_run_async, mock_agent_tool): + """Test that tool providers are tracked across multiple process_tools calls.""" + provider1 = MockToolProvider([mock_agent_tool("tool1")]) + provider2 = MockToolProvider([mock_agent_tool("tool2")]) + + registry = ToolRegistry() + + mock_run_async.side_effect = [ + [mock_agent_tool("tool1")], + [mock_agent_tool("tool2")], + ] + + # Process first provider + registry.process_tools([provider1]) + assert len(registry._tool_providers) == 1 + assert provider1 in registry._tool_providers + + # Process second provider + registry.process_tools([provider2]) + assert len(registry._tool_providers) == 2 + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers + + def test_process_tools_provider_async_optimization(self, mock_agent_tool): + """Test that load_tools and add_consumer are called in same async context.""" + mock_tool = mock_agent_tool("test_tool") + + class TestProvider(ToolProvider): + def __init__(self): + self.load_tools_called = False + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + self.load_tools_called = True + return [mock_tool] + + def add_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + def remove_consumer(self, consumer_id): + pass + + provider = TestProvider() + registry = ToolRegistry() + + # Process the provider - this should call both methods + tool_names = registry.process_tools([provider]) + + # Verify both methods were called + assert provider.load_tools_called + assert provider.add_consumer_called + assert provider.add_consumer_id == registry._registry_id + + # Verify tool was registered + assert "test_tool" in tool_names + assert provider in registry._tool_providers + + def test_registry_cleanup(self): + """Test that registry cleanup calls remove_consumer on all providers.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + registry.cleanup() + + # Verify both providers had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + def test_registry_cleanup_with_provider_consumer_removal(self): + """Test that cleanup removes provider consumers correctly.""" + + class TestProvider(ToolProvider): + def __init__(self): + self.remove_consumer_called = False + self.remove_consumer_id = None + + async def load_tools(self): + return [] + + def add_consumer(self, consumer_id): + pass + + def remove_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + + provider = TestProvider() + registry = ToolRegistry() + registry._tool_providers = [provider] + + # Call cleanup + registry.cleanup() + + # Verify remove_consumer was called with correct ID + assert provider.remove_consumer_called + assert provider.remove_consumer_id == registry._registry_id + + def test_registry_cleanup_raises_exception_on_provider_error(self): + """Test that cleanup raises exception when provider removal fails.""" + provider1 = MockToolProvider(cleanup_error=RuntimeError("Provider cleanup failed")) + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Cleanup should raise the exception from first provider but still attempt cleanup of all + with pytest.raises(RuntimeError, match="Provider cleanup failed"): + registry.cleanup() + + # Both providers should have had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + def test_registry_cleanup_raises_first_exception_on_multiple_provider_errors(self): + """Test that cleanup raises first exception when multiple providers fail but attempts all.""" + provider1 = MockToolProvider(cleanup_error=RuntimeError("Provider 1 failed")) + provider2 = MockToolProvider(cleanup_error=ValueError("Provider 2 failed")) + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Cleanup should raise first exception but still attempt cleanup of all + with pytest.raises(RuntimeError, match="Provider 1 failed"): + registry.cleanup() + + # Both providers should have had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py new file mode 100644 index 000000000..d64cabb83 --- /dev/null +++ b/tests/strands/types/test__events.py @@ -0,0 +1,467 @@ +"""Tests for event types in the strands.types._events module.""" + +from unittest.mock import MagicMock, Mock + +from pydantic import BaseModel + +from strands.telemetry import EventLoopMetrics +from strands.types._events import ( + AgentResultEvent, + CitationStreamEvent, + EventLoopStopEvent, + EventLoopThrottleEvent, + ForceStopEvent, + InitEventLoopEvent, + ModelMessageEvent, + ModelStopReason, + ModelStreamChunkEvent, + ModelStreamEvent, + ReasoningRedactedContentStreamEvent, + ReasoningSignatureStreamEvent, + ReasoningTextStreamEvent, + StartEvent, + StartEventLoopEvent, + StructuredOutputEvent, + TextStreamEvent, + ToolResultEvent, + ToolResultMessageEvent, + ToolStreamEvent, + ToolUseStreamEvent, + TypedEvent, +) +from strands.types.citations import Citation +from strands.types.content import Message +from strands.types.event_loop import Metrics, StopReason, Usage +from strands.types.streaming import ContentBlockDelta, StreamEvent +from strands.types.tools import ToolResult, ToolUse + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str + value: int + + +class TestTypedEvent: + """Tests for the base TypedEvent class.""" + + def test_initialization_with_data(self): + """Test TypedEvent initialization with data.""" + data = {"key": "value", "number": 42} + event = TypedEvent(data) + assert event["key"] == "value" + assert event["number"] == 42 + + def test_initialization_without_data(self): + """Test TypedEvent initialization without data.""" + event = TypedEvent() + assert len(event) == 0 + + def test_is_callback_event_default(self): + """Test that is_callback_event returns True by default.""" + event = TypedEvent() + assert event.is_callback_event is True + + def test_as_dict(self): + """Test as_dict method returns dictionary representation.""" + data = {"test": "data", "nested": {"key": "value"}} + event = TypedEvent(data) + result = event.as_dict() + assert result == data + assert isinstance(result, dict) + + def test_prepare_default_implementation(self): + """Test prepare method default implementation does nothing.""" + event = TypedEvent({"initial": "data"}) + invocation_state = {"state": "value"} + event.prepare(invocation_state) + # Default implementation does nothing + assert event == {"initial": "data"} + + +class TestInitEventLoopEvent: + """Tests for InitEventLoopEvent.""" + + def test_initialization(self): + """Test InitEventLoopEvent initialization.""" + event = InitEventLoopEvent() + assert event["init_event_loop"] is True + + def test_prepare_updates_with_invocation_state(self): + """Test prepare method updates event with invocation state.""" + event = InitEventLoopEvent() + invocation_state = {"request_id": "123", "session": "abc"} + event.prepare(invocation_state) + assert event["request_id"] == "123" + assert event["session"] == "abc" + assert event["init_event_loop"] is True + + +class TestStartEvent: + """Tests for StartEvent (deprecated).""" + + def test_initialization(self): + """Test StartEvent initialization.""" + event = StartEvent() + assert event["start"] is True + + +class TestStartEventLoopEvent: + """Tests for StartEventLoopEvent.""" + + def test_initialization(self): + """Test StartEventLoopEvent initialization.""" + event = StartEventLoopEvent() + assert event["start_event_loop"] is True + + +class TestModelStreamChunkEvent: + """Tests for ModelStreamChunkEvent.""" + + def test_initialization_with_stream_event(self): + """Test ModelStreamChunkEvent initialization with StreamEvent.""" + stream_event = Mock(spec=StreamEvent) + event = ModelStreamChunkEvent(stream_event) + assert event["event"] == stream_event + assert event.chunk == stream_event + + +class TestModelStreamEvent: + """Tests for ModelStreamEvent.""" + + def test_initialization_with_delta_data(self): + """Test ModelStreamEvent initialization with delta data.""" + delta_data = {"type": "text", "content": "hello"} + event = ModelStreamEvent(delta_data) + assert event["type"] == "text" + assert event["content"] == "hello" + + def test_is_callback_event_empty(self): + """Test is_callback_event returns False when empty.""" + event = ModelStreamEvent({}) + assert event.is_callback_event is False + + def test_is_callback_event_non_empty(self): + """Test is_callback_event returns True when non-empty.""" + event = ModelStreamEvent({"data": "value"}) + assert event.is_callback_event is True + + def test_prepare_with_delta(self): + """Test prepare method updates when delta is present.""" + event = ModelStreamEvent({"delta": "content", "other": "data"}) + invocation_state = {"request_id": "456"} + event.prepare(invocation_state) + assert event["request_id"] == "456" + assert event["delta"] == "content" + + def test_prepare_without_delta(self): + """Test prepare method does nothing when delta is not present.""" + event = ModelStreamEvent({"other": "data"}) + invocation_state = {"request_id": "456"} + event.prepare(invocation_state) + assert "request_id" not in event + + +class TestToolUseStreamEvent: + """Tests for ToolUseStreamEvent.""" + + def test_initialization(self): + """Test ToolUseStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + current_tool_use = {"toolUseId": "123", "name": "calculator"} + event = ToolUseStreamEvent(delta, current_tool_use) + assert event["delta"] == delta + assert event["current_tool_use"] == current_tool_use + + +class TestTextStreamEvent: + """Tests for TextStreamEvent.""" + + def test_initialization(self): + """Test TextStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + text = "Hello, world!" + event = TextStreamEvent(delta, text) + assert event["data"] == text + assert event["delta"] == delta + + +class TestCitationStreamEvent: + """Tests for CitationStreamEvent.""" + + def test_initialization(self): + """Test CitationStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + citation = Mock(spec=Citation) + event = CitationStreamEvent(delta, citation) + assert event["callback"]["citation"] == citation + assert event["callback"]["delta"] == delta + + +class TestReasoningTextStreamEvent: + """Tests for ReasoningTextStreamEvent.""" + + def test_initialization_with_reasoning_text(self): + """Test ReasoningTextStreamEvent initialization with text.""" + delta = Mock(spec=ContentBlockDelta) + reasoning_text = "Thinking about the problem..." + event = ReasoningTextStreamEvent(delta, reasoning_text) + assert event["reasoningText"] == reasoning_text + assert event["delta"] == delta + assert event["reasoning"] is True + + def test_initialization_with_none(self): + """Test ReasoningTextStreamEvent initialization with None.""" + delta = Mock(spec=ContentBlockDelta) + event = ReasoningTextStreamEvent(delta, None) + assert event["reasoningText"] is None + assert event["reasoning"] is True + + +class TestReasoningRedactedContentStreamEvent: + """Tests for ReasoningRedactedContentStreamEvent.""" + + def test_initialization_with_redacted_content(self): + """Test ReasoningRedactedContentStreamEvent initialization with content.""" + delta = Mock(spec=ContentBlockDelta) + redacted_content = b"[REDACTED]" + event = ReasoningRedactedContentStreamEvent(delta, redacted_content) + assert event["reasoningRedactedContent"] == redacted_content + assert event["delta"] == delta + assert event["reasoning"] is True + + def test_initialization_with_none(self): + """Test ReasoningRedactedContentStreamEvent initialization with None.""" + delta = Mock(spec=ContentBlockDelta) + event = ReasoningRedactedContentStreamEvent(delta, None) + assert event["reasoningRedactedContent"] is None + assert event["reasoning"] is True + + +class TestReasoningSignatureStreamEvent: + """Tests for ReasoningSignatureStreamEvent.""" + + def test_initialization(self): + """Test ReasoningSignatureStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + signature = "signature_xyz123" + event = ReasoningSignatureStreamEvent(delta, signature) + assert event["reasoning_signature"] == signature + assert event["delta"] == delta + assert event["reasoning"] is True + + +class TestModelStopReason: + """Tests for ModelStopReason.""" + + def test_initialization(self): + """Test ModelStopReason initialization.""" + stop_reason = Mock(spec=StopReason) + message = Mock(spec=Message) + usage = Mock(spec=Usage) + metrics = Mock(spec=Metrics) + + event = ModelStopReason(stop_reason, message, usage, metrics) + assert event["stop"] == (stop_reason, message, usage, metrics) + assert event.is_callback_event is False + + +class TestEventLoopStopEvent: + """Tests for EventLoopStopEvent.""" + + def test_initialization_without_structured_output(self): + """Test EventLoopStopEvent initialization without structured output.""" + stop_reason = Mock(spec=StopReason) + message = Mock(spec=Message) + metrics = Mock(spec=EventLoopMetrics) + request_state = {"state": "final"} + + event = EventLoopStopEvent(stop_reason, message, metrics, request_state) + assert event["stop"] == (stop_reason, message, metrics, request_state, None, None) + assert event.is_callback_event is False + + def test_initialization_with_structured_output(self): + """Test EventLoopStopEvent initialization with structured output.""" + stop_reason = Mock(spec=StopReason) + message = Mock(spec=Message) + metrics = Mock(spec=EventLoopMetrics) + request_state = {"state": "final"} + structured_output = SampleModel(name="test", value=42) + + event = EventLoopStopEvent(stop_reason, message, metrics, request_state, structured_output) + assert event["stop"] == (stop_reason, message, metrics, request_state, structured_output, None) + assert event.is_callback_event is False + + +class TestStructuredOutputEvent: + """Tests for StructuredOutputEvent.""" + + def test_initialization(self): + """Test StructuredOutputEvent initialization.""" + structured_output = SampleModel(name="output", value=100) + event = StructuredOutputEvent(structured_output) + assert event["structured_output"] == structured_output + assert isinstance(event["structured_output"], SampleModel) + + +class TestEventLoopThrottleEvent: + """Tests for EventLoopThrottleEvent.""" + + def test_initialization(self): + """Test EventLoopThrottleEvent initialization.""" + delay = 5 + event = EventLoopThrottleEvent(delay) + assert event["event_loop_throttled_delay"] == 5 + + def test_prepare_updates_with_invocation_state(self): + """Test prepare method updates event with invocation state.""" + event = EventLoopThrottleEvent(10) + invocation_state = {"request_id": "throttle_123"} + event.prepare(invocation_state) + assert event["request_id"] == "throttle_123" + assert event["event_loop_throttled_delay"] == 10 + + +class TestToolResultEvent: + """Tests for ToolResultEvent.""" + + def test_initialization(self): + """Test ToolResultEvent initialization.""" + tool_result: ToolResult = { + "toolUseId": "tool_123", + "content": [{"text": "Result"}], + "isError": False, + } + event = ToolResultEvent(tool_result) + assert event["tool_result"] == tool_result + assert event.tool_use_id == "tool_123" + assert event.tool_result == tool_result + assert event.is_callback_event is False + + def test_tool_use_id_property(self): + """Test tool_use_id property returns correct ID.""" + tool_result: ToolResult = { + "toolUseId": "unique_id_456", + "content": [], + } + event = ToolResultEvent(tool_result) + assert event.tool_use_id == "unique_id_456" + + +class TestToolStreamEvent: + """Tests for ToolStreamEvent.""" + + def test_initialization(self): + """Test ToolStreamEvent initialization.""" + tool_use: ToolUse = { + "toolUseId": "stream_123", + "name": "streaming_tool", + "input": {}, + } + tool_stream_data = {"progress": 50, "status": "processing"} + event = ToolStreamEvent(tool_use, tool_stream_data) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == tool_stream_data + assert event.tool_use_id == "stream_123" + + def test_tool_use_id_property(self): + """Test tool_use_id property returns correct ID.""" + tool_use: ToolUse = { + "toolUseId": "another_stream_456", + "name": "tool", + "input": {}, + } + event = ToolStreamEvent(tool_use, {}) + assert event.tool_use_id == "another_stream_456" + + +class TestModelMessageEvent: + """Tests for ModelMessageEvent.""" + + def test_initialization(self): + """Test ModelMessageEvent initialization.""" + message = Mock(spec=Message) + event = ModelMessageEvent(message) + assert event["message"] == message + + +class TestToolResultMessageEvent: + """Tests for ToolResultMessageEvent.""" + + def test_initialization(self): + """Test ToolResultMessageEvent initialization.""" + message = {"role": "tool", "content": "Tool result message"} + event = ToolResultMessageEvent(message) + assert event["message"] == message + + +class TestForceStopEvent: + """Tests for ForceStopEvent.""" + + def test_initialization_with_string_reason(self): + """Test ForceStopEvent initialization with string reason.""" + reason = "User requested stop" + event = ForceStopEvent(reason) + assert event["force_stop"] is True + assert event["force_stop_reason"] == "User requested stop" + + def test_initialization_with_exception(self): + """Test ForceStopEvent initialization with exception.""" + exception = ValueError("Something went wrong") + event = ForceStopEvent(exception) + assert event["force_stop"] is True + assert event["force_stop_reason"] == "Something went wrong" + + +class TestAgentResultEvent: + """Tests for AgentResultEvent.""" + + def test_initialization(self): + """Test AgentResultEvent initialization.""" + # Mock the AgentResult + agent_result = MagicMock() + agent_result.messages = [] + agent_result.stop_reason = "max_tokens" + + event = AgentResultEvent(agent_result) + assert event["result"] == agent_result + + +class TestEventSerialization: + """Tests for event serialization and conversion.""" + + def test_typed_event_serialization(self): + """Test that TypedEvent can be serialized to dict.""" + event = TypedEvent({"key": "value", "nested": {"data": 123}}) + serialized = event.as_dict() + assert serialized == {"key": "value", "nested": {"data": 123}} + + def test_complex_event_serialization(self): + """Test complex event serialization.""" + delta = Mock(spec=ContentBlockDelta) + delta.to_dict = Mock(return_value={"type": "delta"}) + + event = TextStreamEvent(delta, "Hello") + # The event should be serializable as a dict + assert isinstance(event.as_dict(), dict) + assert event["data"] == "Hello" + + def test_event_inheritance(self): + """Test that all events inherit from TypedEvent.""" + events = [ + InitEventLoopEvent(), + StartEvent(), + StartEventLoopEvent(), + StructuredOutputEvent(SampleModel(name="test", value=1)), + EventLoopThrottleEvent(5), + ForceStopEvent("test"), + ] + + for event in events: + assert isinstance(event, TypedEvent) + assert isinstance(event, dict) + assert hasattr(event, "is_callback_event") + assert hasattr(event, "as_dict") + assert hasattr(event, "prepare") diff --git a/tests/strands/types/test_exceptions.py b/tests/strands/types/test_exceptions.py new file mode 100644 index 000000000..29f68a7d0 --- /dev/null +++ b/tests/strands/types/test_exceptions.py @@ -0,0 +1,387 @@ +"""Tests for exception types in the strands.types.exceptions module.""" + +import pytest + +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + MCPClientInitializationError, + ModelThrottledException, + SessionException, + StructuredOutputException, +) + + +class TestEventLoopException: + """Tests for EventLoopException class.""" + + def test_initialization_with_request_state(self): + """Test EventLoopException initialization with request state.""" + original_exception = ValueError("Original error") + request_state = {"session_id": "123", "user": "test_user"} + + exception = EventLoopException(original_exception, request_state) + + assert exception.original_exception == original_exception + assert exception.request_state == request_state + assert str(exception) == "Original error" + + def test_initialization_without_request_state(self): + """Test EventLoopException initialization without request state.""" + original_exception = RuntimeError("Runtime error") + + exception = EventLoopException(original_exception) + + assert exception.original_exception == original_exception + assert exception.request_state == {} + assert str(exception) == "Runtime error" + + def test_initialization_with_none_request_state(self): + """Test EventLoopException initialization with None request state.""" + original_exception = TypeError("Type error") + + exception = EventLoopException(original_exception, None) + + assert exception.original_exception == original_exception + assert exception.request_state == {} + assert str(exception) == "Type error" + + def test_inheritance(self): + """Test that EventLoopException inherits from Exception.""" + original_exception = Exception("Test") + exception = EventLoopException(original_exception) + + assert isinstance(exception, Exception) + assert issubclass(EventLoopException, Exception) + + def test_exception_message_from_original(self): + """Test that exception message comes from original exception.""" + original_exception = ValueError("Custom error message") + exception = EventLoopException(original_exception) + + assert str(exception) == "Custom error message" + assert exception.args[0] == "Custom error message" + + +class TestMaxTokensReachedException: + """Tests for MaxTokensReachedException class.""" + + def test_initialization_with_message(self): + """Test MaxTokensReachedException initialization with message.""" + message = "Maximum tokens limit of 4096 reached" + exception = MaxTokensReachedException(message) + + assert str(exception) == message + assert exception.args[0] == message + + def test_inheritance(self): + """Test that MaxTokensReachedException inherits from Exception.""" + exception = MaxTokensReachedException("Test message") + + assert isinstance(exception, Exception) + assert issubclass(MaxTokensReachedException, Exception) + + def test_exception_with_detailed_message(self): + """Test exception with detailed message about token limits.""" + message = ( + "Model reached maximum token limit of 8192 tokens. " + "Consider reducing input size or increasing max_tokens parameter." + ) + exception = MaxTokensReachedException(message) + + assert str(exception) == message + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(MaxTokensReachedException) as exc_info: + raise MaxTokensReachedException("Token limit exceeded") + + assert str(exc_info.value) == "Token limit exceeded" + + +class TestContextWindowOverflowException: + """Tests for ContextWindowOverflowException class.""" + + def test_initialization(self): + """Test ContextWindowOverflowException initialization.""" + exception = ContextWindowOverflowException() + + assert isinstance(exception, Exception) + assert str(exception) == "" + + def test_initialization_with_message(self): + """Test ContextWindowOverflowException with custom message.""" + exception = ContextWindowOverflowException("Context window exceeded 100k tokens") + + assert str(exception) == "Context window exceeded 100k tokens" + + def test_inheritance(self): + """Test that ContextWindowOverflowException inherits from Exception.""" + exception = ContextWindowOverflowException() + + assert isinstance(exception, Exception) + assert issubclass(ContextWindowOverflowException, Exception) + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(ContextWindowOverflowException) as exc_info: + raise ContextWindowOverflowException("Input too large for model") + + assert str(exc_info.value) == "Input too large for model" + + +class TestMCPClientInitializationError: + """Tests for MCPClientInitializationError class.""" + + def test_initialization(self): + """Test MCPClientInitializationError initialization.""" + exception = MCPClientInitializationError() + + assert isinstance(exception, Exception) + assert str(exception) == "" + + def test_initialization_with_message(self): + """Test MCPClientInitializationError with custom message.""" + exception = MCPClientInitializationError("Failed to connect to MCP server") + + assert str(exception) == "Failed to connect to MCP server" + + def test_inheritance(self): + """Test that MCPClientInitializationError inherits from Exception.""" + exception = MCPClientInitializationError() + + assert isinstance(exception, Exception) + assert issubclass(MCPClientInitializationError, Exception) + + def test_exception_with_detailed_error(self): + """Test exception with detailed initialization error.""" + message = "MCP server initialization failed: Connection refused on port 8080" + exception = MCPClientInitializationError(message) + + assert str(exception) == message + + +class TestModelThrottledException: + """Tests for ModelThrottledException class.""" + + def test_initialization_with_message(self): + """Test ModelThrottledException initialization with message.""" + message = "Rate limit exceeded. Please retry after 60 seconds." + exception = ModelThrottledException(message) + + assert exception.message == message + assert str(exception) == message + assert exception.args[0] == message + + def test_inheritance(self): + """Test that ModelThrottledException inherits from Exception.""" + exception = ModelThrottledException("Throttled") + + assert isinstance(exception, Exception) + assert issubclass(ModelThrottledException, Exception) + + def test_message_property(self): + """Test that message property is accessible.""" + message = "API rate limit: 10 requests per minute" + exception = ModelThrottledException(message) + + assert exception.message == message + assert hasattr(exception, "message") + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(ModelThrottledException) as exc_info: + raise ModelThrottledException("Service temporarily unavailable") + + assert exc_info.value.message == "Service temporarily unavailable" + assert str(exc_info.value) == "Service temporarily unavailable" + + +class TestSessionException: + """Tests for SessionException class.""" + + def test_initialization(self): + """Test SessionException initialization.""" + exception = SessionException() + + assert isinstance(exception, Exception) + assert str(exception) == "" + + def test_initialization_with_message(self): + """Test SessionException with custom message.""" + exception = SessionException("Session expired") + + assert str(exception) == "Session expired" + + def test_inheritance(self): + """Test that SessionException inherits from Exception.""" + exception = SessionException() + + assert isinstance(exception, Exception) + assert issubclass(SessionException, Exception) + + def test_exception_with_detailed_message(self): + """Test exception with detailed session error.""" + message = "Failed to restore session: Invalid session ID or session has expired" + exception = SessionException(message) + + assert str(exception) == message + + +class TestStructuredOutputException: + """Tests for StructuredOutputException class.""" + + def test_initialization_with_message(self): + """Test StructuredOutputException initialization with message.""" + message = "Failed to validate structured output after 3 attempts" + exception = StructuredOutputException(message) + + assert exception.message == message + assert str(exception) == message + assert exception.args[0] == message + + def test_inheritance(self): + """Test that StructuredOutputException inherits from Exception.""" + exception = StructuredOutputException("Validation failed") + + assert isinstance(exception, Exception) + assert issubclass(StructuredOutputException, Exception) + + def test_message_property(self): + """Test that message property is accessible.""" + message = "Pydantic validation error: field 'name' is required" + exception = StructuredOutputException(message) + + assert exception.message == message + assert hasattr(exception, "message") + + def test_exception_with_validation_details(self): + """Test exception with detailed validation error message.""" + message = ( + "Structured output validation failed:\n" + "- Field 'age' must be a positive integer\n" + "- Field 'email' must be a valid email address" + ) + exception = StructuredOutputException(message) + + assert exception.message == message + assert str(exception) == message + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(StructuredOutputException) as exc_info: + raise StructuredOutputException("Invalid output format") + + assert exc_info.value.message == "Invalid output format" + assert str(exc_info.value) == "Invalid output format" + + +class TestExceptionInheritance: + """Tests for verifying exception inheritance hierarchy.""" + + def test_all_exceptions_inherit_from_exception(self): + """Test that all custom exceptions inherit from Exception.""" + exception_classes = [ + EventLoopException, + MaxTokensReachedException, + ContextWindowOverflowException, + MCPClientInitializationError, + ModelThrottledException, + SessionException, + StructuredOutputException, + ] + + for exc_class in exception_classes: + assert issubclass(exc_class, Exception), f"{exc_class.__name__} should inherit from Exception" + + def test_exception_instances_are_exceptions(self): + """Test that all exception instances are instances of Exception.""" + exceptions = [ + EventLoopException(ValueError("test")), + MaxTokensReachedException("test"), + ContextWindowOverflowException("test"), + MCPClientInitializationError("test"), + ModelThrottledException("test"), + SessionException("test"), + StructuredOutputException("test"), + ] + + for exception in exceptions: + assert isinstance(exception, Exception), f"{type(exception).__name__} instance should be an Exception" + + def test_exceptions_can_be_caught_as_exception(self): + """Test that all custom exceptions can be caught as generic Exception.""" + exceptions_to_raise = [ + (EventLoopException, ValueError("test"), None), + (MaxTokensReachedException, "test", None), + (ContextWindowOverflowException, "test", None), + (MCPClientInitializationError, "test", None), + (ModelThrottledException, "test", None), + (SessionException, "test", None), + (StructuredOutputException, "test", None), + ] + + for exc_class, *args in exceptions_to_raise: + try: + if exc_class == EventLoopException: + raise exc_class(*args) + else: + raise exc_class(args[0]) + except Exception as e: + assert isinstance(e, exc_class) + assert isinstance(e, Exception) + + +class TestExceptionMessages: + """Tests for exception messages and representations.""" + + def test_exception_str_representations(self): + """Test string representations of all exceptions.""" + exceptions = [ + (EventLoopException(ValueError("event loop error")), "event loop error"), + (MaxTokensReachedException("max tokens"), "max tokens"), + (ContextWindowOverflowException("overflow"), "overflow"), + (MCPClientInitializationError("init error"), "init error"), + (ModelThrottledException("throttled"), "throttled"), + (SessionException("session error"), "session error"), + (StructuredOutputException("output error"), "output error"), + ] + + for exception, expected_str in exceptions: + assert str(exception) == expected_str + + def test_exception_repr_contains_class_name(self): + """Test that repr contains the exception class name.""" + exceptions = [ + EventLoopException(ValueError("test")), + MaxTokensReachedException("test"), + ContextWindowOverflowException("test"), + MCPClientInitializationError("test"), + ModelThrottledException("test"), + SessionException("test"), + StructuredOutputException("test"), + ] + + for exception in exceptions: + class_name = type(exception).__name__ + assert class_name in repr(exception) + + def test_exceptions_with_custom_properties(self): + """Test exceptions with custom properties maintain those properties.""" + # EventLoopException with properties + event_loop_exc = EventLoopException(ValueError("test"), {"key": "value"}) + assert hasattr(event_loop_exc, "original_exception") + assert hasattr(event_loop_exc, "request_state") + assert event_loop_exc.original_exception.args[0] == "test" + assert event_loop_exc.request_state == {"key": "value"} + + # ModelThrottledException with message property + throttled_exc = ModelThrottledException("throttle message") + assert hasattr(throttled_exc, "message") + assert throttled_exc.message == "throttle message" + + # StructuredOutputException with message property + structured_exc = StructuredOutputException("validation message") + assert hasattr(structured_exc, "message") + assert structured_exc.message == "validation message" diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py new file mode 100644 index 000000000..ade0fa5e8 --- /dev/null +++ b/tests/strands/types/test_interrupt.py @@ -0,0 +1,80 @@ +import unittest.mock + +import pytest + +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt, InterruptException +from strands.types.interrupt import _Interruptible + + +@pytest.fixture +def interrupt(): + return Interrupt( + id="test_id:test_name", + name="test_name", + reason={"reason": "test"}, + response={"response": "test"}, + ) + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupt_state = InterruptState() + return instance + + +@pytest.fixture +def interrupt_hook_event(agent): + class Event(_Interruptible): + def __init__(self): + self.agent = agent + + def _interrupt_id(self, name): + return f"test_id:{name}" + + return Event() + + +def test_interrupt_hook_event_interrupt(interrupt_hook_event): + with pytest.raises(InterruptException) as exception: + interrupt_hook_event.interrupt("custom_test_name", "custom test reason") + + tru_interrupt = exception.value.interrupt + exp_interrupt = Interrupt( + id="test_id:custom_test_name", + name="custom_test_name", + reason="custom test reason", + ) + assert tru_interrupt == exp_interrupt + + +def test_interrupt_hook_event_interrupt_state(agent, interrupt_hook_event): + with pytest.raises(InterruptException): + interrupt_hook_event.interrupt("custom_test_name", "custom test reason") + + exp_interrupt = Interrupt( + id="test_id:custom_test_name", + name="custom_test_name", + reason="custom test reason", + ) + assert exp_interrupt.id in agent._interrupt_state.interrupts + + tru_interrupt = agent._interrupt_state.interrupts[exp_interrupt.id] + assert tru_interrupt == exp_interrupt + + +def test_interrupt_hook_event_interrupt_response(interrupt, agent, interrupt_hook_event): + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + tru_response = interrupt_hook_event.interrupt("test_name") + exp_response = {"response": "test"} + assert tru_response == exp_response + + +def test_interrupt_hook_event_interrupt_response_empty(interrupt, agent, interrupt_hook_event): + interrupt.response = None + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + with pytest.raises(InterruptException): + interrupt_hook_event.interrupt("test_name") diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index c39615c32..26d4062e4 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -1,7 +1,10 @@ import json +import unittest.mock from uuid import uuid4 from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager +from strands.agent.interrupt import InterruptState +from strands.agent.state import AgentState from strands.types.session import ( Session, SessionAgent, @@ -91,3 +94,38 @@ def test_session_message_with_bytes(): assert original_message["role"] == message["role"] assert original_message["content"][0]["text"] == message["content"][0]["text"] assert original_message["content"][1]["binary_data"] == message["content"][1]["binary_data"] + + +def test_session_agent_from_agent(): + agent = unittest.mock.Mock() + agent.agent_id = "a1" + agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) + agent.state = AgentState({"test": "state"}) + agent._interrupt_state = InterruptState(interrupts={}, context={}, activated=False) + + tru_session_agent = SessionAgent.from_agent(agent) + exp_session_agent = SessionAgent( + agent_id="a1", + conversation_manager_state={"test": "conversation"}, + state={"test": "state"}, + _internal_state={"interrupt_state": {"interrupts": {}, "context": {}, "activated": False}}, + created_at=unittest.mock.ANY, + updated_at=unittest.mock.ANY, + ) + assert tru_session_agent == exp_session_agent + + +def test_session_agent_initialize_internal_state(): + agent = unittest.mock.Mock() + session_agent = SessionAgent( + agent_id="a1", + conversation_manager_state={}, + state={}, + _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, + ) + + session_agent.initialize_internal_state(agent) + + tru_interrupt_state = agent._interrupt_state + exp_interrupt_state = InterruptState(interrupts={}, context={"test": "init"}, activated=False) + assert tru_interrupt_state == exp_interrupt_state diff --git a/tests_integ/fixtures/say_tool.py b/tests_integ/fixtures/say_tool.py new file mode 100644 index 000000000..454f28240 --- /dev/null +++ b/tests_integ/fixtures/say_tool.py @@ -0,0 +1,7 @@ +from strands import tool + + +@tool +def say(input: str) -> str: + """Say the input""" + return f"Said: {input}" diff --git a/tests_integ/fixtures/test_agent.json b/tests_integ/fixtures/test_agent.json new file mode 100644 index 000000000..e1ffad249 --- /dev/null +++ b/tests_integ/fixtures/test_agent.json @@ -0,0 +1,6 @@ +{ + "model": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "tools": ["tests_integ.fixtures.say_tool:say"], + "prompt": "You use the say tool to communicate", + "name": "Sayer" +} \ No newline at end of file diff --git a/tests_integ/interrupts/__init__.py b/tests_integ/interrupts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/interrupts/test_hook.py b/tests_integ/interrupts/test_hook.py new file mode 100644 index 000000000..f4341ac76 --- /dev/null +++ b/tests_integ/interrupts/test_hook.py @@ -0,0 +1,157 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider +from strands.interrupt import Interrupt + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + if event.tool_use["name"] == "weather_tool": + return + + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool, weather_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool, weather_tool]) + + +def test_interrupt(agent): + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:00", "sunny"]) + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "sunny"}, + ], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "12:00"}, + ], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message + + +def test_interrupt_reject(agent): + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + interrupt = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "REJECT", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [{"text": "sunny"}], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "error", + "content": [{"text": "tool rejected"}], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message diff --git a/tests_integ/interrupts/test_session.py b/tests_integ/interrupts/test_session.py new file mode 100644 index 000000000..714363fd8 --- /dev/null +++ b/tests_integ/interrupts/test_session.py @@ -0,0 +1,78 @@ +import json + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider +from strands.session import FileSessionManager + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + if event.tool_use["name"] == "weather_tool": + return + + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool, weather_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool, weather_tool]) + + +def test_interrupt_session(interrupt_hook, time_tool, weather_tool, tmpdir): + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + interrupt = result.interrupts[0] + + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:00", "sunny"]) diff --git a/tests_integ/interrupts/test_tool.py b/tests_integ/interrupts/test_tool.py new file mode 100644 index 000000000..e200f50a6 --- /dev/null +++ b/tests_integ/interrupts/test_tool.py @@ -0,0 +1,162 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider +from strands.interrupt import Interrupt +from strands.types.tools import ToolContext + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + if event.tool_use["name"] != "time_tool": + return + + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool", context=True) + def func(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_interrupt", reason="need time") + + return func + + +@pytest.fixture +def day_tool(): + @tool(name="day_tool", context=True) + def func(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_interrupt", reason="need day") + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func() -> str: + return "sunny" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool, day_tool, weather_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool, day_tool, weather_tool]) + + +def test_interrupt(agent): + result = agent("What is the time, day, and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = sorted(result.interrupts, key=lambda interrupt: interrupt.reason) + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="test_interrupt", + reason="need day", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt_approval, interrupt_day = result.interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt_approval.id, + "response": "APPROVE", + }, + }, + { + "interruptResponse": { + "interruptId": interrupt_day.id, + "response": "monday", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need time", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt_time = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt_time.id, + "response": "12:01", + }, + }, + ] + result = agent(responses) + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:01", "monday", "sunny"]) + + tru_tool_results = agent.messages[-2]["content"] + tru_tool_results.sort(key=lambda content: content["toolResult"]["content"][0]["text"]) + + exp_tool_results = [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "12:01"}, + ], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "monday"}, + ], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "sunny"}, + ], + }, + }, + ] + assert tru_tool_results == exp_tool_results diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index 160ad5af9..e15065a4a 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -15,7 +15,11 @@ $ python echo_server.py """ +import base64 +from typing import Literal + from mcp.server import FastMCP +from mcp.types import BlobResourceContents, EmbeddedResource, TextResourceContents from pydantic import BaseModel @@ -46,6 +50,48 @@ def echo(to_echo: str) -> str: def echo_with_structured_content(to_echo: str) -> EchoResponse: return EchoResponse(echoed=to_echo, message_length=len(to_echo)) + @mcp.tool(description="Get current weather information for a location") + def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): + """Get weather data including forecasts and alerts for the specified location""" + if location.lower() == "new york": + return [ + EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri="https://weather.api/forecast/nyc", + mimeType="text/plain", + text="Current weather in New York: 72°F, partly cloudy with light winds.", + ), + ) + ] + elif location.lower() == "london": + return [ + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + uri="https://weather.api/data/london.json", + mimeType="application/json", + blob=base64.b64encode( + '{"temperature": 18, "condition": "rainy", "humidity": 85}'.encode() + ).decode(), + ), + ) + ] + elif location.lower() == "tokyo": + # Read yellow.png file for weather icon + with open("tests_integ/yellow.png", "rb") as image_file: + png_data = image_file.read() + return [ + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + uri="https://weather.api/icons/sunny.png", + mimeType="image/png", + blob=base64.b64encode(png_data).decode(), + ), + ) + ] + mcp.run(transport="stdio") diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 9d5ab5f13..2c9bb73e1 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -272,6 +272,100 @@ def transport_callback() -> MCPTransport: assert "Hello, Charlie!" in prompt_text +def test_mcp_client_embedded_resources(): + """Test that MCP client properly handles EmbeddedResource content types.""" + embedded_resource_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with embedded_resource_mcp_client: + # Test text embedded resource + text_result = embedded_resource_mcp_client.call_tool_sync( + tool_use_id="test-embedded-text", + name="get_weather", + arguments={"location": "New York"}, + ) + assert text_result["status"] == "success" + assert len(text_result["content"]) == 1 + assert "72°F" in text_result["content"][0]["text"] + assert "partly cloudy" in text_result["content"][0]["text"] + + # Test JSON embedded resource (blob with textual MIME type) + json_result = embedded_resource_mcp_client.call_tool_sync( + tool_use_id="test-embedded-json", + name="get_weather", + arguments={"location": "London"}, + ) + assert json_result["status"] == "success" + assert len(json_result["content"]) == 1 + json_content = json_result["content"][0]["text"] + assert "temperature" in json_content + assert "rainy" in json_content + + # Test image embedded resource + image_result = embedded_resource_mcp_client.call_tool_sync( + tool_use_id="test-embedded-image", + name="get_weather", + arguments={"location": "Tokyo"}, + ) + assert image_result["status"] == "success" + assert len(image_result["content"]) == 1 + assert "image" in image_result["content"][0] + assert image_result["content"][0]["image"]["format"] == "png" + assert "bytes" in image_result["content"][0]["image"]["source"] + + +@pytest.mark.asyncio +async def test_mcp_client_embedded_resources_async(): + """Test that async MCP client properly handles EmbeddedResource content types.""" + embedded_resource_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with embedded_resource_mcp_client: + # Test text embedded resource async + text_result = await embedded_resource_mcp_client.call_tool_async( + tool_use_id="test-embedded-text-async", + name="get_weather", + arguments={"location": "New York"}, + ) + assert text_result["status"] == "success" + assert len(text_result["content"]) == 1 + assert "72°F" in text_result["content"][0]["text"] + + # Test JSON embedded resource async + json_result = await embedded_resource_mcp_client.call_tool_async( + tool_use_id="test-embedded-json-async", + name="get_weather", + arguments={"location": "London"}, + ) + assert json_result["status"] == "success" + assert len(json_result["content"]) == 1 + json_content = json_result["content"][0]["text"] + assert "temperature" in json_content + + +def test_mcp_client_embedded_resources_with_agent(): + """Test that embedded resources work correctly when used with Agent.""" + embedded_resource_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with embedded_resource_mcp_client: + tools = embedded_resource_mcp_client.list_tools_sync() + agent = Agent(tools=tools) + + # Test that agent can successfully use tools that return embedded resources + result = agent("Get the weather for New York and tell me what it says") + + # Check that the agent successfully processed the embedded resource + assert result.message is not None + response_text = " ".join([block["text"] for block in result.message["content"] if "text" in block]).lower() + + # The agent should have received and processed the embedded weather content + assert any(["72" in response_text, "partly cloudy" in response_text, "weather" in response_text]) + + def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py new file mode 100644 index 000000000..7914bb326 --- /dev/null +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -0,0 +1,160 @@ +"""Integration tests for MCPClient ToolProvider functionality with real MCP server.""" + +import logging +import re + +import pytest +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_client import ToolFilters + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +def test_mcp_client_tool_provider_filters(): + """Test MCPClient with various filter combinations.""" + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 20 + + filters: ToolFilters = { + "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], + "rejected": ["echo_with_delay"], + } + + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="test", + ) + + agent = Agent(tools=[client]) + tool_names = agent.tool_names + + assert "test_echo_with_delay" not in [name for name in tool_names] + assert all(name.startswith("test_") for name in tool_names) + + agent.cleanup() + + +def test_mcp_client_tool_provider_execution(): + """Test that MCPClient works with agent execution.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="filtered", + ) + + agent = Agent(tools=[client]) + + assert "filtered_echo" in agent.tool_names + + tool_result = agent.tool.filtered_echo(to_echo="Hello World") + assert "Hello World" in str(tool_result) + + result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") + assert "Integration Test" in str(result) + + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 + + agent.cleanup() + + +def test_mcp_client_tool_provider_reuse(): + """Test that a single MCPClient can be used across multiple agents.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="shared", + ) + + agent1 = Agent(tools=[client]) + assert "shared_echo" in agent1.tool_names + + result1 = agent1.tool.shared_echo(to_echo="Agent 1") + assert "Agent 1" in str(result1) + + agent2 = Agent(tools=[client]) + assert "shared_echo" in agent2.tool_names + + result2 = agent2.tool.shared_echo(to_echo="Agent 2") + assert "Agent 2" in str(result2) + + assert len(agent1.tool_names) == len(agent2.tool_names) + assert agent1.tool_names == agent2.tool_names + + agent1.cleanup() + + # Agent 1 cleans up - client should still be active for agent 2 + agent1.cleanup() + + # Agent 2 should still be able to use the tool + result2 = agent2.tool.shared_echo(to_echo="Agent 2 Test") + assert "Agent 2 Test" in str(result2) + + agent2.cleanup() + + +def test_mcp_client_multiple_servers(): + """Test MCPClient with multiple MCP servers simultaneously.""" + client1 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo"]}, + prefix="server1", + ) + client2 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo_with_structured_content"]}, + prefix="server2", + ) + + agent = Agent(tools=[client1, client2]) + + assert "server1_echo" in agent.tool_names + assert "server2_echo_with_structured_content" in agent.tool_names + assert len(agent.tool_names) == 2 + + result1 = agent.tool.server1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup() + + +def test_mcp_client_server_startup_failure(): + """Test that MCPClient handles server startup failure gracefully without hanging.""" + from strands.types.exceptions import ToolProviderException + + failing_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), + startup_timeout=2, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[failing_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) + + +def test_mcp_client_server_connection_timeout(): + """Test that MCPClient times out gracefully when server hangs during startup.""" + from strands.types.exceptions import ToolProviderException + + hanging_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), + startup_timeout=1, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[hanging_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index c1f442b2a..75cc58f74 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -131,7 +131,7 @@ def __init__(self): id="gemini", environment_variable="GOOGLE_API_KEY", factory=lambda: GeminiModel( - api_key=os.getenv("GOOGLE_API_KEY"), + client_args={"api_key": os.getenv("GOOGLE_API_KEY")}, model_id="gemini-2.5-flash", params={"temperature": 0.7}, ), diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index eaef1eb88..36c21fb7f 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -57,6 +57,21 @@ class Weather(BaseModel): agent = Agent(model) result = agent.structured_output(Weather, "How are you?") + assert isinstance(result, Weather) - assert len(result.time) > 0 - assert len(result.weather) > 0 + +def test_structured_output_is_forced_when_provided_in_agent_invocation(skip_for, model): + """Tests that structured_output is always forced to return a value even if model doesn't have any information.""" + + class UserProfile(BaseModel): + """Basic user profile model.""" + + name: str + age: int + occupation: str + + agent = Agent() + result = agent("Create a profile for John who is a 25 year old dentist", structured_output_model=UserProfile) + assert result.structured_output.name == "John" + assert result.structured_output.age == 25 + assert result.structured_output.occupation == "dentist" diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 6cfdd3038..b348c29f4 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,3 +1,5 @@ +import unittest.mock + import pydantic import pytest @@ -40,6 +42,37 @@ class Weather(pydantic.BaseModel): return Weather(time="12:00", weather="sunny") +class Location(pydantic.BaseModel): + """Location information.""" + + city: str = pydantic.Field(description="The city name") + country: str = pydantic.Field(description="The country name") + + +class WeatherCondition(pydantic.BaseModel): + """Weather condition details.""" + + condition: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')") + temperature: int = pydantic.Field(description="Temperature in Celsius") + + +class NestedWeather(pydantic.BaseModel): + """Weather report with nested location and condition information.""" + + time: str = pydantic.Field(description="The time in HH:MM format") + location: Location = pydantic.Field(description="Location information") + weather: WeatherCondition = pydantic.Field(description="Weather condition details") + + +@pytest.fixture +def nested_weather(): + return NestedWeather( + time="12:00", + location=Location(city="New York", country="USA"), + weather=WeatherCondition(condition="sunny", temperature=25), + ) + + @pytest.fixture def yellow_color(): class Color(pydantic.BaseModel): @@ -88,6 +121,22 @@ async def test_agent_stream_async(agent): assert all(string in text for string in ["12:00", "sunny"]) +def test_agent_invoke_reasoning(agent, model): + model.update_config( + params={ + "thinking": { + "budget_tokens": 1024, + "type": "enabled", + }, + }, + ) + + result = agent("Please reason about the equation 2+2.") + + assert "reasoningContent" in result.message["content"][0] + assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] + + def test_structured_output(agent, weather): tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather @@ -134,3 +183,31 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): tru_color = agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +def test_structured_output_unsupported_model(model, nested_weather): + # Mock supports_response_schema to return False to test fallback mechanism + with ( + unittest.mock.patch.multiple( + "strands.models.litellm", + supports_response_schema=unittest.mock.DEFAULT, + ) as mocks, + unittest.mock.patch.object( + model, "_structured_output_using_tool", wraps=model._structured_output_using_tool + ) as mock_tool, + unittest.mock.patch.object( + model, "_structured_output_using_response_schema", wraps=model._structured_output_using_response_schema + ) as mock_schema, + ): + mocks["supports_response_schema"].return_value = False + + # Test that structured output still works via tool calling fallback + agent = Agent(model=model) + prompt = "The time is 12:00 in New York, USA and the weather is sunny with temperature 25 degrees Celsius" + tru_weather = agent.structured_output(NestedWeather, prompt) + exp_weather = nested_weather + assert tru_weather == exp_weather + + # Verify that the tool method was called and schema method was not + mock_tool.assert_called_once() + mock_schema.assert_not_called() diff --git a/tests_integ/test_agent_json.py b/tests_integ/test_agent_json.py new file mode 100644 index 000000000..387cfd172 --- /dev/null +++ b/tests_integ/test_agent_json.py @@ -0,0 +1,13 @@ +from strands.experimental import config_to_agent + + +def test_load_agent_from_config(): + agent = config_to_agent("file://tests_integ/fixtures/test_agent.json") + + result = agent("Say hello") + + assert "Sayer" == agent.name + assert "You use the say tool to communicate" == agent.system_prompt + assert agent.tool_names[0] == "say" + assert agent.model.get_config().get("model_id") == "global.anthropic.claude-sonnet-4-5-20250929-v1:0" + assert "hello" in str(result).lower() diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py new file mode 100644 index 000000000..188f57777 --- /dev/null +++ b/tests_integ/test_structured_output_agent_loop.py @@ -0,0 +1,330 @@ +""" +Comprehensive integration tests for structured output passed into the agent functionality. +""" + +from typing import List, Optional + +import pytest +from pydantic import BaseModel, Field, field_validator + +from strands import Agent +from strands.tools import tool + +# ========== Pydantic Models from notebook ========== + + +class MathResult(BaseModel): + """Math operation result.""" + + operation: str = Field(description="the performed operation") + result: int = Field(description="the result of the operation") + + +class UserProfile(BaseModel): + """Basic user profile model.""" + + name: str + age: int + occupation: str + active: bool = True + + +class Address(BaseModel): + """Address information.""" + + street: str + city: str + state: str + zip_code: str + + +class Contact(BaseModel): + """Contact information.""" + + email: str + phone: Optional[str] = None + preferred_method: str = "email" + + +class Employee(BaseModel): + """Complex nested employee model.""" + + name: str + employee_id: int + department: str + address: Address + contact: Contact + skills: List[str] + hire_date: str + salary_range: str + + +class ProductReview(BaseModel): + """Product review analysis.""" + + product_name: str + rating: int = Field(ge=1, le=5, description="Rating from 1-5 stars") + sentiment: str = Field(pattern="^(positive|negative|neutral)$") + key_points: List[str] + would_recommend: bool + + +class WeatherForecast(BaseModel): + """Weather forecast data.""" + + location: str + temperature: int + condition: str + humidity: int + wind_speed: int + forecast_date: str + + +class TaskList(BaseModel): + """Task management structure.""" + + project_name: str + tasks: List[str] + priority: str = Field(pattern="^(high|medium|low)$") + due_date: str + estimated_hours: int + + +class Person(BaseModel): + """A person's basic information.""" + + name: str = Field(description="Full name") + age: int = Field(description="Age in years", ge=0, le=150) + + +class Company(BaseModel): + """A company or organization.""" + + name: str = Field(description="Company name") + address: Address = Field(description="Company address") + employees: List[Person] = Field(description="list of persons") + + +class Task(BaseModel): + """A task or todo item.""" + + title: str = Field(description="Task title") + description: str = Field(description="Detailed description") + priority: str = Field(description="Priority level: low, medium, high") + completed: bool = Field(description="Whether task is completed", default=False) + + +class NameWithValidation(BaseModel): + """Name model with validation that forces retry.""" + + first_name: str + + @field_validator("first_name") + @classmethod + def validate_first_name(cls, value: str) -> str: + if not value.endswith("abc"): + raise ValueError("You must append 'abc' to the end of my name") + return value + + +# ========== Tool Definitions ========== + + +@tool +def calculator(operation: str, a: float, b: float) -> float: + """Simple calculator tool for testing.""" + if operation == "add": + return a + b + elif operation == "subtract": + return a - b + elif operation == "multiply": + return a * b + elif operation == "divide": + return b / a if a != 0 else 0 + elif operation == "power": + return a**b + else: + return 0 + + +# ========== Test Classes ========== + + +class TestBasicStructuredOutput: + """Test basic structured output functionality.""" + + def test_regular_call_without_structured_output(self): + """Test that regular calls work without structured output.""" + agent = Agent() + result = agent("What can you do for me?") + + assert result.structured_output is None + assert agent._default_structured_output_model is None + + def test_simple_structured_output(self): + """Test basic structured output with UserProfile.""" + agent = Agent() + + result = agent( + "Create a profile for John Doe who is a 25 year old dentist", structured_output_model=UserProfile + ) + + assert result.structured_output is not None + assert isinstance(result.structured_output, UserProfile) + assert result.structured_output.name == "John Doe" + assert result.structured_output.age == 25 + assert result.structured_output.occupation.lower() == "dentist" + + def test_follow_up_without_structured_output(self): + """Test that follow-up calls work without structured output.""" + agent = Agent() + + # First call with structured output + result1 = agent( + "Create a profile for John Doe who is a 25 year old dentist", structured_output_model=UserProfile + ) + assert result1.structured_output is not None + + # Second call without structured output + result2 = agent("what did you just do?") + assert result2.structured_output is None + + +class TestToolUsage: + """Test structured output with tool usage.""" + + def test_tool_use_without_structured_output(self): + """Test tool usage without structured output.""" + agent = Agent(tools=[calculator]) + + result = agent("What is 2 + 2? Use the calculator tool.") + + assert result.structured_output is None + # Check that tool was called (in metrics) + assert result.metrics.tool_metrics is not None + assert len(result.metrics.tool_metrics) > 0 + + def test_tool_use_with_structured_output(self): + """Test tool usage with structured output.""" + agent = Agent(tools=[calculator]) + + result = agent("Calculate 2 + 2 using the calculator tool", structured_output_model=MathResult) + + assert result.structured_output is not None + assert isinstance(result.structured_output, MathResult) + assert result.structured_output.result == 4 + # Check that tool was called + assert result.metrics.tool_metrics is not None + assert len(result.metrics.tool_metrics) > 0 + + +class TestAsyncOperations: + """Test async operations with structured output.""" + + @pytest.mark.asyncio + async def test_async_structured_output(self): + """Test async invocation with structured output.""" + agent = Agent() + + result = await agent.invoke_async( + """ + Analyze this product review: + "This wireless mouse is fantastic! Great battery life, smooth tracking, + and the ergonomic design is perfect for long work sessions. The price + is reasonable too. I'd definitely buy it again and recommend it to others. + Rating: 5 stars" + """, + structured_output_model=ProductReview, + ) + + assert result.structured_output is not None + assert isinstance(result.structured_output, ProductReview) + assert result.structured_output.rating == 5 + assert result.structured_output.sentiment == "positive" + assert result.structured_output.would_recommend is True + + +class TestStreamingOperations: + """Test streaming with structured output.""" + + @pytest.mark.asyncio + async def test_streaming_with_structured_output(self): + """Test streaming with structured output.""" + agent = Agent() + + result_found = False + structured_output_found = False + + async for event in agent.stream_async( + "Generate a weather forecast for Seattle: 68°F, partly cloudy, 55% humidity, 8 mph winds, for tomorrow", + structured_output_model=WeatherForecast, + ): + if "result" in event: + result_found = True + if event["result"].structured_output: + structured_output_found = True + forecast = event["result"].structured_output + assert isinstance(forecast, WeatherForecast) + assert forecast.location == "Seattle" + + assert result_found, "No result event found in stream" + assert structured_output_found, "No structured output found in stream result" + + +class TestMultipleInvocations: + """Test multiple invocations with different structured output models.""" + + def test_multiple_invocations_different_models(self): + """Test using different structured output models in consecutive calls.""" + agent = Agent() + + # First invocation with Person model + person_result = agent("Extract person: John Doe, 35, john@test.com", structured_output_model=Person) + assert person_result.structured_output is not None + assert isinstance(person_result.structured_output, Person) + assert person_result.structured_output.name == "John Doe" + assert person_result.structured_output.age == 35 + + # Second invocation with Task model + task_result = agent("Create task: Review code, high priority, completed", structured_output_model=Task) + assert task_result.structured_output is not None + assert isinstance(task_result.structured_output, Task) + assert task_result.structured_output.title == "Review code" + assert task_result.structured_output.priority == "high" + assert task_result.structured_output.completed is True + + # Third invocation without structured output + normal_result = agent("What tasks do we have?") + assert normal_result.structured_output is None + + +class TestAgentInitialization: + """Test agent initialization with default structured output model.""" + + def test_agent_with_default_structured_output(self): + """Test agent initialized with default structured output model.""" + agent = Agent(structured_output_model=UserProfile) + + result = agent("Create a profile for John Doe who is a 25 year old dentist") + + assert result.structured_output is not None + assert isinstance(result.structured_output, UserProfile) + assert result.structured_output.name == "John Doe" + assert result.structured_output.age == 25 + assert result.structured_output.occupation.lower() == "dentist" + + +class TestValidationRetry: + """Test validation with retry logic.""" + + def test_validation_forces_retry(self): + """Test that validation errors force the model to retry.""" + agent = Agent() + + result = agent("What's Aaron's name?", structured_output_model=NameWithValidation) + + assert result.structured_output is not None + assert isinstance(result.structured_output, NameWithValidation) + # The model should have learned to append 'abc' after validation failure + assert result.structured_output.first_name.endswith("abc") + assert "Aaron" in result.structured_output.first_name or "aaron" in result.structured_output.first_name.lower() diff --git a/tests_integ/test_summarizing_conversation_manager_integration.py b/tests_integ/test_summarizing_conversation_manager_integration.py index b205c723f..91fb5b910 100644 --- a/tests_integ/test_summarizing_conversation_manager_integration.py +++ b/tests_integ/test_summarizing_conversation_manager_integration.py @@ -372,3 +372,39 @@ def test_dedicated_summarization_agent(model, summarization_model): break assert summary_text + + +def test_summarization_with_tool_messages_and_no_tools(): + agent = Agent( + messages=[ + {"role": "user", "content": [{"text": "What is the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "t1", "name": "time_tool", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "content": [{"text": "12:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 12:00."}]}, + {"role": "user", "content": [{"text": "Thank you"}]}, + {"role": "assistant", "content": [{"text": "You are welcome."}]}, + ], + ) + + conversation_manager = SummarizingConversationManager(summary_ratio=1, preserve_recent_messages=2) + conversation_manager.reduce_context(agent) + + assert len(agent.tool_names) == 0 + assert len(agent.messages) == 3 + + summary = str(agent.messages[0]).lower() + assert "12:00" in summary