From 2a26ffad8bc7379358bc2535d9ce1ec290fea0af Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Tue, 7 Oct 2025 22:43:53 +0400 Subject: [PATCH 01/26] fix(litellm): map LiteLLM context-window errors to ContextWindowOverflowException (#994) --- src/strands/models/litellm.py | 31 +++++++++++++++++++++------- tests/strands/models/test_litellm.py | 12 +++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..1763f5dec 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -8,11 +8,13 @@ from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast import litellm +from litellm.exceptions import ContextWindowExceededError from litellm.utils import supports_response_schema from pydantic import BaseModel from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys @@ -135,7 +137,11 @@ async def stream( logger.debug("request=<%s>", request) logger.debug("invoking model") - response = await litellm.acompletion(**self.client_args, **request) + try: + response = await litellm.acompletion(**self.client_args, **request) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -205,15 +211,24 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - if not supports_response_schema(self.get_config()["model_id"]): + supports_schema = supports_response_schema(self.get_config()["model_id"]) + + # If the provider does not support response schemas, we cannot reliably parse structured output. + # In that case we must not call the provider and must raise the documented ValueError. + if not supports_schema: raise ValueError("Model does not support response_format") - response = await litellm.acompletion( - **self.client_args, - model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], - response_format=output_model, - ) + # For providers that DO support response schemas, call litellm and map context-window errors. + try: + response = await litellm.acompletion( + **self.client_args, + model=self.get_config()["model_id"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bc81fc819..776ae7bae 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -3,9 +3,11 @@ import pydantic import pytest +from litellm.exceptions import ContextWindowExceededError import strands from strands.models.litellm import LiteLLMModel +from strands.types.exceptions import ContextWindowOverflowException @pytest.fixture @@ -332,3 +334,13 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_context_window_maps_to_typed_exception(litellm_acompletion, model): + """Test that a typed ContextWindowExceededError is mapped correctly.""" + litellm_acompletion.side_effect = ContextWindowExceededError(message="test error", model="x", llm_provider="y") + + with pytest.raises(ContextWindowOverflowException): + async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): + pass From 171779ab50198833b710df29b514f94f0327750e Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 8 Oct 2025 15:03:29 -0400 Subject: [PATCH 02/26] feat: Refactor and update tool loading to support modules (#989) * feat: Refactor and update tool loading to support modules * Update registry.py * feat: Address pr feedback * Update src/strands/tools/registry.py Co-authored-by: Patrick Gray * Update src/strands/tools/loader.py Co-authored-by: Patrick Gray --------- Co-authored-by: Patrick Gray --- .github/workflows/test-lint.yml | 5 + src/strands/tools/loader.py | 152 +++++++++++++++++- src/strands/tools/registry.py | 142 +++++++++------- tests/fixtures/say_tool.py | 17 ++ .../tool_with_spec_but_no_function.py | 1 + ...ool_with_spec_but_non_callable_function.py | 3 + tests/strands/tools/test_loader.py | 9 +- tests/strands/tools/test_registry.py | 98 ++++++++++- 8 files changed, 364 insertions(+), 63 deletions(-) create mode 100644 tests/fixtures/say_tool.py create mode 100644 tests/fixtures/tool_with_spec_but_no_function.py create mode 100644 tests/fixtures/tool_with_spec_but_non_callable_function.py diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 291874dce..e38942b2c 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -66,6 +66,11 @@ jobs: id: tests run: hatch test tests --cover continue-on-error: false + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} lint: name: Lint runs-on: ubuntu-latest diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 5935077db..31e8dc788 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -5,7 +5,10 @@ import os import sys import warnings +from importlib.machinery import ModuleSpec from pathlib import Path +from posixpath import expanduser +from types import ModuleType from typing import List, cast from ..types.tools import AgentTool @@ -15,16 +18,151 @@ logger = logging.getLogger(__name__) +def load_tool_from_string(tool_string: str) -> List[AgentTool]: + """Load tools follows strands supported input string formats. + + This function can load a tool based on a string in the following ways: + 1. Local file path to a module based tool: `./path/to/module/tool.py` + 2. Module import path + 2.1. Path to a module based tool: `strands_tools.file_read` + 2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool` + 2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say` + """ + # Case 1: Local file path to a tool + # Ex: ./path/to/my_cool_tool.py + tool_path = expanduser(tool_string) + if os.path.exists(tool_path): + return load_tools_from_file_path(tool_path) + + # Case 2: Module import path + # Ex: test.fixtures.say_tool:say (Load specific @tool decorated function) + # Ex: strands_tools.file_read (Load all @tool decorated functions, or module tool) + return load_tools_from_module_path(tool_string) + + +def load_tools_from_file_path(tool_path: str) -> List[AgentTool]: + """Load module from specified path, and then load tools from that module. + + This function attempts to load the passed in path as a python module, and if it succeeds, + then it tries to import strands tool(s) from that module. + """ + abs_path = str(Path(tool_path).resolve()) + logger.debug("tool_path=<%s> | loading python tool from path", abs_path) + + # Load the module by spec + + # Using this to determine the module name + # ./path/to/my_cool_tool.py -> my_cool_tool + module_name = os.path.basename(tool_path).split(".")[0] + + # This function imports a module based on its path, and gives it the provided name + + spec: ModuleSpec = cast(ModuleSpec, importlib.util.spec_from_file_location(module_name, abs_path)) + if not spec: + raise ImportError(f"Could not create spec for {module_name}") + if not spec.loader: + raise ImportError(f"No loader available for {module_name}") + + module = importlib.util.module_from_spec(spec) + # Load, or re-load, the module + sys.modules[module_name] = module + # Execute the module to run any top level code + spec.loader.exec_module(module) + + return load_tools_from_module(module, module_name) + + +def load_tools_from_module_path(module_tool_path: str) -> list[AgentTool]: + """Load strands tool from a module path. + + Example module paths: + my.module.path + my.module.path:tool_name + """ + if ":" in module_tool_path: + module_path, tool_func_name = module_tool_path.split(":") + else: + module_path, tool_func_name = (module_tool_path, None) + + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as e: + raise AttributeError(f'Tool string: "{module_tool_path}" is not a valid tool string.') from e + + # If a ':' is present in the string, then its a targeted function in a module + if tool_func_name: + if hasattr(module, tool_func_name): + target_tool = getattr(module, tool_func_name) + if isinstance(target_tool, DecoratedFunctionTool): + return [target_tool] + + raise AttributeError(f"Tool {tool_func_name} not found in module {module_path}") + + # Else, try to import all of the @tool decorated tools, or the module based tool + module_name = module_path.split(".")[-1] + return load_tools_from_module(module, module_name) + + +def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTool]: + """Load tools from a module. + + First checks if the passed in module has instances of DecoratedToolFunction classes as atributes to the module. + If so, then it returns them as a list of tools. If not, then it attempts to load the module as a module based tool. + """ + logger.debug("tool_name=<%s>, module=<%s> | loading tools from module", module_name, module_name) + + # Try and see if any of the attributes in the module are function-based tools decorated with @tool + # This means that there may be more than one tool available in this module, so we load them all + + function_tools: List[AgentTool] = [] + # Function tools will appear as attributes in the module + for attr_name in dir(module): + attr = getattr(module, attr_name) + # Check if the module attribute is a DecoratedFunctiontool + if isinstance(attr, DecoratedFunctionTool): + logger.debug("tool_name=<%s>, module=<%s> | found function-based tool in module", attr_name, module_name) + function_tools.append(cast(AgentTool, attr)) + + if function_tools: + return function_tools + + # Finally, if no DecoratedFunctionTools are found in the module, fall back + # to module based tools, and search for TOOL_SPEC + function + module_tool_name = module_name + tool_spec = getattr(module, "TOOL_SPEC", None) + if not tool_spec: + raise AttributeError( + f"The module {module_tool_name} is not a valid module for loading tools." + "This module must contain @tool decorated function(s), or must be a module based tool." + ) + + # If this is a module based tool, the module should have a function with the same name as the module itself + if not hasattr(module, module_tool_name): + raise AttributeError(f"Module-based tool {module_tool_name} missing function {module_tool_name}") + + tool_func = getattr(module, module_tool_name) + if not callable(tool_func): + raise TypeError(f"Tool {module_tool_name} function is not callable") + + return [PythonAgentTool(module_tool_name, tool_spec, tool_func)] + + class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: - """Load a Python tool module and return all discovered function-based tools as a list. + """DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list. This method always returns a list of AgentTool (possibly length 1). It is the canonical API for retrieving multiple tools from a single Python file. """ + warnings.warn( + "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", + DeprecationWarning, + stacklevel=2, + ) try: # Support module:function style (e.g. package.module:function) if not os.path.exists(tool_path) and ":" in tool_path: @@ -108,7 +246,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: """ warnings.warn( "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " - "Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.", + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", DeprecationWarning, stacklevel=2, ) @@ -127,7 +265,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: """ warnings.warn( "ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. " - "Use ToolLoader.load_tools(...) which always returns a list of AgentTool.", + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", DeprecationWarning, stacklevel=2, ) @@ -140,7 +278,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: @classmethod def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: - """Load tools from a file based on its file extension. + """DEPRECATED: Load tools from a file based on its file extension. Args: tool_path: Path to the tool file. @@ -154,6 +292,12 @@ def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: ValueError: If the tool file has an unsupported extension. Exception: For other errors during tool loading. """ + warnings.warn( + "ToolLoader.load_tools is deprecated and will be removed in Strands SDK 2.0. " + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", + DeprecationWarning, + stacklevel=2, + ) ext = Path(tool_path).suffix.lower() abs_path = str(Path(tool_path).resolve()) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 0660337a2..3631c9dee 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -8,6 +8,7 @@ import logging import os import sys +import warnings from importlib import import_module, util from os.path import expanduser from pathlib import Path @@ -18,6 +19,7 @@ from strands.tools.decorator import DecoratedFunctionTool from ..types.tools import AgentTool, ToolSpec +from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -36,18 +38,23 @@ def __init__(self) -> None: self.tool_config: Optional[Dict[str, Any]] = None def process_tools(self, tools: List[Any]) -> List[str]: - """Process tools list that can contain tool names, paths, imported modules, or functions. + """Process tools list. + + Process list of tools that can contain local file path string, module import path string, + imported modules, @tool decorated functions, or instances of AgentTool. Args: tools: List of tool specifications. Can be: + 1. Local file path to a module based tool: `./path/to/module/tool.py` + 2. Module import path + 2.1. Path to a module based tool: `strands_tools.file_read` + 2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool` + 2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say` + 3. A module for a module based tool + 4. Instances of AgentTool (@tool decorated functions) + 5. Dictionaries with name/path keys (deprecated) - - String tool names (e.g., "calculator") - - File paths (e.g., "/path/to/tool.py") - - Imported Python modules (e.g., a module object) - - Functions decorated with @tool - - Dictionaries with name/path keys - - Instance of an AgentTool Returns: List of tool names that were processed. @@ -55,62 +62,76 @@ def process_tools(self, tools: List[Any]) -> List[str]: tool_names = [] def add_tool(tool: Any) -> None: - # Case 1: String file path - if isinstance(tool, str): - # Extract tool name from path - tool_name = os.path.basename(tool).split(".")[0] - self.load_tool_from_filepath(tool_name=tool_name, tool_path=tool) - tool_names.append(tool_name) - - # Case 2: Dictionary with name and path - elif isinstance(tool, dict) and "name" in tool and "path" in tool: - self.load_tool_from_filepath(tool_name=tool["name"], tool_path=tool["path"]) - tool_names.append(tool["name"]) - - # Case 3: Dictionary with path only - elif isinstance(tool, dict) and "path" in tool: - tool_name = os.path.basename(tool["path"]).split(".")[0] - self.load_tool_from_filepath(tool_name=tool_name, tool_path=tool["path"]) - tool_names.append(tool_name) - - # Case 4: Imported Python module - elif hasattr(tool, "__file__") and inspect.ismodule(tool): - # Get the module file path - module_path = tool.__file__ - # Extract the tool name from the module name - tool_name = tool.__name__.split(".")[-1] - - # Check for TOOL_SPEC in module to validate it's a Strands tool - if hasattr(tool, "TOOL_SPEC") and hasattr(tool, tool_name) and module_path: - self.load_tool_from_filepath(tool_name=tool_name, tool_path=module_path) - tool_names.append(tool_name) + try: + # String based tool + # Can be a file path, a module path, or a module path with a targeted function. Examples: + # './path/to/tool.py' + # 'my.module.tool' + # 'my.module.tool:tool_name' + if isinstance(tool, str): + tools = load_tool_from_string(tool) + for a_tool in tools: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Dictionary with name and path + elif isinstance(tool, dict) and "name" in tool and "path" in tool: + tools = load_tool_from_string(tool["path"]) + + tool_found = False + for a_tool in tools: + if a_tool.tool_name == tool["name"]: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + tool_found = True + + if not tool_found: + raise ValueError(f'Tool "{tool["name"]}" not found in "{tool["path"]}"') + + # Dictionary with path only + elif isinstance(tool, dict) and "path" in tool: + tools = load_tool_from_string(tool["path"]) + + for a_tool in tools: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Imported Python module + elif hasattr(tool, "__file__") and inspect.ismodule(tool): + # Extract the tool name from the module name + module_tool_name = tool.__name__.split(".")[-1] + + tools = load_tools_from_module(tool, module_tool_name) + for a_tool in tools: + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Case 5: AgentTools (which also covers @tool) + elif isinstance(tool, AgentTool): + self.register_tool(tool) + tool_names.append(tool.tool_name) + + # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool + elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): + for t in tool: + add_tool(t) else: - function_tools = self._scan_module_for_tools(tool) - for function_tool in function_tools: - self.register_tool(function_tool) - tool_names.append(function_tool.tool_name) - - if not function_tools: - logger.warning("tool_name=<%s>, module_path=<%s> | invalid agent tool", tool_name, module_path) - - # Case 5: AgentTools (which also covers @tool) - elif isinstance(tool, AgentTool): - self.register_tool(tool) - tool_names.append(tool.tool_name) - # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool - elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): - for t in tool: - add_tool(t) - else: - logger.warning("tool=<%s> | unrecognized tool specification", tool) + logger.warning("tool=<%s> | unrecognized tool specification", tool) - for a_tool in tools: - add_tool(a_tool) + except Exception as e: + exception_str = str(e) + logger.exception("tool_name=<%s> | failed to load tool", tool) + raise ValueError(f"Failed to load tool {tool}: {exception_str}") from e + for tool in tools: + add_tool(tool) return tool_names def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: - """Load a tool from a file path. + """DEPRECATED: Load a tool from a file path. Args: tool_name: Name of the tool. @@ -120,6 +141,13 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: FileNotFoundError: If the tool file is not found. ValueError: If the tool cannot be loaded. """ + warnings.warn( + "load_tool_from_filepath is deprecated and will be removed in Strands SDK 2.0. " + "`process_tools` automatically handles loading tools from a filepath.", + DeprecationWarning, + stacklevel=2, + ) + from .loader import ToolLoader try: diff --git a/tests/fixtures/say_tool.py b/tests/fixtures/say_tool.py new file mode 100644 index 000000000..4607b2501 --- /dev/null +++ b/tests/fixtures/say_tool.py @@ -0,0 +1,17 @@ +from strands import tool + + +@tool +def say(input: str) -> str: + """Say something.""" + return f"Hello {input}!" + + +@tool +def dont_say(input: str) -> str: + """Dont say something.""" + return "Didnt say anything!" + + +def not_a_tool() -> str: + return "Not a tool!" diff --git a/tests/fixtures/tool_with_spec_but_no_function.py b/tests/fixtures/tool_with_spec_but_no_function.py new file mode 100644 index 000000000..75f8bf6f6 --- /dev/null +++ b/tests/fixtures/tool_with_spec_but_no_function.py @@ -0,0 +1 @@ +TOOL_SPEC = {"hello": "world!"} diff --git a/tests/fixtures/tool_with_spec_but_non_callable_function.py b/tests/fixtures/tool_with_spec_but_non_callable_function.py new file mode 100644 index 000000000..0ca2f092c --- /dev/null +++ b/tests/fixtures/tool_with_spec_but_non_callable_function.py @@ -0,0 +1,3 @@ +TOOL_SPEC = {"hello": "world"} + +tool_with_spec_but_non_callable_function = "not a function!" diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index 6b86d00ee..13aca90c3 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -1,11 +1,12 @@ import os import re +import tempfile import textwrap import pytest from strands.tools.decorator import DecoratedFunctionTool -from strands.tools.loader import ToolLoader +from strands.tools.loader import ToolLoader, load_tools_from_file_path from strands.tools.tools import PythonAgentTool @@ -310,3 +311,9 @@ def test_load_tool_path_returns_single_tool(tool_path): assert loaded_python_tool.tool_name == "alpha" assert loaded_tool.tool_name == "alpha" + + +def test_load_tools_from_file_path_module_spec_missing(): + with tempfile.NamedTemporaryFile() as f: + with pytest.raises(ImportError, match=f"Could not create spec for {os.path.basename(f.name)}"): + load_tools_from_file_path(f.name) diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index f0759ea07..ee0098adc 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -26,7 +26,10 @@ def test_process_tools_with_invalid_path(): tool_registry = ToolRegistry() invalid_path = "not a filepath" - with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): + with pytest.raises( + ValueError, + match=f'Failed to load tool {invalid_path}: Tool string: "{invalid_path}" is not a valid tool string', + ): tool_registry.process_tools([invalid_path]) @@ -164,3 +167,96 @@ def test_register_tool_duplicate_name_with_hot_reload(): # Verify the second tool replaced the first assert tool_registry.registry["hot_reload_tool"] == tool_2 + + +def test_register_strands_tools_from_module(): + tool_registry = ToolRegistry() + tool_registry.process_tools(["tests.fixtures.say_tool"]) + + assert len(tool_registry.registry) == 2 + assert "say" in tool_registry.registry + assert "dont_say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_from_module(): + tool_registry = ToolRegistry() + tool_registry.process_tools(["tests.fixtures.say_tool:say"]) + + assert len(tool_registry.registry) == 1 + assert "say" in tool_registry.registry + assert "dont_say" not in tool_registry.registry + + +def test_register_strands_tools_specific_tool_from_module_tool_missing(): + tool_registry = ToolRegistry() + + with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:nay: "): + tool_registry.process_tools(["tests.fixtures.say_tool:nay"]) + + +def test_register_strands_tools_specific_tool_from_module_not_a_tool(): + tool_registry = ToolRegistry() + + with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:not_a_tool: "): + tool_registry.process_tools(["tests.fixtures.say_tool:not_a_tool"]) + + +def test_register_strands_tools_with_dict(): + tool_registry = ToolRegistry() + tool_registry.process_tools([{"path": "tests.fixtures.say_tool"}]) + + assert len(tool_registry.registry) == 2 + assert "say" in tool_registry.registry + assert "dont_say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_with_dict(): + tool_registry = ToolRegistry() + tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "say"}]) + + assert len(tool_registry.registry) == 1 + assert "say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_with_dict_not_found(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool {'path': 'tests.fixtures.say_tool'" + ", 'name': 'nay'}: Tool \"nay\" not found in \"tests.fixtures.say_tool\"", + ): + tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "nay"}]) + + +def test_register_strands_tools_module_no_spec(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.mocked_model_provider: " + "The module mocked_model_provider is not a valid module", + ): + tool_registry.process_tools(["tests.fixtures.mocked_model_provider"]) + + +def test_register_strands_tools_module_no_function(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.tool_with_spec_but_no_function: " + "Module-based tool tool_with_spec_but_no_function missing function tool_with_spec_but_no_function", + ): + tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_no_function"]) + + +def test_register_strands_tools_module_non_callable_function(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.tool_with_spec_but_non_callable_function:" + " Tool tool_with_spec_but_non_callable_function function is not callable", + ): + tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_non_callable_function"]) From 1790b2d7df56eeb8bb42f401441750e8960c1838 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 8 Oct 2025 17:05:57 -0400 Subject: [PATCH 03/26] Adding Development Tenets to CONTRIBUTING.md (#1009) * Adding Development Tenets to CONTRIBUTING.md * Update CONTRIBUTING.md --- CONTRIBUTING.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d107b1fa8..be83ff85b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,6 +36,18 @@ Before starting work on any issue: 3. Wait for maintainer confirmation before beginning significant work +## Development Tenets +Our team follows these core principles when designing and implementing features. These tenets help us make consistent decisions, resolve trade-offs, and maintain the quality and coherence of the SDK. When contributing, please consider how your changes align with these principles: + +1. **Simple at any scale:** We believe that simple things should be simple. The same clean abstractions that power a weekend prototype should scale effortlessly to production workloads. We reject the notion that enterprise-grade means enterprise-complicated - Strands remains approachable whether it's your first agent or your millionth. +2. **Extensible by design:** We allow for as much configuration as possible, from hooks to model providers, session managers, tools, etc. We meet customers where they are with flexible extension points that are simple to integrate with. +3. **Composability:** Primitives are building blocks with each other. Each feature of Strands is developed with all other features in mind, they are consistent and complement one another. +4. **The obvious path is the happy path:** Through intuitive naming, helpful error messages, and thoughtful API design, we guide developers toward correct patterns and away from common pitfalls. +5. **We are accessible to humans and agents:** Strands is designed for both humans and AI to understand equally well. We don’t take shortcuts on curated DX for humans and we go the extra mile to make sure coding assistants can help you use those interfaces the right way. +6. **Embrace common standards:** We respect what came before, and do not want to reinvent something that is already widely adopted or done better. + +When proposing solutions or reviewing code, we reference these principles to guide our decisions. If two approaches seem equally valid, we choose the one that best aligns with our tenets. + ## Development Environment This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management. From 92da54453ee3eadf3f32b1da1522cc3e9b05bb25 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 9 Oct 2025 10:11:56 -0400 Subject: [PATCH 04/26] Revert "feat: implement concurrent message reading for session managers (#897)" (#1013) --- src/strands/session/file_session_manager.py | 20 ++++------------ src/strands/session/s3_session_manager.py | 26 +++++++-------------- 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 93adeb7f2..491f7ad60 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -1,6 +1,5 @@ """File-based session manager for local filesystem storage.""" -import asyncio import json import logging import os @@ -232,20 +231,11 @@ def list_messages( else: message_files = message_files[offset:] - return asyncio.run(self._load_messages_concurrently(messages_dir, message_files)) - - async def _load_messages_concurrently(self, messages_dir: str, message_files: list[str]) -> list[SessionMessage]: - """Load multiple message files concurrently using async.""" - if not message_files: - return [] - - async def load_message(filename: str) -> SessionMessage: + # Load only the message files + messages: list[SessionMessage] = [] + for filename in message_files: file_path = os.path.join(messages_dir, filename) - loop = asyncio.get_event_loop() - message_data = await loop.run_in_executor(None, self._read_file, file_path) - return SessionMessage.from_dict(message_data) - - tasks = [load_message(filename) for filename in message_files] - messages = await asyncio.gather(*tasks) + message_data = self._read_file(file_path) + messages.append(SessionMessage.from_dict(message_data)) return messages diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 1f6ffe7f1..c6ce28d80 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -1,6 +1,5 @@ """S3-based session manager for cloud storage.""" -import asyncio import json import logging from typing import Any, Dict, List, Optional, cast @@ -284,23 +283,14 @@ def list_messages( else: message_keys = message_keys[offset:] - # Load message objects concurrently using async - return asyncio.run(self._load_messages_concurrently(message_keys)) + # Load only the required message objects + messages: List[SessionMessage] = [] + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + + return messages except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e - - async def _load_messages_concurrently(self, message_keys: List[str]) -> List[SessionMessage]: - """Load multiple message objects concurrently using async.""" - if not message_keys: - return [] - - async def load_message(key: str) -> Optional[SessionMessage]: - loop = asyncio.get_event_loop() - message_data = await loop.run_in_executor(None, self._read_s3_object, key) - return SessionMessage.from_dict(message_data) if message_data else None - - tasks = [load_message(key) for key in message_keys] - loaded_messages = await asyncio.gather(*tasks) - - return [msg for msg in loaded_messages if msg is not None] From 2f04758917d6200edf9962f43cbb57dcc8dc6f55 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 12:47:09 -0400 Subject: [PATCH 05/26] feat(models): use tool for litellm structured_output when supports_response_schema=false (#957) --- src/strands/models/litellm.py | 84 ++++++++++++++++-------- tests/strands/models/test_litellm.py | 22 +++++-- tests_integ/models/test_model_litellm.py | 61 +++++++++++++++++ 3 files changed, 136 insertions(+), 31 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1763f5dec..486f67bf8 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StreamEvent @@ -202,6 +203,10 @@ async def structured_output( ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. + Some models do not support native structured output via response_format. + In cases of proxies, we may not have a way to determine support, so we + fallback to using tool calling to achieve structured output. + Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. @@ -211,42 +216,69 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - supports_schema = supports_response_schema(self.get_config()["model_id"]) + if supports_response_schema(self.get_config()["model_id"]): + logger.debug("structuring output using response schema") + result = await self._structured_output_using_response_schema(output_model, prompt, system_prompt) + else: + logger.debug("model does not support response schema, structuring output using tool approach") + result = await self._structured_output_using_tool(output_model, prompt, system_prompt) + + yield {"output": result} + + async def _structured_output_using_response_schema( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using native response_format support.""" + response = await litellm.acompletion( + **self.client_args, + model=self.get_config()["model_id"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) - # If the provider does not support response schemas, we cannot reliably parse structured output. - # In that case we must not call the provider and must raise the documented ValueError. - if not supports_schema: - raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") - # For providers that DO support response schemas, call litellm and map context-window errors. + choice = response.choices[0] try: - response = await litellm.acompletion( - **self.client_args, - model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], - response_format=output_model, - ) + # Parse the message content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) except ContextWindowExceededError as e: logger.warning("litellm client raised context window overflow in structured_output") raise ContextWindowOverflowException(e) from e + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + async def _structured_output_using_tool( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using tool calling fallback.""" + tool_spec = convert_pydantic_to_tool_spec(output_model) + request = self.format_request(prompt, [tool_spec], system_prompt, cast(ToolChoice, {"any": {}})) + args = {**self.client_args, **request, "stream": False} + response = await litellm.acompletion(**args) if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") - # Find the first choice with tool_calls - for choice in response.choices: - if choice.finish_reason == "tool_calls": - try: - # Parse the tool call content as JSON - tool_call_data = json.loads(choice.message.content) - # Instantiate the output model with the parsed data - yield {"output": output_model(**tool_call_data)} - return - except (json.JSONDecodeError, TypeError, ValueError) as e: - raise ValueError(f"Failed to parse or load content into model: {e}") from e - - # If no tool_calls found, raise an error - raise ValueError("No tool_calls found in response") + choice = response.choices[0] + try: + # Parse the tool call content as JSON + tool_call = choice.message.tool_calls[0] + tool_call_data = json.loads(tool_call.function.arguments) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 776ae7bae..82023cae3 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -292,15 +292,27 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c @pytest.mark.asyncio -async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls): +async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.arguments = '{"name": "John", "age": 30}' + + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "tool_calls" + mock_choice.message.tool_calls = [mock_tool_call] + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + litellm_acompletion.return_value = mock_response + with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False): - with pytest.raises(ValueError, match="Model does not support response_format"): - stream = model.structured_output(test_output_model_cls, messages) - await stream.__anext__() + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + tru_result = events[-1] - litellm_acompletion.assert_not_called() + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings): diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 6cfdd3038..c5a09e3e9 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,3 +1,5 @@ +import unittest.mock + import pydantic import pytest @@ -40,6 +42,37 @@ class Weather(pydantic.BaseModel): return Weather(time="12:00", weather="sunny") +class Location(pydantic.BaseModel): + """Location information.""" + + city: str = pydantic.Field(description="The city name") + country: str = pydantic.Field(description="The country name") + + +class WeatherCondition(pydantic.BaseModel): + """Weather condition details.""" + + condition: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')") + temperature: int = pydantic.Field(description="Temperature in Celsius") + + +class NestedWeather(pydantic.BaseModel): + """Weather report with nested location and condition information.""" + + time: str = pydantic.Field(description="The time in HH:MM format") + location: Location = pydantic.Field(description="Location information") + weather: WeatherCondition = pydantic.Field(description="Weather condition details") + + +@pytest.fixture +def nested_weather(): + return NestedWeather( + time="12:00", + location=Location(city="New York", country="USA"), + weather=WeatherCondition(condition="sunny", temperature=25), + ) + + @pytest.fixture def yellow_color(): class Color(pydantic.BaseModel): @@ -134,3 +167,31 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): tru_color = agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +def test_structured_output_unsupported_model(model, nested_weather): + # Mock supports_response_schema to return False to test fallback mechanism + with ( + unittest.mock.patch.multiple( + "strands.models.litellm", + supports_response_schema=unittest.mock.DEFAULT, + ) as mocks, + unittest.mock.patch.object( + model, "_structured_output_using_tool", wraps=model._structured_output_using_tool + ) as mock_tool, + unittest.mock.patch.object( + model, "_structured_output_using_response_schema", wraps=model._structured_output_using_response_schema + ) as mock_schema, + ): + mocks["supports_response_schema"].return_value = False + + # Test that structured output still works via tool calling fallback + agent = Agent(model=model) + prompt = "The time is 12:00 in New York, USA and the weather is sunny with temperature 25 degrees Celsius" + tru_weather = agent.structured_output(NestedWeather, prompt) + exp_weather = nested_weather + assert tru_weather == exp_weather + + # Verify that the tool method was called and schema method was not + mock_tool.assert_called_once() + mock_schema.assert_not_called() From aada326821f0bce2a0ab41b14ead457a78e2f6b4 Mon Sep 17 00:00:00 2001 From: Kyler Middleton Date: Thu, 9 Oct 2025 13:56:25 -0500 Subject: [PATCH 06/26] feat(mcp): Add EmbeddedResource support to mcp (#726) --------- Co-authored-by: Dean Schmigelski --- src/strands/tools/mcp/mcp_client.py | 60 ++++++++- tests/strands/tools/mcp/test_mcp_client.py | 147 +++++++++++++++++++++ tests_integ/mcp/echo_server.py | 46 +++++++ tests_integ/mcp/test_mcp_client.py | 94 +++++++++++++ 4 files changed, 343 insertions(+), 4 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index dec8ec313..8148e149a 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -20,8 +20,9 @@ import anyio from mcp import ClientSession, ListToolsResult +from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents from mcp.types import CallToolResult as MCPCallToolResult -from mcp.types import GetPromptResult, ListPromptsResult +from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent @@ -358,8 +359,7 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes """ self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) - # Build a typed list of ToolResultContent. Use a clearer local name to avoid shadowing - # and annotate the result for mypy so it knows the intended element type. + # Build a typed list of ToolResultContent. mapped_contents: list[ToolResultContent] = [ mc for content in call_tool_result.content @@ -438,7 +438,7 @@ def _background_task(self) -> None: def _map_mcp_content_to_tool_result_content( self, - content: MCPTextContent | MCPImageContent | Any, + content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, ) -> Union[ToolResultContent, None]: """Maps MCP content types to tool result content types. @@ -462,6 +462,58 @@ def _map_mcp_content_to_tool_result_content( "source": {"bytes": base64.b64decode(content.data)}, } } + elif isinstance(content, MCPEmbeddedResource): + """ + TODO: Include URI information in results. + Models may find it useful to be aware not only of the information, + but the location of the information too. + + This may be difficult without taking an opinionated position. For example, + a content block may need to indicate that the following Image content block + is of particular URI. + """ + + self._log_debug_with_thread("mapping MCP embedded resource content") + + resource = content.resource + if isinstance(resource, TextResourceContents): + return {"text": resource.text} + elif isinstance(resource, BlobResourceContents): + try: + raw_bytes = base64.b64decode(resource.blob) + except Exception: + self._log_debug_with_thread("embedded resource blob could not be decoded - dropping") + return None + + if resource.mimeType and ( + resource.mimeType.startswith("text/") + or resource.mimeType + in ( + "application/json", + "application/xml", + "application/javascript", + "application/yaml", + "application/x-yaml", + ) + or resource.mimeType.endswith(("+json", "+xml")) + ): + try: + return {"text": raw_bytes.decode("utf-8", errors="replace")} + except Exception: + pass + + if resource.mimeType in MIME_TO_FORMAT: + return { + "image": { + "format": MIME_TO_FORMAT[resource.mimeType], + "source": {"bytes": raw_bytes}, + } + } + + self._log_debug_with_thread("embedded resource blob with non-textual/unknown mimeType - dropping") + return None + + return None # type: ignore[unreachable] # Defensive: future MCP resource types else: self._log_debug_with_thread("unhandled content type: %s - dropping content", content.__class__.__name__) return None diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 67d8fe558..130a4703e 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,3 +1,4 @@ +import base64 import time from unittest.mock import AsyncMock, MagicMock, patch @@ -541,3 +542,149 @@ def slow_transport(): assert client._background_thread_session is None assert client._background_thread_event_loop is None assert not client._init_future.done() # New future created + + +def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): + """EmbeddedResource.resource (uri + text) should map to plain text content.""" + embedded_resource = { + "type": "resource", # required literal + "resource": { + "uri": "mcp://resource/embedded-text-1", + "text": "inner text", + "mimeType": "text/plain", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "inner text" + + +def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock_session): + """EmbeddedResource.resource (uri + blob with textual MIME) should decode to text.""" + + payload = base64.b64encode(b'{"k":"v"}').decode() + + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-blob-1", + # NOTE: blob is a STRING, mimeType is sibling + "blob": payload, + "mimeType": "application/json", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == '{"k":"v"}' + + +def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): + """EmbeddedResource.resource (blob with image MIME) should map to image content.""" + # Read yellow.png file + with open("tests_integ/yellow.png", "rb") as image_file: + png_data = image_file.read() + payload = base64.b64encode(png_data).decode() + + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-image", + "blob": payload, + "mimeType": "image/png", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert "image" in result["content"][0] + assert result["content"][0]["image"]["format"] == "png" + assert "bytes" in result["content"][0]["image"]["source"] + + +def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_session): + """EmbeddedResource.resource (blob with non-textual/unknown MIME) should be dropped.""" + payload = base64.b64encode(b"\x00\x01\x02\x03").decode() + + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-binary", + "blob": payload, + "mimeType": "application/octet-stream", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 0 # Content should be dropped + + +def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_session): + """EmbeddedResource with different textual MIME types should decode to text.""" + + # Test YAML content + yaml_content = base64.b64encode(b"key: value\nlist:\n - item1\n - item2").decode() + embedded_resource = { + "type": "resource", + "resource": { + "uri": "mcp://resource/embedded-yaml", + "blob": yaml_content, + "mimeType": "application/yaml", + }, + } + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[embedded_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 1 + assert "key: value" in result["content"][0]["text"] + + +def test_call_tool_sync_embedded_unknown_resource_type_dropped(mock_transport, mock_session): + """EmbeddedResource with unknown resource type should be dropped for forward compatibility.""" + + # Mock an unknown resource type that's neither TextResourceContents nor BlobResourceContents + class UnknownResourceContents: + def __init__(self): + self.uri = "mcp://resource/unknown-type" + self.mimeType = "application/unknown" + self.data = "some unknown data" + + # Create a mock embedded resource with unknown resource type + mock_embedded_resource = MagicMock() + mock_embedded_resource.resource = UnknownResourceContents() + + mock_session.call_tool.return_value = MagicMock( + isError=False, content=[mock_embedded_resource], structuredContent=None + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) + + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + assert result["status"] == "success" + assert len(result["content"]) == 0 # Unknown resource type should be dropped diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index 160ad5af9..e15065a4a 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -15,7 +15,11 @@ $ python echo_server.py """ +import base64 +from typing import Literal + from mcp.server import FastMCP +from mcp.types import BlobResourceContents, EmbeddedResource, TextResourceContents from pydantic import BaseModel @@ -46,6 +50,48 @@ def echo(to_echo: str) -> str: def echo_with_structured_content(to_echo: str) -> EchoResponse: return EchoResponse(echoed=to_echo, message_length=len(to_echo)) + @mcp.tool(description="Get current weather information for a location") + def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): + """Get weather data including forecasts and alerts for the specified location""" + if location.lower() == "new york": + return [ + EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri="https://weather.api/forecast/nyc", + mimeType="text/plain", + text="Current weather in New York: 72°F, partly cloudy with light winds.", + ), + ) + ] + elif location.lower() == "london": + return [ + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + uri="https://weather.api/data/london.json", + mimeType="application/json", + blob=base64.b64encode( + '{"temperature": 18, "condition": "rainy", "humidity": 85}'.encode() + ).decode(), + ), + ) + ] + elif location.lower() == "tokyo": + # Read yellow.png file for weather icon + with open("tests_integ/yellow.png", "rb") as image_file: + png_data = image_file.read() + return [ + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + uri="https://weather.api/icons/sunny.png", + mimeType="image/png", + blob=base64.b64encode(png_data).decode(), + ), + ) + ] + mcp.run(transport="stdio") diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 9d5ab5f13..2c9bb73e1 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -272,6 +272,100 @@ def transport_callback() -> MCPTransport: assert "Hello, Charlie!" in prompt_text +def test_mcp_client_embedded_resources(): + """Test that MCP client properly handles EmbeddedResource content types.""" + embedded_resource_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with embedded_resource_mcp_client: + # Test text embedded resource + text_result = embedded_resource_mcp_client.call_tool_sync( + tool_use_id="test-embedded-text", + name="get_weather", + arguments={"location": "New York"}, + ) + assert text_result["status"] == "success" + assert len(text_result["content"]) == 1 + assert "72°F" in text_result["content"][0]["text"] + assert "partly cloudy" in text_result["content"][0]["text"] + + # Test JSON embedded resource (blob with textual MIME type) + json_result = embedded_resource_mcp_client.call_tool_sync( + tool_use_id="test-embedded-json", + name="get_weather", + arguments={"location": "London"}, + ) + assert json_result["status"] == "success" + assert len(json_result["content"]) == 1 + json_content = json_result["content"][0]["text"] + assert "temperature" in json_content + assert "rainy" in json_content + + # Test image embedded resource + image_result = embedded_resource_mcp_client.call_tool_sync( + tool_use_id="test-embedded-image", + name="get_weather", + arguments={"location": "Tokyo"}, + ) + assert image_result["status"] == "success" + assert len(image_result["content"]) == 1 + assert "image" in image_result["content"][0] + assert image_result["content"][0]["image"]["format"] == "png" + assert "bytes" in image_result["content"][0]["image"]["source"] + + +@pytest.mark.asyncio +async def test_mcp_client_embedded_resources_async(): + """Test that async MCP client properly handles EmbeddedResource content types.""" + embedded_resource_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with embedded_resource_mcp_client: + # Test text embedded resource async + text_result = await embedded_resource_mcp_client.call_tool_async( + tool_use_id="test-embedded-text-async", + name="get_weather", + arguments={"location": "New York"}, + ) + assert text_result["status"] == "success" + assert len(text_result["content"]) == 1 + assert "72°F" in text_result["content"][0]["text"] + + # Test JSON embedded resource async + json_result = await embedded_resource_mcp_client.call_tool_async( + tool_use_id="test-embedded-json-async", + name="get_weather", + arguments={"location": "London"}, + ) + assert json_result["status"] == "success" + assert len(json_result["content"]) == 1 + json_content = json_result["content"][0]["text"] + assert "temperature" in json_content + + +def test_mcp_client_embedded_resources_with_agent(): + """Test that embedded resources work correctly when used with Agent.""" + embedded_resource_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with embedded_resource_mcp_client: + tools = embedded_resource_mcp_client.list_tools_sync() + agent = Agent(tools=tools) + + # Test that agent can successfully use tools that return embedded resources + result = agent("Get the weather for New York and tell me what it says") + + # Check that the agent successfully processed the embedded resource + assert result.message is not None + response_text = " ".join([block["text"] for block in result.message["content"] if "text" in block]).lower() + + # The agent should have received and processed the embedded weather content + assert any(["72" in response_text, "partly cloudy" in response_text, "weather" in response_text]) + + def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] From 9632ed57e56d8a00f7a8c985c3a92eaf4a16d32b Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 9 Oct 2025 15:37:07 -0400 Subject: [PATCH 07/26] conversation manager - summarization - noop tool (#1003) --- .../summarizing_conversation_manager.py | 27 +++++++++++++- .../test_summarizing_conversation_manager.py | 29 +++++++++++++++ ...rizing_conversation_manager_integration.py | 36 +++++++++++++++++++ 3 files changed, 91 insertions(+), 1 deletion(-) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index b08b6853e..117626fbe 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -5,6 +5,8 @@ from typing_extensions import override +from ...tools import tool +from ...tools.registry import ToolRegistry from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -23,6 +25,10 @@ - You MUST create a structured and concise summary in bullet-point format. - You MUST NOT respond conversationally. - You MUST NOT address the user directly. +- You MUST NOT comment on tool availability. + +Assumptions: +- You MUST NOT assume tool executions failed unless otherwise stated. Task: Your task is to create a structured summary document: @@ -182,9 +188,10 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: # Choose which agent to use for summarization summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent - # Save original system prompt and messages to restore later + # Save original system prompt, messages, and tool registry to restore later original_system_prompt = summarization_agent.system_prompt original_messages = summarization_agent.messages.copy() + original_tool_registry = summarization_agent.tool_registry try: # Only override system prompt if no agent was provided during initialization @@ -197,6 +204,13 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: ) # Temporarily set the system prompt for summarization summarization_agent.system_prompt = system_prompt + + # Add no-op tool if agent has no tools to satisfy tool spec requirement + if not summarization_agent.tool_names: + tool_registry = ToolRegistry() + tool_registry.register_tool(self._noop_tool) + summarization_agent.tool_registry = tool_registry + summarization_agent.messages = messages # Use the agent to generate summary with rich content (can use tools if needed) @@ -207,6 +221,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: # Restore original agent state summarization_agent.system_prompt = original_system_prompt summarization_agent.messages = original_messages + summarization_agent.tool_registry = original_tool_registry def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. @@ -249,3 +264,13 @@ def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_poin raise ContextWindowOverflowException("Unable to trim conversation context!") return split_point + + @tool(name="noop", description="MUST NOT call or summarize") + def _noop_tool(self) -> None: + """No-op tool to satisfy tool spec requirement when tool messages are present. + + Some model provides (e.g., Bedrock) will return an error response if tool uses and tool results are present in + messages without any tool specs configured. Consequently, if the summarization agent has no registered tools, + summarization will fail. As a workaround, we register the no-op tool. + """ + pass diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index 6003a1710..4b69e6653 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -19,6 +19,8 @@ def __init__(self, summary_response="This is a summary of the conversation."): self.messages = [] self.model = Mock() self.call_tracker = Mock() + self.tool_registry = Mock() + self.tool_names = [] def __call__(self, prompt): """Mock agent call that returns a summary.""" @@ -608,3 +610,30 @@ def test_summarizing_conversation_manager_properly_records_removed_message_count # so we dont count this toward the total: # 4 (Previously removed messages) + 2 (removed messages) - 1 (Previous summary message) = 5 assert manager.removed_message_count == 5 + + +@patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") +def test_summarizing_conversation_manager_generate_summary_with_noop_tool(mock_registry_cls, summarizing_manager): + mock_registry = mock_registry_cls.return_value + + messages = [{"role": "user", "content": [{"text": "test"}]}] + agent = create_mock_agent() + + original_tool_registry = agent.tool_registry + summarizing_manager._generate_summary(messages, agent) + + assert original_tool_registry == agent.tool_registry + mock_registry.register_tool.assert_called_once() + + +@patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") +def test_summarizing_conversation_manager_generate_summary_with_tools(mock_registry_cls, summarizing_manager): + mock_registry = mock_registry_cls.return_value + + messages = [{"role": "user", "content": [{"text": "test"}]}] + agent = create_mock_agent() + agent.tool_names = ["test_tool"] + + summarizing_manager._generate_summary(messages, agent) + + mock_registry.register_tool.assert_not_called() diff --git a/tests_integ/test_summarizing_conversation_manager_integration.py b/tests_integ/test_summarizing_conversation_manager_integration.py index b205c723f..91fb5b910 100644 --- a/tests_integ/test_summarizing_conversation_manager_integration.py +++ b/tests_integ/test_summarizing_conversation_manager_integration.py @@ -372,3 +372,39 @@ def test_dedicated_summarization_agent(model, summarization_model): break assert summary_text + + +def test_summarization_with_tool_messages_and_no_tools(): + agent = Agent( + messages=[ + {"role": "user", "content": [{"text": "What is the current time?"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "t1", "name": "time_tool", "input": {}}}], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "content": [{"text": "12:00"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "The current time is 12:00."}]}, + {"role": "user", "content": [{"text": "Thank you"}]}, + {"role": "assistant", "content": [{"text": "You are welcome."}]}, + ], + ) + + conversation_manager = SummarizingConversationManager(summary_ratio=1, preserve_recent_messages=2) + conversation_manager.reduce_context(agent) + + assert len(agent.tool_names) == 0 + assert len(agent.messages) == 3 + + summary = str(agent.messages[0]).lower() + assert "12:00" in summary From 419de199713ac3e98b88cd61851191dd969b2990 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Fri, 10 Oct 2025 21:29:00 +0800 Subject: [PATCH 08/26] Fix additional_args passing in SageMakerAIModel (#983) * fix(sagemaker): additional_args dict issue Fix error where passing an additional_args dict to SageMakerAIModel would raise an AttributeError because Python dicts have no '__dict__' attribute. Fixes #982 * fix(sagemaker): typing for endpoint_config Fix typing for SageMakerAIModel.endpoint_config which was previously being treated as an arbitrary dictionary due to init assignment. * fix(sagemaker): Typing for payload_config Fix typing for SageMakerAIModel.payload_config, which was previously being treated as a plain dict due to init assignment. * test(sagemaker): tests for ep additional_args Add a test to check for insertion of endpoint config additional_args * fix(sagemaker): include payload additional_args Copy SageMakerAIPayloadSchema's additional_args into request payloads where provided - previously these were being ignored. Includes unit tests. --- src/strands/models/sagemaker.py | 36 ++++++++++++++++---------- tests/strands/models/test_sagemaker.py | 28 ++++++++++++++++++++ 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index d1447732e..25b3ca7ce 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -4,7 +4,7 @@ import logging import os from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union import boto3 from botocore.config import Config as BotocoreConfig @@ -151,8 +151,8 @@ def __init__( validate_config_keys(payload_config, self.SageMakerAIPayloadSchema) payload_config.setdefault("stream", True) payload_config.setdefault("tool_results_as_user_messages", False) - self.endpoint_config = dict(endpoint_config) - self.payload_config = dict(payload_config) + self.endpoint_config = self.SageMakerAIEndpointConfig(**endpoint_config) + self.payload_config = self.SageMakerAIPayloadSchema(**payload_config) logger.debug( "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config ) @@ -193,7 +193,7 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i Returns: The Amazon SageMaker model configuration. """ - return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) + return self.endpoint_config @override def format_request( @@ -238,6 +238,10 @@ def format_request( }, } + payload_additional_args = self.payload_config.get("additional_args") + if payload_additional_args: + payload.update(payload_additional_args) + # Remove tools and tool_choice if tools = [] if not payload["tools"]: payload.pop("tools") @@ -273,16 +277,20 @@ def format_request( } # Add optional SageMaker parameters if provided - if self.endpoint_config.get("inference_component_name"): - request["InferenceComponentName"] = self.endpoint_config["inference_component_name"] - if self.endpoint_config.get("target_model"): - request["TargetModel"] = self.endpoint_config["target_model"] - if self.endpoint_config.get("target_variant"): - request["TargetVariant"] = self.endpoint_config["target_variant"] - - # Add additional args if provided - if self.endpoint_config.get("additional_args"): - request.update(self.endpoint_config["additional_args"].__dict__) + inf_component_name = self.endpoint_config.get("inference_component_name") + if inf_component_name: + request["InferenceComponentName"] = inf_component_name + target_model = self.endpoint_config.get("target_model") + if target_model: + request["TargetModel"] = target_model + target_variant = self.endpoint_config.get("target_variant") + if target_variant: + request["TargetVariant"] = target_variant + + # Add additional request args if provided + additional_args = self.endpoint_config.get("additional_args") + if additional_args: + request.update(additional_args) return request diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index a5662ecdc..72ebf01c6 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -112,11 +112,13 @@ def test_init_with_all_params(self, boto_session): "endpoint_name": "test-endpoint", "inference_component_name": "test-component", "region_name": "us-west-2", + "additional_args": {"test_req_arg_name": "test_req_arg_value"}, } payload_config = { "stream": False, "max_tokens": 1024, "temperature": 0.7, + "additional_args": {"test_payload_arg_name": "test_payload_arg_value"}, } client_config = BotocoreConfig(user_agent_extra="test-agent") @@ -129,9 +131,11 @@ def test_init_with_all_params(self, boto_session): assert model.endpoint_config["endpoint_name"] == "test-endpoint" assert model.endpoint_config["inference_component_name"] == "test-component" + assert model.endpoint_config["additional_args"]["test_req_arg_name"] == "test_req_arg_value" assert model.payload_config["stream"] is False assert model.payload_config["max_tokens"] == 1024 assert model.payload_config["temperature"] == 0.7 + assert model.payload_config["additional_args"]["test_payload_arg_name"] == "test_payload_arg_value" boto_session.client.assert_called_once_with( service_name="sagemaker-runtime", @@ -239,6 +243,30 @@ def test_get_config(self, model, endpoint_config): # assert "tools" in payload # assert payload["tools"] == [] + def test_format_request_with_additional_args(self, boto_session, endpoint_config, messages, payload_config): + """Test formatting a request's `additional_args` where provided""" + endpoint_config_ext = { + **endpoint_config, + "additional_args": { + "extra_request_key": "extra_request_value", + }, + } + payload_config_ext = { + **payload_config, + "additional_args": { + "extra_payload_key": "extra_payload_value", + }, + } + model = SageMakerAIModel( + boto_session=boto_session, + endpoint_config=endpoint_config_ext, + payload_config=payload_config_ext, + ) + request = model.format_request(messages) + assert request.get("extra_request_key") == "extra_request_value" + payload = json.loads(request["Body"]) + assert payload.get("extra_payload_key") == "extra_payload_value" + @pytest.mark.asyncio async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages): """Test streaming response with streaming enabled.""" From 7fbc9dc876533d60ff80957510a2dd19a05f5624 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Fri, 10 Oct 2025 23:19:24 +0800 Subject: [PATCH 09/26] feat: replace kwargs with invocation_state in agent APIs (#966) * feat: replace kwargs with invocation_state in agent APIs * fix: handle **kwargs in stream_async. * feat: add a unit test for the change * Update src/strands/agent/agent.py Co-authored-by: Nick Clegg * tool - executors - concurrent - remove no-op gather (#954) * feat(telemetry): updated traces to match OTEL v1.37 semantic conventions (#952) * event loop - handle model execution (#958) * feat: implement concurrent message reading for session managers (#897) Replace sequential message loading with async concurrent reading in both S3SessionManager and FileSessionManager to improve performance for long conversations. Uses asyncio.gather() with run_in_executor() to read multiple messages simultaneously while maintaining proper ordering. Resolves: #874 Co-authored-by: Vamil Gandhi * hooks - before tool call event - cancel tool (#964) * fix(telemetry): removed double serialization for events (#977) * fix(litellm): map LiteLLM context-window errors to ContextWindowOverflowException (#994) * feat: add more tests and adjust invocation_state dic structure * Apply suggestion from @Unshure Co-authored-by: Nick Clegg * fix: adjust **kwargs in multiagent primitives --------- Co-authored-by: Nick Clegg Co-authored-by: Patrick Gray Co-authored-by: poshinchen Co-authored-by: Vamil Gandhi Co-authored-by: Vamil Gandhi Co-authored-by: ratish <114130421+Ratish1@users.noreply.github.com> --- src/strands/agent/agent.py | 44 ++++++++++++++------ src/strands/multiagent/base.py | 7 +++- src/strands/multiagent/graph.py | 4 +- src/strands/multiagent/swarm.py | 3 +- tests/strands/agent/test_agent.py | 56 ++++++++++++++++++++++++++ tests/strands/multiagent/test_base.py | 5 ++- tests/strands/multiagent/test_graph.py | 12 ++++-- tests/strands/multiagent/test_swarm.py | 4 +- 8 files changed, 109 insertions(+), 26 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..8607a2601 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -13,6 +13,7 @@ import json import logging import random +import warnings from concurrent.futures import ThreadPoolExecutor from typing import ( Any, @@ -374,7 +375,9 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + def __call__( + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -389,7 +392,8 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result object containing: @@ -401,13 +405,15 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """ def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, **kwargs)) + return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + async def invoke_async( + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -422,7 +428,8 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result: object containing: @@ -432,7 +439,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - events = self.stream_async(prompt, **kwargs) + events = self.stream_async(prompt, invocation_state=invocation_state, **kwargs) async for event in events: _ = event @@ -528,9 +535,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def stream_async( - self, - prompt: AgentInput = None, - **kwargs: Any, + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -546,7 +551,8 @@ async def stream_async( - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass to the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: An async iterator that yields events. Each event is a dictionary containing @@ -567,7 +573,19 @@ async def stream_async( yield event["data"] ``` """ - callback_handler = kwargs.get("callback_handler", self.callback_handler) + merged_state = {} + if kwargs: + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + merged_state.update(kwargs) + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state + else: + if invocation_state is not None: + merged_state = invocation_state + + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) messages = self._convert_prompt_to_messages(prompt) @@ -576,10 +594,10 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, invocation_state=kwargs) + events = self._run_loop(messages, invocation_state=merged_state) async for event in events: - event.prepare(invocation_state=kwargs) + event.prepare(invocation_state=merged_state) if event.is_callback_event: as_dict = event.as_dict() diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 03d7de9b4..0dbd85d81 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -4,6 +4,7 @@ """ import asyncio +import warnings from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field @@ -111,8 +112,12 @@ def __call__( if invocation_state is None: invocation_state = {} + if kwargs: + invocation_state.update(kwargs) + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 738dc4d4c..60299c1b5 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -572,11 +572,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) elif isinstance(node.executor, Agent): if self.node_timeout is not None: agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input, **invocation_state), + node.executor.invoke_async(node_input, invocation_state=invocation_state), timeout=self.node_timeout, ) else: - agent_response = await node.executor.invoke_async(node_input, **invocation_state) + agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 620fa5e24..42efd5742 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -635,8 +635,7 @@ async def _execute_node( # Execute node result = None node.reset_executor_state() - # Unpacking since this is the agent class. Other executors should not unpack - result = await node.executor.invoke_async(node_input, **invocation_state) + result = await node.executor.invoke_async(node_input, invocation_state=invocation_state) execution_time = round((time.time() - start_time) * 1000) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..200584115 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +import warnings from uuid import uuid4 import pytest @@ -1877,3 +1878,58 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text + + +def test_agent__call__handles_none_invocation_state(mock_model, agent): + """Test that agent handles None invocation_state without AttributeError.""" + mock_model.mock_stream.return_value = [ + {"contentBlockDelta": {"delta": {"text": "test response"}}}, + {"contentBlockStop": {}}, + ] + + # This should not raise AttributeError: 'NoneType' object has no attribute 'get' + result = agent("test", invocation_state=None) + + assert result.message["content"][0]["text"] == "test response" + assert result.stop_reason == "end_turn" + + +def test_agent__call__invocation_state_with_kwargs_deprecation_warning(agent, mock_event_loop_cycle): + """Test that kwargs trigger deprecation warning and are merged correctly with invocation_state.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + # Should have nested structure when both invocation_state and kwargs are provided + assert invocation_state["invocation_state"] == {"my": "state"} + assert invocation_state["other_kwarg"] == "foobar" + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}, other_kwarg="foobar") + + # Verify deprecation warning was issued + assert len(captured_warnings) == 1 + assert issubclass(captured_warnings[0].category, UserWarning) + assert "`**kwargs` parameter is deprecating, use `invocation_state` instead." in str(captured_warnings[0].message) + + +def test_agent__call__invocation_state_only_no_warning(agent, mock_event_loop_cycle): + """Test that using only invocation_state does not trigger warning and passes state directly.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + + assert invocation_state["my"] == "state" + assert "agent" in invocation_state + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index d21aa6e14..ab55b2c84 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -159,6 +159,7 @@ async def invoke_async(self, task, invocation_state, **kwargs): self.invoke_async_called = True self.received_task = task self.received_kwargs = kwargs + self.received_invocation_state = invocation_state return MultiAgentResult( status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) @@ -166,10 +167,10 @@ async def invoke_async(self, task, invocation_state, **kwargs): agent = TestMultiAgent() # Test with string task - result = agent("test task", param1="value1", param2="value2") + result = agent("test task", param1="value1", param2="value2", invocation_state={"value3": "value4"}) assert agent.invoke_async_called assert agent.received_task == "test task" - assert agent.received_kwargs == {"param1": "value1", "param2": "value2"} + assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"} assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 8097d944e..c4c1a664f 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -310,7 +310,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): result = await graph.invoke_async([{"text": "Original task"}]) # Verify entry node was called with original task - entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}]) + entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}], invocation_state={}) assert result.status == Status.COMPLETED mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called_once() @@ -906,7 +906,7 @@ def __init__(self, name): self._session_manager = None self.hooks = HookRegistry() - async def invoke_async(self, input_data): + async def invoke_async(self, input_data, invocation_state=None): # Increment execution count in state count = self.state.get("execution_count") or 0 self.state.set("execution_count", count + 1) @@ -1300,7 +1300,9 @@ async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state) + kwargs_agent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing"}], invocation_state=test_invocation_state + ) assert result.status == Status.COMPLETED @@ -1335,5 +1337,7 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = graph("Test kwargs passing sync", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) + kwargs_agent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing sync"}], invocation_state=test_invocation_state + ) assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 7d3e69695..0968fd30c 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -558,7 +558,7 @@ async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): test_kwargs = {"custom_param": "test_value", "another_param": 42} result = await swarm.invoke_async("Test kwargs passing", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs} assert result.status == Status.COMPLETED @@ -572,5 +572,5 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): test_kwargs = {"custom_param": "test_value", "another_param": 42} result = swarm("Test kwargs passing sync", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs} assert result.status == Status.COMPLETED From 355b3bbaef105c6b44f2610e4d677d3bb74883d1 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Tue, 14 Oct 2025 09:30:53 -0400 Subject: [PATCH 10/26] feat(telemetry): updated semantic conventions, added timeToFirstByteMs into spans and metrics (#997) * feat(telemetry): added timeToFirstByteMs into spans and metrics * chore(trace): updated semantic conventions with tool mappings --- src/strands/event_loop/event_loop.py | 2 +- src/strands/event_loop/streaming.py | 26 ++++-- src/strands/telemetry/metrics.py | 7 +- src/strands/telemetry/metrics_constants.py | 1 + src/strands/telemetry/tracer.py | 93 +++++++++++++++++++--- src/strands/types/event_loop.py | 7 +- tests/strands/event_loop/test_streaming.py | 4 +- tests/strands/telemetry/test_metrics.py | 21 ++++- tests/strands/telemetry/test_tracer.py | 84 +++++++++++++------ 9 files changed, 195 insertions(+), 50 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index d6367e9d9..feb6ac339 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -281,7 +281,7 @@ async def _handle_model_execution( message = recover_message_on_max_tokens_reached(message) if model_invoke_span: - tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) break # Success! Break out of retry loop except Exception as e: diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index f24bd2a76..73f38de8a 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,6 +2,7 @@ import json import logging +import time from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model @@ -267,31 +268,38 @@ def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> N state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] -def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: +def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | None = None) -> tuple[Usage, Metrics]: """Extracts usage metrics from the metadata chunk. Args: event: metadata. + time_to_first_byte_ms: time to get the first byte from the model in milliseconds Returns: The extracted usage metrics and latency. """ usage = Usage(**event["usage"]) metrics = Metrics(**event["metrics"]) + if time_to_first_byte_ms: + metrics["timeToFirstByteMs"] = time_to_first_byte_ms return usage, metrics -async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]: +async def process_stream( + chunks: AsyncIterable[StreamEvent], start_time: float | None = None +) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. + start_time: Time when the model request is initiated Yields: The reason for stopping, the constructed message, and the usage metrics. """ stop_reason: StopReason = "end_turn" + first_byte_time = None state: dict[str, Any] = { "message": {"role": "assistant", "content": []}, @@ -303,10 +311,14 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T state["content"] = state["message"]["content"] usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics: Metrics = Metrics(latencyMs=0) + metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0) async for chunk in chunks: + # Track first byte time when we get first content + if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): + first_byte_time = time.time() yield ModelStreamChunkEvent(chunk=chunk) + if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: @@ -319,7 +331,10 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T elif "messageStop" in chunk: stop_reason = handle_message_stop(chunk["messageStop"]) elif "metadata" in chunk: - usage, metrics = extract_usage_metrics(chunk["metadata"]) + time_to_first_byte_ms = ( + int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None + ) + usage, metrics = extract_usage_metrics(chunk["metadata"], time_to_first_byte_ms) elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], state) @@ -346,7 +361,8 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = remove_blank_messages_content_text(messages) + start_time = time.time() chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt) - async for event in process_stream(chunks): + async for event in process_stream(chunks, start_time): yield event diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 883273f64..abfbbffae 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -286,6 +286,8 @@ def update_metrics(self, metrics: Metrics) -> None: metrics: The metrics data to add to the accumulated totals. """ self._metrics_client.event_loop_latency.record(metrics["latencyMs"]) + if metrics.get("timeToFirstByteMs") is not None: + self._metrics_client.model_time_to_first_token.record(metrics["timeToFirstByteMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] def get_summary(self) -> Dict[str, Any]: @@ -448,7 +450,7 @@ class MetricsClient: event_loop_output_tokens: Histogram event_loop_cache_read_input_tokens: Histogram event_loop_cache_write_input_tokens: Histogram - + model_time_to_first_token: Histogram tool_call_count: Counter tool_success_count: Counter tool_error_count: Counter @@ -507,3 +509,6 @@ def create_instruments(self) -> None: self.event_loop_cache_write_input_tokens = self.meter.create_histogram( name=constants.STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS, unit="token" ) + self.model_time_to_first_token = self.meter.create_histogram( + name=constants.STRANDS_MODEL_TIME_TO_FIRST_TOKEN, unit="ms" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py index f8fac34da..2e1047581 100644 --- a/src/strands/telemetry/metrics_constants.py +++ b/src/strands/telemetry/metrics_constants.py @@ -15,3 +15,4 @@ STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS = "strands.event_loop.cache_read.input.tokens" STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS = "strands.event_loop.cache_write.input.tokens" +STRANDS_MODEL_TIME_TO_FIRST_TOKEN = "strands.model.time_to_first_token" diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 7cd2d0e7b..907fd454a 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -16,7 +16,7 @@ from ..agent.agent_result import AgentResult from ..types.content import ContentBlock, Message, Messages -from ..types.streaming import StopReason, Usage +from ..types.streaming import Metrics, StopReason, Usage from ..types.tools import ToolResult, ToolUse from ..types.traces import Attributes, AttributeValue @@ -153,6 +153,28 @@ def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> for key, value in attributes.items(): span.set_attribute(key, value) + def _add_optional_usage_and_metrics_attributes( + self, attributes: Dict[str, AttributeValue], usage: Usage, metrics: Metrics + ) -> None: + """Add optional usage and metrics attributes if they have values. + + Args: + attributes: Dictionary to add attributes to + usage: Token usage information from the model call + metrics: Metrics from the model call + """ + if "cacheReadInputTokens" in usage: + attributes["gen_ai.usage.cache_read_input_tokens"] = usage["cacheReadInputTokens"] + + if "cacheWriteInputTokens" in usage: + attributes["gen_ai.usage.cache_write_input_tokens"] = usage["cacheWriteInputTokens"] + + if metrics.get("timeToFirstByteMs", 0) > 0: + attributes["gen_ai.server.time_to_first_token"] = metrics["timeToFirstByteMs"] + + if metrics.get("latencyMs", 0) > 0: + attributes["gen_ai.server.request.duration"] = metrics["latencyMs"] + def _end_span( self, span: Span, @@ -277,7 +299,13 @@ def start_model_invoke_span( return span def end_model_invoke_span( - self, span: Span, message: Message, usage: Usage, stop_reason: StopReason, error: Optional[Exception] = None + self, + span: Span, + message: Message, + usage: Usage, + metrics: Metrics, + stop_reason: StopReason, + error: Optional[Exception] = None, ) -> None: """End a model invocation span with results and metrics. @@ -285,6 +313,7 @@ def end_model_invoke_span( span: The span to end. message: The message response from the model. usage: Token usage information from the model call. + metrics: Metrics from the model call. stop_reason (StopReason): The reason the model stopped generating. error: Optional exception if the model call failed. """ @@ -294,10 +323,11 @@ def end_model_invoke_span( "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.output_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], - "gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0), - "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), } + # Add optional attributes if they have values + self._add_optional_usage_and_metrics_attributes(attributes, usage, metrics) + if self.use_latest_genai_conventions: self._add_event( span, @@ -307,7 +337,7 @@ def end_model_invoke_span( [ { "role": message["role"], - "parts": [{"type": "text", "content": message["content"]}], + "parts": self._map_content_blocks_to_otel_parts(message["content"]), "finish_reason": str(stop_reason), } ] @@ -362,7 +392,7 @@ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None "type": "tool_call", "name": tool["name"], "id": tool["toolUseId"], - "arguments": [{"content": tool["input"]}], + "arguments": tool["input"], } ], } @@ -417,7 +447,7 @@ def end_tool_call_span( { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "result": tool_result.get("content"), + "response": tool_result.get("content"), } ], } @@ -504,7 +534,7 @@ def end_event_loop_cycle_span( [ { "role": tool_result_message["role"], - "parts": [{"type": "text", "content": tool_result_message["content"]}], + "parts": self._map_content_blocks_to_otel_parts(tool_result_message["content"]), } ] ) @@ -634,19 +664,23 @@ def start_multiagent_span( ) span = self._start_span(operation, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT) - content = serialize(task) if isinstance(task, list) else task if self.use_latest_genai_conventions: + parts: list[dict[str, Any]] = [] + if isinstance(task, list): + parts = self._map_content_blocks_to_otel_parts(task) + else: + parts = [{"type": "text", "content": task}] self._add_event( span, "gen_ai.client.inference.operation.details", - {"gen_ai.input.messages": serialize([{"role": "user", "parts": [{"type": "text", "content": task}]}])}, + {"gen_ai.input.messages": serialize([{"role": "user", "parts": parts}])}, ) else: self._add_event( span, "gen_ai.user.message", - event_attributes={"content": content}, + event_attributes={"content": serialize(task) if isinstance(task, list) else task}, ) return span @@ -718,7 +752,7 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: input_messages: list = [] for message in messages: input_messages.append( - {"role": message["role"], "parts": [{"type": "text", "content": message["content"]}]} + {"role": message["role"], "parts": self._map_content_blocks_to_otel_parts(message["content"])} ) self._add_event( span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize(input_messages)} @@ -731,6 +765,41 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: {"content": serialize(message["content"])}, ) + def _map_content_blocks_to_otel_parts(self, content_blocks: list[ContentBlock]) -> list[dict[str, Any]]: + """Map ContentBlock objects to OpenTelemetry parts format.""" + parts: list[dict[str, Any]] = [] + + for block in content_blocks: + if "text" in block: + # Standard TextPart + parts.append({"type": "text", "content": block["text"]}) + elif "toolUse" in block: + # Standard ToolCallRequestPart + tool_use = block["toolUse"] + parts.append( + { + "type": "tool_call", + "name": tool_use["name"], + "id": tool_use["toolUseId"], + "arguments": tool_use["input"], + } + ) + elif "toolResult" in block: + # Standard ToolCallResponsePart + tool_result = block["toolResult"] + parts.append( + { + "type": "tool_call_response", + "id": tool_result["toolUseId"], + "response": tool_result["content"], + } + ) + else: + # For all other ContentBlock types, use the key as type and value as content + for key, value in block.items(): + parts.append({"type": key, "content": value}) + return parts + # Singleton instance for global access _tracer_instance = None diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 2c240972b..f184f5e59 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -23,14 +23,17 @@ class Usage(TypedDict, total=False): cacheWriteInputTokens: int -class Metrics(TypedDict): +class Metrics(TypedDict, total=False): """Performance metrics for model interactions. Attributes: latencyMs (int): Latency of the model request in milliseconds. + timeToFirstByteMs (int): Latency from sending model request to first + content chunk (contentBlockDelta or contentBlockStart) from the model in milliseconds. """ - latencyMs: int + latencyMs: Required[int] + timeToFirstByteMs: int StopReason = Literal[ diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 1de957619..5afa0cb45 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -491,7 +491,7 @@ def test_extract_usage_metrics_with_cache_tokens(): "content": [], }, {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, + {"latencyMs": 0, "timeToFirstByteMs": 0}, ), }, ], @@ -781,7 +781,7 @@ async def test_stream_messages(agenerator, alist): "end_turn", {"role": "assistant", "content": [{"text": "test"}]}, {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, - {"latencyMs": 0}, + {"latencyMs": 0, "timeToFirstByteMs": 0}, ) }, ] diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 12db81908..e87277eed 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -109,6 +109,18 @@ def metrics(request): return Metrics(**params) +@pytest.fixture +def metrics_with_ttfb(request): + params = { + "latencyMs": 1, + "timeToFirstByteMs": 10, + } + if hasattr(request, "param"): + params.update(request.param) + + return Metrics(**params) + + @pytest.mark.parametrize("end_time", [None, 1]) @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") def test_trace_end(mock_time, end_time, trace): @@ -132,8 +144,8 @@ def mock_get_meter_provider(): mock_create_counter = mock.MagicMock() mock_meter.create_counter.return_value = mock_create_counter - mock_create_histogram = mock.MagicMock() - mock_meter.create_histogram.return_value = mock_create_histogram + # Create separate mock objects for each histogram call + mock_meter.create_histogram.side_effect = lambda *args, **kwargs: mock.MagicMock() meter_provider_mock.get_meter.return_value = mock_meter mock_get_meter_provider.return_value = meter_provider_mock @@ -326,9 +338,9 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met metrics_client.event_loop_cache_write_input_tokens.record.assert_called() -def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): +def test_event_loop_metrics_update_metrics(metrics_with_ttfb, event_loop_metrics, mock_get_meter_provider): for _ in range(3): - event_loop_metrics.update_metrics(metrics) + event_loop_metrics.update_metrics(metrics_with_ttfb) tru_metrics = event_loop_metrics.accumulated_metrics exp_metrics = Metrics( @@ -338,6 +350,7 @@ def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get assert tru_metrics == exp_metrics mock_get_meter_provider.return_value.get_meter.assert_called() event_loop_metrics._metrics_client.event_loop_latency.record.assert_called_with(1) + event_loop_metrics._metrics_client.model_time_to_first_token.record.assert_called_with(10) def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_get_meter_provider): diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 4e9872100..de677c2cc 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -11,7 +11,7 @@ from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.content import ContentBlock -from strands.types.streaming import StopReason, Usage +from strands.types.streaming import Metrics, StopReason, Usage @pytest.fixture(autouse=True) @@ -173,7 +173,15 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer): mock_span = mock.MagicMock() mock_tracer.start_span.return_value = mock_span - messages = [{"role": "user", "content": [{"text": "Hello"}]}] + messages = [ + {"role": "user", "content": [{"text": "Hello 2025-1993"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"input": '"expression": "2025-1993"', "name": "calculator", "toolUseId": "123"}} + ], + }, + ] model_id = "test-model" span = tracer.start_model_invoke_span(messages=messages, agent_name="TestAgent", model_id=model_id) @@ -191,8 +199,19 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer): [ { "role": messages[0]["role"], - "parts": [{"type": "text", "content": messages[0]["content"]}], - } + "parts": [{"type": "text", "content": "Hello 2025-1993"}], + }, + { + "role": messages[1]["role"], + "parts": [ + { + "type": "tool_call", + "name": "calculator", + "id": "123", + "arguments": '"expression": "2025-1993"', + } + ], + }, ] ) }, @@ -205,17 +224,18 @@ def test_end_model_invoke_span(mock_span): tracer = Tracer() message = {"role": "assistant", "content": [{"text": "Response"}]} usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) stop_reason: StopReason = "end_turn" - tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) + mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -231,17 +251,18 @@ def test_end_model_invoke_span_latest_conventions(mock_span): tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) stop_reason: StopReason = "end_turn" - tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) + mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -249,7 +270,7 @@ def test_end_model_invoke_span_latest_conventions(mock_span): [ { "role": "assistant", - "parts": [{"type": "text", "content": message["content"]}], + "parts": [{"type": "text", "content": "Response"}], "finish_reason": "end_turn", } ] @@ -318,7 +339,7 @@ def test_start_tool_call_span_latest_conventions(mock_tracer): "type": "tool_call", "name": tool["name"], "id": tool["toolUseId"], - "arguments": [{"content": tool["input"]}], + "arguments": tool["input"], } ], } @@ -398,7 +419,7 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer) "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": [{"text": "Original Task: foo bar"}]}]}] + [{"role": "user", "parts": [{"type": "text", "content": "Original Task: foo bar"}]}] ) }, ) @@ -486,7 +507,7 @@ def test_end_tool_call_span_latest_conventions(mock_span): """Test ending a tool call span with the latest semantic conventions.""" tracer = Tracer() tracer.use_latest_genai_conventions = True - tool_result = {"status": "success", "content": [{"text": "Tool result"}]} + tool_result = {"status": "success", "content": [{"text": "Tool result"}, {"json": {"foo": "bar"}}]} tracer.end_tool_call_span(mock_span, tool_result) @@ -502,7 +523,7 @@ def test_end_tool_call_span_latest_conventions(mock_span): { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "result": tool_result.get("content"), + "response": tool_result.get("content"), } ], } @@ -558,9 +579,7 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer): mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ - "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": messages[0]["content"]}]}] - ) + "gen_ai.input.messages": serialize([{"role": "user", "parts": [{"type": "text", "content": "Hello"}]}]) }, ) assert span is not None @@ -570,7 +589,12 @@ def test_end_event_loop_cycle_span(mock_span): """Test ending an event loop cycle span.""" tracer = Tracer() message = {"role": "assistant", "content": [{"text": "Response"}]} - tool_result_message = {"role": "assistant", "content": [{"toolResult": {"response": "Success"}}]} + tool_result_message = { + "role": "assistant", + "content": [ + {"toolResult": {"toolUseId": "123", "status": "success", "content": [{"text": "Weather is sunny"}]}} + ], + } tracer.end_event_loop_cycle_span(mock_span, message, tool_result_message) @@ -590,7 +614,12 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span): tracer = Tracer() tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} - tool_result_message = {"role": "assistant", "content": [{"toolResult": {"response": "Success"}}]} + tool_result_message = { + "role": "assistant", + "content": [ + {"toolResult": {"toolUseId": "123", "status": "success", "content": [{"text": "Weather is sunny"}]}} + ], + } tracer.end_event_loop_cycle_span(mock_span, message, tool_result_message) @@ -601,7 +630,13 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span): [ { "role": "assistant", - "parts": [{"type": "text", "content": tool_result_message["content"]}], + "parts": [ + { + "type": "tool_call_response", + "id": "123", + "response": [{"text": "Weather is sunny"}], + } + ], } ] ) @@ -676,7 +711,7 @@ def test_start_agent_span_latest_conventions(mock_tracer): "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": [{"text": "test prompt"}]}]}] + [{"role": "user", "parts": [{"type": "text", "content": "test prompt"}]}] ) }, ) @@ -766,8 +801,9 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): cacheWriteInputTokens=3, ) stop_reason: StopReason = "end_turn" + metrics = Metrics(latencyMs=10, timeToFirstByteMs=5) - tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) @@ -776,6 +812,8 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 5) mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 3) + mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 10) + mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 5) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() From c3e5f6b8e7d6846395cad9dc5684508f7702c6d9 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Tue, 14 Oct 2025 09:32:06 -0400 Subject: [PATCH 11/26] chore(telemetry): added gen_ai.tool.description and gen_ai.tool.json_schema (#1027) --- src/strands/tools/executors/_executor.py | 10 ++- .../strands/tools/executors/test_executor.py | 87 +++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index f78861f81..6c1bd4eb4 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -13,7 +13,7 @@ from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace -from ...telemetry.tracer import get_tracer +from ...telemetry.tracer import get_tracer, serialize from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse @@ -59,6 +59,14 @@ async def _stream( tool_info = agent.tool_registry.dynamic_tools.get(tool_name) tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) + tool_spec = tool_func.tool_spec if tool_func is not None else None + + current_span = trace_api.get_current_span() + if current_span and tool_spec is not None: + current_span.set_attribute("gen_ai.tool.description", tool_spec["description"]) + input_schema = tool_spec["inputSchema"] + if "json" in input_schema: + current_span.set_attribute("gen_ai.tool.json_schema", serialize(input_schema["json"])) invocation_state.update( { diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 2a0a44e10..81be34969 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -250,3 +250,90 @@ def cancel_callback(event): tru_results = tool_results exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_sets_span_attributes( + executor, agent, tool_results, invocation_state, weather_tool, alist +): + """Test that span attributes are set correctly when tool_spec is available.""" + with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api: + mock_span = unittest.mock.MagicMock() + mock_trace_api.get_current_span.return_value = mock_span + + # Mock tool_spec with inputSchema containing json field + with unittest.mock.patch.object( + type(weather_tool), "tool_spec", new_callable=unittest.mock.PropertyMock + ) as mock_tool_spec: + mock_tool_spec.return_value = { + "name": "weather_tool", + "description": "Get weather information", + "inputSchema": {"json": {"type": "object", "properties": {}}, "type": "object"}, + } + + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + await alist(stream) + + # Verify set_attribute was called with correct values + calls = mock_span.set_attribute.call_args_list + assert len(calls) == 2 + + # Check description attribute + assert calls[0][0][0] == "gen_ai.tool.description" + assert calls[0][0][1] == "Get weather information" + + # Check json_schema attribute + assert calls[1][0][0] == "gen_ai.tool.json_schema" + # The serialize function should have been called on the json field + + +@pytest.mark.asyncio +async def test_executor_stream_handles_missing_json_in_input_schema( + executor, agent, tool_results, invocation_state, weather_tool, alist +): + """Test that span attributes handle inputSchema without json field gracefully.""" + with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api: + mock_span = unittest.mock.MagicMock() + mock_trace_api.get_current_span.return_value = mock_span + + # Mock tool_spec with inputSchema but no json field + with unittest.mock.patch.object( + type(weather_tool), "tool_spec", new_callable=unittest.mock.PropertyMock + ) as mock_tool_spec: + mock_tool_spec.return_value = { + "name": "weather_tool", + "description": "Get weather information", + "inputSchema": {"type": "object", "properties": {}}, + } + + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + # Should not raise an error - json_schema attribute just won't be set + await alist(stream) + + # Verify only description attribute was set (not json_schema) + calls = mock_span.set_attribute.call_args_list + assert len(calls) == 1 + assert calls[0][0][0] == "gen_ai.tool.description" + + +@pytest.mark.asyncio +async def test_executor_stream_no_span_attributes_when_no_tool_spec( + executor, agent, tool_results, invocation_state, alist +): + """Test that no span attributes are set when tool_spec is None.""" + with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api: + mock_span = unittest.mock.MagicMock() + mock_trace_api.get_current_span.return_value = mock_span + + # Use unknown tool which will have no tool_spec + tool_use: ToolUse = {"name": "unknown_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + await alist(stream) + + # Verify set_attribute was not called since tool_spec is None + mock_span.set_attribute.assert_not_called() From 6cf4f7ead0e1b922c41ed61de4ceb377106a8c52 Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:17:36 +0400 Subject: [PATCH 12/26] fix(tool/decorator): validate ToolContext parameter name and raise clear error (#1028) --- src/strands/tools/decorator.py | 16 ++++++++++++++++ tests/strands/tools/test_decorator.py | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 99aa7e372..72109dbef 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -99,6 +99,8 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - self.type_hints = get_type_hints(func) self._context_param = context_param + self._validate_signature() + # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) @@ -111,6 +113,20 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _validate_signature(self) -> None: + """Verify that ToolContext is used correctly in the function signature.""" + for param in self.signature.parameters.values(): + if param.annotation is ToolContext: + if self._context_param is None: + raise ValueError("@tool(context) must be set if passing in ToolContext param") + + if param.name != self._context_param: + raise ValueError( + f"param_name=<{param.name}> | ToolContext param must be named '{self._context_param}'" + ) + # Found the parameter, no need to check further + break + def _create_input_model(self) -> Type[BaseModel]: """Create a Pydantic model from function signature for input validation. diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 5b4b5cdda..658a34052 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1363,3 +1363,27 @@ async def async_generator() -> AsyncGenerator: ] assert act_results == exp_results + + +def test_function_tool_metadata_validate_signature_default_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'tool_context'"): + + @strands.tool(context=True) + def my_tool(context: ToolContext): + pass + + +def test_function_tool_metadata_validate_signature_custom_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'my_context'"): + + @strands.tool(context="my_context") + def my_tool(tool_context: ToolContext): + pass + + +def test_function_tool_metadata_validate_signature_missing_context_config(): + with pytest.raises(ValueError, match=r"@tool\(context\) must be set if passing in ToolContext param"): + + @strands.tool + def my_tool(tool_context: ToolContext): + pass From f7931c5dc230f81b085601fb31c5fdc1dc40b7a0 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 14 Oct 2025 15:35:16 -0400 Subject: [PATCH 13/26] integ tests - fix flaky structured output test (#1030) --- tests_integ/models/providers.py | 2 +- tests_integ/models/test_conformance.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index c1f442b2a..75cc58f74 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -131,7 +131,7 @@ def __init__(self): id="gemini", environment_variable="GOOGLE_API_KEY", factory=lambda: GeminiModel( - api_key=os.getenv("GOOGLE_API_KEY"), + client_args={"api_key": os.getenv("GOOGLE_API_KEY")}, model_id="gemini-2.5-flash", params={"temperature": 0.7}, ), diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index eaef1eb88..4df6dd69b 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -57,6 +57,4 @@ class Weather(BaseModel): agent = Agent(model) result = agent.structured_output(Weather, "How are you?") - - assert len(result.time) > 0 - assert len(result.weather) > 0 + assert isinstance(result, Weather) From dbf6200d104539217dddfc7bd729c53f46e2ec56 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 15 Oct 2025 14:58:24 -0400 Subject: [PATCH 14/26] hooks - before tool call event - interrupt (#987) --- src/strands/agent/agent.py | 47 +++++ src/strands/agent/agent_result.py | 5 +- src/strands/agent/interrupt.py | 59 ++++++ src/strands/event_loop/event_loop.py | 53 ++++- src/strands/hooks/events.py | 18 +- src/strands/hooks/registry.py | 27 ++- src/strands/interrupt.py | 33 +++ .../session/repository_session_manager.py | 2 + src/strands/tools/executors/_executor.py | 21 +- src/strands/tools/executors/sequential.py | 12 +- src/strands/types/_events.py | 29 ++- src/strands/types/agent.py | 3 +- src/strands/types/event_loop.py | 2 + src/strands/types/interrupt.py | 181 +++++++++++++++++ src/strands/types/session.py | 26 ++- tests/strands/agent/test_agent.py | 128 ++++++++++++ tests/strands/agent/test_agent_hooks.py | 15 +- tests/strands/agent/test_interrupt.py | 61 ++++++ tests/strands/event_loop/test_event_loop.py | 162 ++++++++++++++- tests/strands/hooks/__init__.py | 0 tests/strands/hooks/test_registry.py | 73 +++++++ .../test_repository_session_manager.py | 3 + tests/strands/test_interrupt.py | 24 +++ tests/strands/tools/executors/conftest.py | 2 + .../tools/executors/test_concurrent.py | 42 +++- .../strands/tools/executors/test_executor.py | 72 ++++++- .../tools/executors/test_sequential.py | 35 +++- tests/strands/types/__init__.py | 0 tests/strands/types/test_interrupt.py | 80 ++++++++ tests/strands/types/test_session.py | 38 ++++ tests_integ/test_interrupt.py | 192 ++++++++++++++++++ 31 files changed, 1401 insertions(+), 44 deletions(-) create mode 100644 src/strands/agent/interrupt.py create mode 100644 src/strands/interrupt.py create mode 100644 src/strands/types/interrupt.py create mode 100644 tests/strands/agent/test_interrupt.py create mode 100644 tests/strands/hooks/__init__.py create mode 100644 tests/strands/hooks/test_registry.py create mode 100644 tests/strands/test_interrupt.py create mode 100644 tests/strands/types/__init__.py create mode 100644 tests/strands/types/test_interrupt.py create mode 100644 tests_integ/test_interrupt.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8607a2601..f963f14e7 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -55,6 +55,7 @@ from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException +from ..types.interrupt import InterruptResponseContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -62,6 +63,7 @@ ConversationManager, SlidingWindowConversationManager, ) +from .interrupt import InterruptState from .state import AgentState logger = logging.getLogger(__name__) @@ -143,6 +145,9 @@ def caller( Raises: AttributeError: If the tool doesn't exist. """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + normalized_name = self._find_normalized_tool_name(name) # Create unique tool ID and set up the tool request @@ -338,6 +343,8 @@ def __init__( self.hooks = HookRegistry() + self._interrupt_state = InterruptState() + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -491,6 +498,9 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu Raises: ValueError: If no conversation history or prompt is provided. """ + if self._interrupt_state.activated: + raise RuntimeError("cannot call structured output during interrupt") + self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT @@ -573,6 +583,8 @@ async def stream_async( yield event["data"] ``` """ + self._resume_interrupt(prompt) + merged_state = {} if kwargs: warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) @@ -614,6 +626,38 @@ async def stream_async( self._end_agent_trace_span(error=e) raise + def _resume_interrupt(self, prompt: AgentInput) -> None: + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ + if not self._interrupt_state.activated: + return + + if not isinstance(prompt, list): + raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") + + invalid_types = [ + content_type for content in prompt for content_type in content if content_type != "interruptResponse" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" + ) + + for content in cast(list[InterruptResponseContent], prompt): + interrupt_id = content["interruptResponse"]["interruptId"] + interrupt_response = content["interruptResponse"]["response"] + + if interrupt_id not in self._interrupt_state.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + self._interrupt_state.interrupts[interrupt_id].response = interrupt_response + async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -689,6 +733,9 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A yield event def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: + if self._interrupt_state.activated: + return [] + messages: Messages | None = None if prompt is not None: if isinstance(prompt, str): diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f3758c8d2..eb9bc4dd9 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,8 +4,9 @@ """ from dataclasses import dataclass -from typing import Any +from typing import Any, Sequence +from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message from ..types.streaming import StopReason @@ -20,12 +21,14 @@ class AgentResult: message: The last message generated by the agent. metrics: Performance metrics collected during processing. state: Additional state information from the event loop. + interrupts: List of interrupts if raised by user. """ stop_reason: StopReason message: Message metrics: EventLoopMetrics state: Any + interrupts: Sequence[Interrupt] | None = None def __str__(self) -> str: """Get the agent's last message as a string. diff --git a/src/strands/agent/interrupt.py b/src/strands/agent/interrupt.py new file mode 100644 index 000000000..3cec1541b --- /dev/null +++ b/src/strands/agent/interrupt.py @@ -0,0 +1,59 @@ +"""Track the state of interrupt events raised by the user for human-in-the-loop workflows.""" + +from dataclasses import asdict, dataclass, field +from typing import Any + +from ..interrupt import Interrupt + + +@dataclass +class InterruptState: + """Track the state of interrupt events raised by the user. + + Note, interrupt state is cleared after resuming. + + Attributes: + interrupts: Interrupts raised by the user. + context: Additional context associated with an interrupt event. + activated: True if agent is in an interrupt state, False otherwise. + """ + + interrupts: dict[str, Interrupt] = field(default_factory=dict) + context: dict[str, Any] = field(default_factory=dict) + activated: bool = False + + def activate(self, context: dict[str, Any] | None = None) -> None: + """Activate the interrupt state. + + Args: + context: Context associated with the interrupt event. + """ + self.context = context or {} + self.activated = True + + def deactivate(self) -> None: + """Deacitvate the interrupt state. + + Interrupts and context are cleared. + """ + self.interrupts = {} + self.context = {} + self.activated = False + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InterruptState": + """Initiailize interrupt state from serialized interrupt state. + + Interrupt state can be serialized with the `to_dict` method. + """ + return cls( + interrupts={ + interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() + }, + context=data["context"], + activated=data["activated"], + ) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index feb6ac339..7a9c60c3b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -27,6 +27,7 @@ ModelStopReason, StartEvent, StartEventLoopEvent, + ToolInterruptEvent, ToolResultMessageEvent, TypedEvent, ) @@ -106,13 +107,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) invocation_state["event_loop_cycle_span"] = cycle_span - model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + else: + model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) try: if stop_reason == "max_tokens": @@ -142,6 +149,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_span=cycle_span, cycle_start_time=cycle_start_time, invocation_state=invocation_state, + tracer=tracer, ) async for tool_event in tool_events: yield tool_event @@ -345,6 +353,7 @@ async def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, invocation_state: dict[str, Any], + tracer: Tracer, ) -> AsyncGenerator[TypedEvent, None]: """Handles the execution of tools requested by the model during an event loop cycle. @@ -356,6 +365,7 @@ async def _handle_tool_execution( cycle_span: Span object for tracing the cycle (type may vary). cycle_start_time: Start time of the current cycle. invocation_state: Additional keyword arguments, including request state. + tracer: Tracer instance for span management. Yields: Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple @@ -375,15 +385,45 @@ async def _handle_tool_execution( yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return + if agent._interrupt_state.activated: + tool_results.extend(agent._interrupt_state.context["tool_results"]) + + # Filter to only the interrupted tools when resuming from interrupt (tool uses without results) + tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results} + tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] + + interrupts = [] tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state ) async for tool_event in tool_events: + if isinstance(tool_event, ToolInterruptEvent): + interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"]) + yield tool_event # Store parent cycle ID for the next cycle invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] + if interrupts: + # Session state stored on AfterInvocationEvent. + agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results}) + + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield EventLoopStopEvent( + "interrupt", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + interrupts, + ) + if cycle_span: + tracer.end_event_loop_cycle_span(span=cycle_span, message=message) + + return + + agent._interrupt_state.deactivate() + tool_result_message: Message = { "role": "user", "content": [{"toolResult": result} for result in tool_results], @@ -394,7 +434,6 @@ async def _handle_tool_execution( yield ToolResultMessageEvent(message=tool_result_message) if cycle_span: - tracer = get_tracer() tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) if invocation_state["request_state"].get("stop_event_loop", False): diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8f611e4e2..de07002c5 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -3,10 +3,14 @@ This module defines the events that are emitted as Agents run through the lifecycle of a request. """ +import uuid from dataclasses import dataclass from typing import Any, Optional +from typing_extensions import override + from ..types.content import Message +from ..types.interrupt import InterruptHookEvent from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -84,7 +88,7 @@ class MessageAddedEvent(HookEvent): @dataclass -class BeforeToolCallEvent(HookEvent): +class BeforeToolCallEvent(HookEvent, InterruptHookEvent): """Event triggered before a tool is invoked. This event is fired just before the agent executes a tool, allowing hook @@ -110,6 +114,18 @@ class BeforeToolCallEvent(HookEvent): def _can_write(self, name: str) -> bool: return name in ["cancel_tool", "selected_tool", "tool_use"] + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + @dataclass class AfterToolCallEvent(HookEvent): diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index b8e7f82ab..1cfd5c63e 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -10,6 +10,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from ..interrupt import Interrupt, InterruptException + if TYPE_CHECKING: from ..agent import Agent @@ -184,7 +186,7 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) - def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: + def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: """Invoke all registered callbacks for the given event. This method finds all callbacks registered for the event's type and @@ -192,11 +194,16 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: callbacks are invoked in reverse registration order. Any exceptions raised by callback functions will propagate to the caller. + Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows. + Args: event: The event to dispatch to registered callbacks. Returns: - The event dispatched to registered callbacks. + The event dispatched to registered callbacks and any interrupts raised by the user. + + Raises: + ValueError: If interrupt name is used more than once. Example: ```python @@ -204,10 +211,22 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: registry.invoke_callbacks(event) ``` """ + interrupts: dict[str, Interrupt] = {} + for callback in self.get_callbacks_for(event): - callback(event) + try: + callback(event) + except InterruptException as exception: + interrupt = exception.interrupt + if interrupt.name in interrupts: + raise ValueError( + f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" + ) from exception + + # Each callback is allowed to raise their own interrupt. + interrupts[interrupt.name] = interrupt - return event + return event, list(interrupts.values()) def has_callbacks(self) -> bool: """Check if the registry has any registered callbacks. diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py new file mode 100644 index 000000000..f0ed52389 --- /dev/null +++ b/src/strands/interrupt.py @@ -0,0 +1,33 @@ +"""Human-in-the-loop interrupt system for agent workflows.""" + +from dataclasses import asdict, dataclass +from typing import Any + + +@dataclass +class Interrupt: + """Represents an interrupt that can pause agent execution for human-in-the-loop workflows. + + Attributes: + id: Unique identifier. + name: User defined name. + reason: User provided reason for raising the interrupt. + response: Human response provided when resuming the agent after an interrupt. + """ + + id: str + name: str + reason: Any = None + response: Any = None + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + +class InterruptException(Exception): + """Exception raised when human input is required.""" + + def __init__(self, interrupt: Interrupt) -> None: + """Set the interrupt.""" + self.interrupt = interrupt diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 75058b251..e5075de93 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -132,6 +132,8 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: ) agent.state = AgentState(session_agent.state) + session_agent.initialize_internal_state(agent) + # Restore the conversation manager to its previous state, and get the optional prepend messages prepend_messages = agent.conversation_manager.restore_from_session(session_agent.conversation_manager_state) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 6c1bd4eb4..a4f43b149 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -14,7 +14,7 @@ from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer, serialize -from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent +from ...types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse @@ -43,6 +43,7 @@ async def _stream( - Before/after hook execution - Tracing and metrics collection - Error handling and recovery + - Interrupt handling for human-in-the-loop workflows Args: agent: The agent for which the tool is being executed. @@ -80,7 +81,7 @@ async def _stream( } ) - before_event = agent.hooks.invoke_callbacks( + before_event, interrupts = agent.hooks.invoke_callbacks( BeforeToolCallEvent( agent=agent, selected_tool=tool_func, @@ -89,6 +90,10 @@ async def _stream( ) ) + if interrupts: + yield ToolInterruptEvent(tool_use, interrupts) + return + if before_event.cancel_tool: cancel_message = ( before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" @@ -100,7 +105,7 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, tool_use=tool_use, @@ -138,7 +143,7 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -169,7 +174,7 @@ async def _stream( result = cast(ToolResult, event) - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -189,7 +194,7 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -238,6 +243,10 @@ async def _stream_with_trace( async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): yield event + if isinstance(event, ToolInterruptEvent): + tracer.end_tool_call_span(tool_call_span, tool_result=None) + return + result_event = cast(ToolResultEvent, event) result = result_event.tool_result diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 60e5c7fa7..adbd5a5d3 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -5,7 +5,7 @@ from typing_extensions import override from ...telemetry.metrics import Trace -from ...types._events import TypedEvent +from ...types._events import ToolInterruptEvent, TypedEvent from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor @@ -28,6 +28,8 @@ async def _execute( ) -> AsyncGenerator[TypedEvent, None]: """Execute tools sequentially. + Breaks early if an interrupt is raised by the user. + Args: agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. @@ -39,9 +41,17 @@ async def _execute( Yields: Events from the tool execution stream. """ + interrupted = False + for tool_use in tool_uses: events = ToolExecutor._stream_with_trace( agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state ) async for event in events: + if isinstance(event, ToolInterruptEvent): + interrupted = True + yield event + + if interrupted: + break diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index e20bf658a..13d4a98f9 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,10 +5,11 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Sequence, cast from typing_extensions import override +from ..interrupt import Interrupt from ..telemetry import EventLoopMetrics from .citations import Citation from .content import Message @@ -220,6 +221,7 @@ def __init__( message: Message, metrics: "EventLoopMetrics", request_state: Any, + interrupts: Sequence[Interrupt] | None = None, ) -> None: """Initialize with the final execution results. @@ -228,8 +230,9 @@ def __init__( message: Final message from the model metrics: Execution metrics and performance data request_state: Final state of the agent execution + interrupts: Interrupts raised by user during agent execution. """ - super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts)}) @property @override @@ -313,12 +316,30 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancelled_event")).get("tool_use")).get("toolUseId")) + return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId")) @property def message(self) -> str: """The tool cancellation message.""" - return cast(str, self["message"]) + return cast(str, self["tool_cancel_event"]["message"]) + + +class ToolInterruptEvent(TypedEvent): + """Event emitted when a tool is interrupted.""" + + def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: + """Set interrupt in the event payload.""" + super().__init__({"tool_interrupt_event": {"tool_use": tool_use, "interrupts": interrupts}}) + + @property + def tool_use_id(self) -> str: + """The id of the tool interrupted.""" + return cast(str, cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId")) + + @property + def interrupts(self) -> list[Interrupt]: + """The interrupt instances.""" + return cast(list[Interrupt], self["tool_interrupt_event"]["interrupts"]) class ModelMessageEvent(TypedEvent): diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index 151c88f89..a2a4c7dce 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -6,5 +6,6 @@ from typing import TypeAlias from .content import ContentBlock, Messages +from .interrupt import InterruptResponse -AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None +AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponse] | Messages | None diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index f184f5e59..2a7ad344e 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -40,6 +40,7 @@ class Metrics(TypedDict, total=False): "content_filtered", "end_turn", "guardrail_intervened", + "interrupt", "max_tokens", "stop_sequence", "tool_use", @@ -49,6 +50,7 @@ class Metrics(TypedDict, total=False): - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened +- "interrupt": Agent was interrupted for human input - "max_tokens": Maximum token limit reached - "stop_sequence": Stop sequence encountered - "tool_use": Model requested to use a tool diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py new file mode 100644 index 000000000..4e9584a70 --- /dev/null +++ b/src/strands/types/interrupt.py @@ -0,0 +1,181 @@ +"""Interrupt related type definitions for human-in-the-loop workflows. + +Interrupt Flow: + ┌─────────────────┐ + │ Agent Invoke │ + └────────┬────────┘ + │ + ▼ + ┌─────────────────┐ + │ Hook Calls │ + | on Event | + └────────┬────────┘ + │ + ▼ + ┌─────────────────┐ No ┌─────────────────┐ + │ Interrupts │ ────────► │ Continue │ + │ Raised? │ │ Execution │ + └────────┬────────┘ └─────────────────┘ + │ Yes + ▼ + ┌─────────────────┐ + │ Stop Event Loop │◄───────────────────┐ + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ | + │ Return | | + | Interrupts │ | + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ | + │ Agent Invoke │ | + │ with Responses │ | + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ | + │ Hook Calls │ | + | on Event | | + | with Responses | | + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ Yes ┌────────┴────────┐ + │ New Interrupts │ ────────► │ Store State │ + │ Raised? │ │ │ + └────────┬────────┘ └─────────────────┘ + │ No + ▼ + ┌─────────────────┐ + │ Continue │ + │ Execution │ + └─────────────────┘ + +Example: + ``` + from typing import Any + + from strands import Agent, tool + from strands.hooks import BeforeToolCallEvent, HookProvider, HookRegistry + + + @tool + def delete_tool(key: str) -> bool: + print("DELETE_TOOL | deleting") + return True + + + class ToolInterruptHook(HookProvider): + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + registry.add_callback(BeforeToolCallEvent, self.approve) + + def approve(self, event: BeforeToolCallEvent) -> None: + if event.tool_use["name"] != "delete_tool": + return + + approval = event.interrupt("for_delete_tool", reason="APPROVAL") + if approval != "A": + event.cancel_tool = "approval was not granted" + + agent = Agent( + hooks=[ToolInterruptHook()], + tools=[delete_tool], + system_prompt="You delete objects given their keys.", + callback_handler=None, + ) + result = agent(f"delete object with key 'X'") + + if result.stop_reason == "interrupt": + responses = [] + for interrupt in result.interrupts: + if interrupt.name == "for_delete_tool": + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "A"}) + + result = agent(responses) + + ... + ``` + +Details: + + - User raises interrupt on their hook event by calling `event.interrupt()`. + - User can raise one interrupt per hook callback. + - Interrupts stop the agent event loop. + - Interrupts are returned to the user in AgentResult. + - User resumes by invoking agent with interrupt responses. + - Second call to `event.interrupt()` returns user response. + - Process repeats if user raises additional interrupts. + - Interrupts are session managed in-between return and user response. +""" + +from typing import TYPE_CHECKING, Any, Protocol, TypedDict + +from ..interrupt import Interrupt, InterruptException + +if TYPE_CHECKING: + from ..agent import Agent + + +class InterruptHookEvent(Protocol): + """Interface that adds interrupt support to hook events.""" + + agent: "Agent" + + def interrupt(self, name: str, reason: Any = None, response: Any = None) -> Any: + """Trigger the interrupt with a reason. + + Args: name: User defined name for the interrupt. + Must be unique across hook callbacks. + reason: User provided reason for the interrupt. + response: Preemptive response from user if available. + + Returns: + The response from a human user when resuming from an interrupt state. + + Raises: + InterruptException: If human input is required. + """ + id = self._interrupt_id(name) + state = self.agent._interrupt_state + + interrupt_ = state.interrupts.setdefault(id, Interrupt(id, name, reason, response)) + if interrupt_.response: + return interrupt_.response + + raise InterruptException(interrupt_) + + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + reason: User provided reason for the interrupt. + + Returns: + Interrupt id. + """ + ... + + +class InterruptResponse(TypedDict): + """User response to an interrupt. + + Attributes: + interruptId: Unique identifier for the interrupt. + response: User response to the interrupt. + """ + + interruptId: str + response: Any + + +class InterruptResponseContent(TypedDict): + """Content block containing a user response to an interrupt. + + Attributes: + interruptResponse: User response to an interrupt event. + """ + + interruptResponse: InterruptResponse diff --git a/src/strands/types/session.py b/src/strands/types/session.py index e51816f74..926480f2c 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,8 +5,9 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional +from ..agent.interrupt import InterruptState from .content import Message if TYPE_CHECKING: @@ -104,11 +105,20 @@ def to_dict(self) -> dict[str, Any]: @dataclass class SessionAgent: - """Agent that belongs to a Session.""" + """Agent that belongs to a Session. + + Attributes: + agent_id: Unique id for the agent. + state: User managed state. + conversation_manager_state: State for conversation management. + created_at: Created at time. + updated_at: Updated at time. + """ agent_id: str - state: Dict[str, Any] - conversation_manager_state: Dict[str, Any] + state: dict[str, Any] + conversation_manager_state: dict[str, Any] + _internal_state: dict[str, Any] = field(default_factory=dict) # Strands managed state created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -121,6 +131,9 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": agent_id=agent.agent_id, conversation_manager_state=agent.conversation_manager.get_state(), state=agent.state.get(), + _internal_state={ + "interrupt_state": agent._interrupt_state.to_dict(), + }, ) @classmethod @@ -132,6 +145,11 @@ def to_dict(self) -> dict[str, Any]: """Convert the SessionAgent to a dictionary representation.""" return asdict(self) + def initialize_internal_state(self, agent: "Agent") -> None: + """Initialize internal state of agent.""" + if "interrupt_state" in self._internal_state: + agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"]) + @dataclass class Session: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 200584115..ae2d8c7b5 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -17,6 +17,8 @@ from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize @@ -1933,3 +1935,129 @@ async def check_invocation_state(**kwargs): agent("hello!", invocation_state={"my": "state"}) assert len(captured_warnings) == 0 + + +def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator): + tool_use_message = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_decorated", + "input": {"random_string": "test input"}, + } + }, + ], + } + agent = Agent( + messages=[tool_use_message], + model=mock_model, + tools=[tool_decorated], + ) + + interrupt = Interrupt( + id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": []}) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + mock_model.mock_stream.return_value = agenerator( + [ + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "resumed"}}}, + {"contentBlockStop": {}}, + ] + ) + + prompt = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test response", + } + } + ] + agent(prompt) + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "test input"}], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state + + +def test_agent__call__resume_interrupt_invalid_prompt(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"prompt_type= \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + agent("invalid") + + +def test_agent__call__resume_interrupt_invalid_content(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"content_types=<\['text'\]> \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + agent([{"text": "invalid"}]) + + +def test_agent__call__resume_interrupt_invalid_id(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"interrupt_id= \| no interrupt found" + with pytest.raises(KeyError, match=exp_message): + agent([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) + + +def test_agent_structured_output_interrupt(user): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"cannot call structured output during interrupt" + with pytest.raises(RuntimeError, match=exp_message): + agent.structured_output(type(user), "invalid") + + +def test_agent_tool_caller_interrupt(user): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"cannot directly call tool during interrupt" + with pytest.raises(RuntimeError, match=exp_message): + agent.tool.test_tool() diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 6c5625e0b..32266c3eb 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -124,7 +124,10 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): assert length == 6 assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -170,7 +173,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -231,7 +237,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, diff --git a/tests/strands/agent/test_interrupt.py b/tests/strands/agent/test_interrupt.py new file mode 100644 index 000000000..e248c29a6 --- /dev/null +++ b/tests/strands/agent/test_interrupt.py @@ -0,0 +1,61 @@ +import pytest + +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt + + +@pytest.fixture +def interrupt(): + return Interrupt(id="test_id", name="test_name", reason="test reason") + + +def test_interrupt_activate(): + interrupt_state = InterruptState() + + interrupt_state.activate(context={"test": "context"}) + + assert interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {"test": "context"} + assert tru_context == exp_context + + +def test_interrupt_deactivate(): + interrupt_state = InterruptState(context={"test": "context"}, activated=True) + + interrupt_state.deactivate() + + assert not interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {} + assert tru_context == exp_context + + +def test_interrupt_state_to_dict(interrupt): + interrupt_state = InterruptState(interrupts={"test_id": interrupt}, context={"test": "context"}, activated=True) + + tru_data = interrupt_state.to_dict() + exp_data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + assert tru_data == exp_data + + +def test_interrupt_state_from_dict(): + data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + + tru_state = InterruptState.from_dict(data) + exp_state = InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + assert tru_state == exp_state diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 2b71f3502..89ef477fa 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,12 +6,15 @@ import strands import strands.telemetry +from strands.agent.interrupt import InterruptState from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, + BeforeToolCallEvent, HookRegistry, MessageAddedEvent, ) +from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry @@ -138,6 +141,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor + mock._interrupt_state = InterruptState() return mock @@ -169,7 +173,7 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -201,7 +205,7 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -239,7 +243,7 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -330,7 +334,7 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -445,7 +449,7 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -747,7 +751,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _ = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -759,7 +763,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _ = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state @@ -862,3 +866,147 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, assert next(events) == MessageAddedEvent( agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} ) + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt(agent, model, tool_stream, agenerator, alist): + def interrupt_callback(event): + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator(tool_stream)] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, tru_interrupts = events[-1]["stop"] + exp_stop_reason = "interrupt" + exp_interrupts = [ + Interrupt( + id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ), + ] + + assert tru_stop_reason == exp_stop_reason and tru_interrupts == exp_interrupts + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": True, + "context": { + "tool_results": [], + "tool_use_message": { + "content": [ + { + "toolUse": { + "input": {"random_string": "abcdEfghI123"}, + "name": "tool_for_testing", + "toolUseId": "t1", + }, + }, + ], + "role": "assistant", + }, + }, + "interrupts": { + "v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9": { + "id": "v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + "name": "test_name", + "reason": "test reason", + "response": None, + }, + }, + } + assert tru_state == exp_state + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt_resume(agent, model, tool, tool_times_2, agenerator, alist): + interrupt = Interrupt( + id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + + tool_use_message = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_for_testing", + "input": {"random_string": "test input"}, + } + }, + { + "toolUse": { + "toolUseId": "t2", + "name": "tool_times_2", + "input": {}, + } + }, + ], + } + tool_results = [ + { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + ] + + agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": tool_results}) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator([{"contentBlockStop": {}}])] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, _ = events[-1]["stop"] + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + }, + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "test input"}], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state diff --git a/tests/strands/hooks/__init__.py b/tests/strands/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py new file mode 100644 index 000000000..807011869 --- /dev/null +++ b/tests/strands/hooks/test_registry.py @@ -0,0 +1,73 @@ +import unittest.mock + +import pytest + +from strands.agent.interrupt import InterruptState +from strands.hooks import BeforeToolCallEvent, HookRegistry +from strands.interrupt import Interrupt + + +@pytest.fixture +def registry(): + return HookRegistry() + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupt_state = InterruptState() + return instance + + +def test_hook_registry_invoke_callbacks_interrupt(registry, agent): + event = BeforeToolCallEvent( + agent=agent, + selected_tool=None, + tool_use={"toolUseId": "test_tool_id", "name": "test_tool_name", "input": {}}, + invocation_state={}, + ) + + callback1 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name_1", "test reason 1")) + callback2 = unittest.mock.Mock() + callback3 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name_2", "test reason 2")) + + registry.add_callback(BeforeToolCallEvent, callback1) + registry.add_callback(BeforeToolCallEvent, callback2) + registry.add_callback(BeforeToolCallEvent, callback3) + + _, tru_interrupts = registry.invoke_callbacks(event) + exp_interrupts = [ + Interrupt( + id="v1:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", + name="test_name_1", + reason="test reason 1", + ), + Interrupt( + id="v1:test_tool_id:0f5a8068-d1ba-5a48-bf67-c9d33786d8d4", + name="test_name_2", + reason="test reason 2", + ), + ] + assert tru_interrupts == exp_interrupts + + callback1.assert_called_once_with(event) + callback2.assert_called_once_with(event) + callback3.assert_called_once_with(event) + + +def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): + event = BeforeToolCallEvent( + agent=agent, + selected_tool=None, + tool_use={"toolUseId": "test_tool_id", "name": "test_tool_name", "input": {}}, + invocation_state={}, + ) + + callback1 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name", "test reason 1")) + callback2 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name", "test reason 2")) + + registry.add_callback(BeforeToolCallEvent, callback1) + registry.add_callback(BeforeToolCallEvent, callback2) + + with pytest.raises(ValueError, match="interrupt_name= | interrupt name used more than once"): + registry.invoke_callbacks(event) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 2c25fcc38..923b13daa 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -5,6 +5,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.agent.interrupt import InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException @@ -95,6 +96,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): agent_id="existing-agent", state={"key": "value"}, conversation_manager_state=SlidingWindowConversationManager().get_state(), + _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, ) session_manager.session_repository.create_agent("test-session", session_agent) @@ -116,6 +118,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert len(agent.messages) == 1 assert agent.messages[0]["role"] == "user" assert agent.messages[0]["content"][0]["text"] == "Hello" + assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False) def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py new file mode 100644 index 000000000..8ce972103 --- /dev/null +++ b/tests/strands/test_interrupt.py @@ -0,0 +1,24 @@ +import pytest + +from strands.interrupt import Interrupt + + +@pytest.fixture +def interrupt(): + return Interrupt( + id="test_id:test_name", + name="test_name", + reason={"reason": "test"}, + response={"response": "test"}, + ) + + +def test_interrupt_to_dict(interrupt): + tru_dict = interrupt.to_dict() + exp_dict = { + "id": "test_id:test_name", + "name": "test_name", + "reason": {"reason": "test"}, + "response": {"response": "test"}, + } + assert tru_dict == exp_dict diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index be90226f6..fa8ce10af 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,6 +4,7 @@ import pytest import strands +from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry from strands.tools.registry import ToolRegistry @@ -92,6 +93,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry + mock_agent._interrupt_state = InterruptState() return mock_agent diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index f7fc64b25..7264c8e58 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,8 +1,9 @@ import pytest +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.tools.executors import ConcurrentToolExecutor -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolUse +from strands.types._events import ToolInterruptEvent, ToolResultEvent @pytest.fixture @@ -14,7 +15,7 @@ def executor(): async def test_concurrent_executor_execute( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): - tool_uses: list[ToolUse] = [ + tool_uses = [ {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] @@ -30,3 +31,38 @@ async def test_concurrent_executor_execute( tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_concurrent_executor_interrupt( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + interrupt = Interrupt( + id="v1:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + if event.tool_use["name"] == "weather_tool": + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + tool_uses = [ + {"name": "weather_tool", "toolUseId": "test_tool_id_1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, + ] + + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) + exp_events = [ + ToolInterruptEvent(tool_uses[0], [interrupt]), + ToolResultEvent({"toolUseId": "test_tool_id_2", "status": "success", "content": [{"text": "75F"}]}), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[1].tool_result] + assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 81be34969..fd15c9747 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -5,9 +5,10 @@ import strands from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor -from strands.types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import ToolUse @@ -36,6 +37,7 @@ async def test_executor_stream_yields_result( executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist ): tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) @@ -337,3 +339,71 @@ async def test_executor_stream_no_span_attributes_when_no_tool_spec( # Verify set_attribute was not called since tool_spec is None mock_span.set_attribute.assert_not_called() + + +@pytest.mark.asyncio +async def test_executor_stream_interrupt(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + event.interrupt("test_name", reason="test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_use, [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_interrupt_resume(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", reason="test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "test_tool_id", + "status": "success", + "content": [{"text": "sunny"}], + }, + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index 37e098142..c1db3cd55 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,9 @@ import pytest +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent +from strands.types._events import ToolInterruptEvent, ToolResultEvent @pytest.fixture @@ -29,3 +31,34 @@ async def test_sequential_executor_execute( tru_results = tool_results exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_sequential_executor_interrupt( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + interrupt = Interrupt( + id="v1:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + tool_uses = [ + {"name": "weather_tool", "toolUseId": "test_tool_id_1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, + ] + + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_uses[0], [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results diff --git a/tests/strands/types/__init__.py b/tests/strands/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py new file mode 100644 index 000000000..3b970a00a --- /dev/null +++ b/tests/strands/types/test_interrupt.py @@ -0,0 +1,80 @@ +import unittest.mock + +import pytest + +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt, InterruptException +from strands.types.interrupt import InterruptHookEvent + + +@pytest.fixture +def interrupt(): + return Interrupt( + id="test_id:test_name", + name="test_name", + reason={"reason": "test"}, + response={"response": "test"}, + ) + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupt_state = InterruptState() + return instance + + +@pytest.fixture +def interrupt_hook_event(agent): + class Event(InterruptHookEvent): + def __init__(self): + self.agent = agent + + def _interrupt_id(self, name): + return f"test_id:{name}" + + return Event() + + +def test_interrupt_hook_event_interrupt(interrupt_hook_event): + with pytest.raises(InterruptException) as exception: + interrupt_hook_event.interrupt("custom_test_name", "custom test reason") + + tru_interrupt = exception.value.interrupt + exp_interrupt = Interrupt( + id="test_id:custom_test_name", + name="custom_test_name", + reason="custom test reason", + ) + assert tru_interrupt == exp_interrupt + + +def test_interrupt_hook_event_interrupt_state(agent, interrupt_hook_event): + with pytest.raises(InterruptException): + interrupt_hook_event.interrupt("custom_test_name", "custom test reason") + + exp_interrupt = Interrupt( + id="test_id:custom_test_name", + name="custom_test_name", + reason="custom test reason", + ) + assert exp_interrupt.id in agent._interrupt_state.interrupts + + tru_interrupt = agent._interrupt_state.interrupts[exp_interrupt.id] + assert tru_interrupt == exp_interrupt + + +def test_interrupt_hook_event_interrupt_response(interrupt, agent, interrupt_hook_event): + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + tru_response = interrupt_hook_event.interrupt("test_name") + exp_response = {"response": "test"} + assert tru_response == exp_response + + +def test_interrupt_hook_event_interrupt_response_empty(interrupt, agent, interrupt_hook_event): + interrupt.response = None + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + with pytest.raises(InterruptException): + interrupt_hook_event.interrupt("test_name") diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index c39615c32..26d4062e4 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -1,7 +1,10 @@ import json +import unittest.mock from uuid import uuid4 from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager +from strands.agent.interrupt import InterruptState +from strands.agent.state import AgentState from strands.types.session import ( Session, SessionAgent, @@ -91,3 +94,38 @@ def test_session_message_with_bytes(): assert original_message["role"] == message["role"] assert original_message["content"][0]["text"] == message["content"][0]["text"] assert original_message["content"][1]["binary_data"] == message["content"][1]["binary_data"] + + +def test_session_agent_from_agent(): + agent = unittest.mock.Mock() + agent.agent_id = "a1" + agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) + agent.state = AgentState({"test": "state"}) + agent._interrupt_state = InterruptState(interrupts={}, context={}, activated=False) + + tru_session_agent = SessionAgent.from_agent(agent) + exp_session_agent = SessionAgent( + agent_id="a1", + conversation_manager_state={"test": "conversation"}, + state={"test": "state"}, + _internal_state={"interrupt_state": {"interrupts": {}, "context": {}, "activated": False}}, + created_at=unittest.mock.ANY, + updated_at=unittest.mock.ANY, + ) + assert tru_session_agent == exp_session_agent + + +def test_session_agent_initialize_internal_state(): + agent = unittest.mock.Mock() + session_agent = SessionAgent( + agent_id="a1", + conversation_manager_state={}, + state={}, + _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, + ) + + session_agent.initialize_internal_state(agent) + + tru_interrupt_state = agent._interrupt_state + exp_interrupt_state = InterruptState(interrupts={}, context={"test": "init"}, activated=False) + assert tru_interrupt_state == exp_interrupt_state diff --git a/tests_integ/test_interrupt.py b/tests_integ/test_interrupt.py new file mode 100644 index 000000000..164dfdede --- /dev/null +++ b/tests_integ/test_interrupt.py @@ -0,0 +1,192 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider +from strands.interrupt import Interrupt +from strands.session import FileSessionManager + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + if event.tool_use["name"] == "weather_tool": + return + + response = event.interrupt("test_interrupt", "need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool, weather_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool, weather_tool]) + + +@pytest.mark.asyncio +def test_interrupt(agent): + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:00", "sunny"]) + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "sunny"}, + ], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "12:00"}, + ], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message + + +@pytest.mark.asyncio +def test_interrupt_reject(agent): + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + interrupt = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "REJECT", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [{"text": "sunny"}], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "error", + "content": [{"text": "tool rejected"}], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message + + +@pytest.mark.asyncio +def test_interrupt_session(interrupt_hook, time_tool, weather_tool, tmpdir): + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + interrupt = result.interrupts[0] + + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:00", "sunny"]) From 61e41da96ab41f3557f6ed6a94bffadc696607de Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 16 Oct 2025 09:45:38 -0400 Subject: [PATCH 15/26] multiagents - temporarily raise exception when interrupted (#1038) --- src/strands/hooks/registry.py | 9 ++++++--- src/strands/multiagent/graph.py | 8 ++++++++ src/strands/multiagent/swarm.py | 6 ++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 1cfd5c63e..564be85cb 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -7,6 +7,7 @@ via hook provider objects. """ +import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar @@ -15,6 +16,8 @@ if TYPE_CHECKING: from ..agent import Agent +logger = logging.getLogger(__name__) + @dataclass class BaseHookEvent: @@ -219,9 +222,9 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte except InterruptException as exception: interrupt = exception.interrupt if interrupt.name in interrupts: - raise ValueError( - f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" - ) from exception + message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" + logger.error(message) + raise ValueError(message) from exception # Each callback is allowed to raise their own interrupt. interrupts[interrupt.name] = interrupt diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 60299c1b5..1dbbfc3af 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -578,6 +578,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) else: agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) + if agent_response.stop_reason == "interrupt": + node.executor.messages.pop() # remove interrupted tool use message + node.executor._interrupt_state.deactivate() + + raise RuntimeError( + "user raised interrupt from agent | interrupts are not yet supported in graphs" + ) + # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics = Metrics(latencyMs=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 42efd5742..7542b1b85 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -637,6 +637,12 @@ async def _execute_node( node.reset_executor_state() result = await node.executor.invoke_async(node_input, invocation_state=invocation_state) + if result.stop_reason == "interrupt": + node.executor.messages.pop() # remove interrupted tool use message + node.executor._interrupt_state.deactivate() + + raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in swarms") + execution_time = round((time.time() - start_time) * 1000) # Create NodeResult From 7cd10b91ee9bbda36c70f569aa0ededa72940e84 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:01:32 +0100 Subject: [PATCH 16/26] feat: Support adding exception notes for Python 3.10 (#1034) When add_note is not available (3.10) enhance the default error message with the added notes. In PR #290 we started using add_note to provide the bedrock model and region in exceptions to better clarify to customers what model & region were active. The implementation used add_note which is only supported in 3.11+; however, we've had enough customers on 3.10 where they're not seeing the error message that it makes sense to add a shim to do something similar for 3.10. --------- Co-authored-by: Mackenzie Zastrow --- src/strands/_exception_notes.py | 21 +++++++++++ src/strands/models/bedrock.py | 47 ++++++++++++------------ tests/strands/models/test_bedrock.py | 19 ++++++++++ tests/strands/test_exception_notes.py | 51 +++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 23 deletions(-) create mode 100644 src/strands/_exception_notes.py create mode 100644 tests/strands/test_exception_notes.py diff --git a/src/strands/_exception_notes.py b/src/strands/_exception_notes.py new file mode 100644 index 000000000..019b9cde4 --- /dev/null +++ b/src/strands/_exception_notes.py @@ -0,0 +1,21 @@ +"""Exception note utilities for Python 3.10+ compatibility.""" + +# add_note was added in 3.11 - we hoist to a constant to facilitate testing +supports_add_note = hasattr(Exception, "add_note") + + +def add_exception_note(exception: Exception, note: str) -> None: + """Add a note to an exception, compatible with Python 3.10+. + + Uses add_note() if it's available (Python 3.11+) or modifies the exception message if it is not. + """ + if supports_add_note: + # we ignore the mypy error because the version-check for add_note is extracted into a constant up above and + # mypy doesn't detect that + exception.add_note(note) # type: ignore + else: + # For Python 3.10, append note to the exception message + if hasattr(exception, "args") and exception.args: + exception.args = (f"{exception.args[0]}\n{note}",) + exception.args[1:] + else: + exception.args = (note,) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c6a500597..c465a2f38 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -16,6 +16,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from .._exception_notes import add_exception_note from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages @@ -716,29 +717,29 @@ def _stream( region = self.client.meta.region_name - # add_note added in Python 3.11 - if hasattr(e, "add_note"): - # Aid in debugging by adding more information - e.add_note(f"└ Bedrock region: {region}") - e.add_note(f"└ Model id: {self.config.get('model_id')}") - - if ( - e.response["Error"]["Code"] == "AccessDeniedException" - and "You don't have access to the model" in error_message - ): - e.add_note( - "└ For more information see " - "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue" - ) - - if ( - e.response["Error"]["Code"] == "ValidationException" - and "with on-demand throughput isn’t supported" in error_message - ): - e.add_note( - "└ For more information see " - "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported" - ) + # Aid in debugging by adding more information + add_exception_note(e, f"└ Bedrock region: {region}") + add_exception_note(e, f"└ Model id: {self.config.get('model_id')}") + + if ( + e.response["Error"]["Code"] == "AccessDeniedException" + and "You don't have access to the model" in error_message + ): + add_exception_note( + e, + "└ For more information see " + "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue", + ) + + if ( + e.response["Error"]["Code"] == "ValidationException" + and "with on-demand throughput isn’t supported" in error_message + ): + add_exception_note( + e, + "└ For more information see " + "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported", + ) raise e diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 96fee67fa..f6251943d 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,5 +1,6 @@ import os import sys +import traceback import unittest.mock from unittest.mock import ANY @@ -10,6 +11,7 @@ from botocore.exceptions import ClientError, EventStreamError import strands +from strands import _exception_notes from strands.models import BedrockModel from strands.models.bedrock import ( _DEFAULT_BEDROCK_MODEL_ID, @@ -1209,6 +1211,23 @@ async def test_add_note_on_client_error(bedrock_client, model, alist, messages): assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] +@pytest.mark.asyncio +async def test_add_note_on_client_error_without_add_notes(bedrock_client, model, alist, messages): + """Test that when add_note is not used, the region & model are still included in the error output.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + # Mock the client error response + error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream") + + # Call the stream method which should catch and add notes to the exception + with pytest.raises(ClientError) as err: + await alist(model.stream(messages)) + + error_str = "".join(traceback.format_exception(err.value)) + assert "└ Bedrock region: us-west-2" in error_str + assert "└ Model id: m1" in error_str + + @pytest.mark.asyncio async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" diff --git a/tests/strands/test_exception_notes.py b/tests/strands/test_exception_notes.py new file mode 100644 index 000000000..936cf0848 --- /dev/null +++ b/tests/strands/test_exception_notes.py @@ -0,0 +1,51 @@ +"""Tests for exception note utilities.""" + +import sys +import traceback +import unittest.mock + +import pytest + +from strands import _exception_notes +from strands._exception_notes import add_exception_note + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") +def test_add_exception_note_python_311_plus(): + """Test add_exception_note uses add_note in Python 3.11+.""" + exception = ValueError("original message") + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: original message\n", "test note\n"] + + +def test_add_exception_note_python_310(): + """Test add_exception_note modifies args in Python 3.10.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + exception = ValueError("original message") + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: original message\ntest note\n"] + + +def test_add_exception_note_python_310_no_args(): + """Test add_exception_note handles exception with no args in Python 3.10.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + exception = ValueError() + exception.args = () + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: test note\n"] + + +def test_add_exception_note_python_310_multiple_args(): + """Test add_exception_note preserves additional args in Python 3.10.""" + with unittest.mock.patch.object(_exception_notes, "supports_add_note", False): + exception = ValueError("original message", "second arg") + + add_exception_note(exception, "test note") + + assert traceback.format_exception(exception) == ["ValueError: ('original message\\ntest note', 'second arg')\n"] From 26862e4741af92f580371828cec2ab516195a139 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 17 Oct 2025 14:21:56 -0400 Subject: [PATCH 17/26] interrupts - decorated tools (#1041) --- src/strands/hooks/events.py | 6 +- src/strands/tools/decorator.py | 7 +- src/strands/tools/executors/_executor.py | 7 +- src/strands/types/interrupt.py | 4 +- src/strands/types/tools.py | 15 +- tests/strands/agent/test_agent.py | 2 +- tests/strands/event_loop/test_event_loop.py | 8 +- tests/strands/hooks/test_registry.py | 4 +- tests/strands/tools/executors/conftest.py | 17 +- .../tools/executors/test_concurrent.py | 2 +- .../strands/tools/executors/test_executor.py | 60 ++++++- .../tools/executors/test_sequential.py | 2 +- tests/strands/tools/test_decorator.py | 65 ++++++- tests/strands/types/test_interrupt.py | 4 +- tests_integ/interrupts/__init__.py | 0 .../test_hook.py} | 35 +--- tests_integ/interrupts/test_session.py | 79 +++++++++ tests_integ/interrupts/test_tool.py | 163 ++++++++++++++++++ 18 files changed, 419 insertions(+), 61 deletions(-) create mode 100644 tests_integ/interrupts/__init__.py rename tests_integ/{test_interrupt.py => interrupts/test_hook.py} (74%) create mode 100644 tests_integ/interrupts/test_session.py create mode 100644 tests_integ/interrupts/test_tool.py diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index de07002c5..05be255f6 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -10,7 +10,7 @@ from typing_extensions import override from ..types.content import Message -from ..types.interrupt import InterruptHookEvent +from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -88,7 +88,7 @@ class MessageAddedEvent(HookEvent): @dataclass -class BeforeToolCallEvent(HookEvent, InterruptHookEvent): +class BeforeToolCallEvent(HookEvent, _Interruptible): """Event triggered before a tool is invoked. This event is fired just before the agent executes a tool, allowing hook @@ -124,7 +124,7 @@ def _interrupt_id(self, name: str) -> str: Returns: Interrupt id. """ - return f"v1:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + return f"v1:before_tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" @dataclass diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 72109dbef..5c49f4b58 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -62,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types._events import ToolResultEvent, ToolStreamEvent +from ..interrupt import InterruptException +from ..types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -493,6 +494,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore yield self._wrap_tool_result(tool_use_id, result) + except InterruptException as e: + yield ToolInterruptEvent(tool_use, [e.interrupt]) + return + except ValueError as e: # Special handling for validation errors error_msg = str(e) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index a4f43b149..44c2dc36a 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -163,11 +163,16 @@ async def _stream( # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in # ToolStreamEvent and the last event is just the result. + if isinstance(event, ToolInterruptEvent): + yield event + return + if isinstance(event, ToolResultEvent): # below the last "event" must point to the tool_result event = event.tool_result break - elif isinstance(event, ToolStreamEvent): + + if isinstance(event, ToolStreamEvent): yield event else: yield ToolStreamEvent(tool_use, event) diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py index 4e9584a70..2968ed219 100644 --- a/src/strands/types/interrupt.py +++ b/src/strands/types/interrupt.py @@ -118,8 +118,8 @@ def approve(self, event: BeforeToolCallEvent) -> None: from ..agent import Agent -class InterruptHookEvent(Protocol): - """Interface that adds interrupt support to hook events.""" +class _Interruptible(Protocol): + """Interface that adds interrupt support to hook events and tools.""" agent: "Agent" diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 18c7013ee..8343647b2 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -5,12 +5,14 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union from typing_extensions import NotRequired, TypedDict +from .interrupt import _Interruptible from .media import DocumentContent, ImageContent if TYPE_CHECKING: @@ -126,7 +128,7 @@ class ToolChoiceTool(TypedDict): @dataclass -class ToolContext: +class ToolContext(_Interruptible): """Context object containing framework-provided data for decorated tools. This object provides access to framework-level information that may be useful @@ -148,6 +150,17 @@ class ToolContext: agent: "Agent" invocation_state: dict[str, Any] + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:tool_call:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + # Individual ToolChoice type aliases ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ae2d8c7b5..b58e5f3fd 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1957,7 +1957,7 @@ def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator): ) interrupt = Interrupt( - id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", name="test_name", reason="test reason", ) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 89ef477fa..0a694bf1d 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -884,7 +884,7 @@ def interrupt_callback(event): exp_stop_reason = "interrupt" exp_interrupts = [ Interrupt( - id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", name="test_name", reason="test reason", ), @@ -911,8 +911,8 @@ def interrupt_callback(event): }, }, "interrupts": { - "v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9": { - "id": "v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + "v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9": { + "id": "v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", "name": "test_name", "reason": "test reason", "response": None, @@ -925,7 +925,7 @@ def interrupt_callback(event): @pytest.mark.asyncio async def test_event_loop_cycle_interrupt_resume(agent, model, tool, tool_times_2, agenerator, alist): interrupt = Interrupt( - id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + id="v1:before_tool_call:t1:78714d6c-613c-5cf4-bf25-7037569941f9", name="test_name", reason="test reason", response="test response", diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 807011869..6918bd2ee 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -38,12 +38,12 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent): _, tru_interrupts = registry.invoke_callbacks(event) exp_interrupts = [ Interrupt( - id="v1:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", + id="v1:before_tool_call:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", name="test_name_1", reason="test reason 1", ), Interrupt( - id="v1:test_tool_id:0f5a8068-d1ba-5a48-bf67-c9d33786d8d4", + id="v1:before_tool_call:test_tool_id:0f5a8068-d1ba-5a48-bf67-c9d33786d8d4", name="test_name_2", reason="test reason 2", ), diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index fa8ce10af..d25cf14bd 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -7,6 +7,7 @@ from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry from strands.tools.registry import ToolRegistry +from strands.types.tools import ToolContext @pytest.fixture @@ -79,12 +80,22 @@ def func(): @pytest.fixture -def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool): +def interrupt_tool(): + @strands.tool(name="interrupt_tool", context=True) + def func(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_name", reason="test reason") + + return func + + +@pytest.fixture +def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool): registry = ToolRegistry() registry.register_tool(weather_tool) registry.register_tool(temperature_tool) registry.register_tool(exception_tool) registry.register_tool(thread_tool) + registry.register_tool(interrupt_tool) return registry @@ -113,5 +124,5 @@ def cycle_span(): @pytest.fixture -def invocation_state(): - return {} +def invocation_state(agent): + return {"agent": agent} diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 7264c8e58..4b62a8a9a 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -38,7 +38,7 @@ async def test_concurrent_executor_interrupt( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): interrupt = Interrupt( - id="v1:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + id="v1:before_tool_call:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", name="test_name", reason="test reason", ) diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index fd15c9747..a11e2eab2 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -342,11 +342,11 @@ async def test_executor_stream_no_span_attributes_when_no_tool_spec( @pytest.mark.asyncio -async def test_executor_stream_interrupt(executor, agent, tool_results, invocation_state, alist): +async def test_executor_stream_hook_interrupt(executor, agent, tool_results, invocation_state, alist): tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} interrupt = Interrupt( - id="v1:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + id="v1:before_tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", name="test_name", reason="test reason", ) @@ -368,11 +368,11 @@ def interrupt_callback(event): @pytest.mark.asyncio -async def test_executor_stream_interrupt_resume(executor, agent, tool_results, invocation_state, alist): +async def test_executor_stream_hook_interrupt_resume(executor, agent, tool_results, invocation_state, alist): tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} interrupt = Interrupt( - id="v1:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + id="v1:before_tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", name="test_name", reason="test reason", response="test response", @@ -407,3 +407,55 @@ def interrupt_callback(event): tru_response = interrupt_response["response"] exp_response = "test response" assert tru_response == exp_response + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_use, [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "test_tool_id", + "status": "success", + "content": [{"text": "test response"}], + }, + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index c1db3cd55..a6c2c2277 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -38,7 +38,7 @@ async def test_sequential_executor_interrupt( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): interrupt = Interrupt( - id="v1:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + id="v1:before_tool_call:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", name="test_name", reason="test reason", ) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 658a34052..25f9bc39e 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -10,7 +10,9 @@ import strands from strands import Agent -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt +from strands.types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -138,6 +140,67 @@ def identity(a: int, agent: dict = None): assert tru_events == exp_events +@pytest.mark.asyncio +async def test_stream_interrupt(alist): + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + tool_use = {"toolUseId": "test_tool_id"} + + mock_agent = MagicMock() + mock_agent._interrupt_state = InterruptState() + + invocation_state = {"agent": mock_agent} + + @strands.tool(context=True) + def interrupt_tool(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_name", reason="test reason") + + stream = interrupt_tool.stream(tool_use, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_use, [interrupt])] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume(alist): + interrupt = Interrupt( + id="v1:tool_call:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + + tool_use = {"toolUseId": "test_tool_id"} + + mock_agent = MagicMock() + mock_agent._interrupt_state = InterruptState(interrupts={interrupt.id: interrupt}) + + invocation_state = {"agent": mock_agent} + + @strands.tool(context=True) + def interrupt_tool(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_name", reason="test reason") + + stream = interrupt_tool.stream(tool_use, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "test_tool_id", + "status": "success", + "content": [{"text": "test response"}], + }, + ), + ] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_basic_tool_creation(alist): """Test basic tool decorator functionality.""" diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py index 3b970a00a..ade0fa5e8 100644 --- a/tests/strands/types/test_interrupt.py +++ b/tests/strands/types/test_interrupt.py @@ -4,7 +4,7 @@ from strands.agent.interrupt import InterruptState from strands.interrupt import Interrupt, InterruptException -from strands.types.interrupt import InterruptHookEvent +from strands.types.interrupt import _Interruptible @pytest.fixture @@ -26,7 +26,7 @@ def agent(): @pytest.fixture def interrupt_hook_event(agent): - class Event(InterruptHookEvent): + class Event(_Interruptible): def __init__(self): self.agent = agent diff --git a/tests_integ/interrupts/__init__.py b/tests_integ/interrupts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/test_interrupt.py b/tests_integ/interrupts/test_hook.py similarity index 74% rename from tests_integ/test_interrupt.py rename to tests_integ/interrupts/test_hook.py index 164dfdede..836d7d415 100644 --- a/tests_integ/test_interrupt.py +++ b/tests_integ/interrupts/test_hook.py @@ -6,7 +6,6 @@ from strands import Agent, tool from strands.hooks import BeforeToolCallEvent, HookProvider from strands.interrupt import Interrupt -from strands.session import FileSessionManager @pytest.fixture @@ -19,7 +18,7 @@ def interrupt(self, event): if event.tool_use["name"] == "weather_tool": return - response = event.interrupt("test_interrupt", "need approval") + response = event.interrupt("test_interrupt", reason="need approval") if response != "APPROVE": event.cancel_tool = "tool rejected" @@ -158,35 +157,3 @@ def test_interrupt_reject(agent): ], } assert tru_tool_result_message == exp_tool_result_message - - -@pytest.mark.asyncio -def test_interrupt_session(interrupt_hook, time_tool, weather_tool, tmpdir): - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) - agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) - result = agent("What is the time and weather?") - - tru_stop_reason = result.stop_reason - exp_stop_reason = "interrupt" - assert tru_stop_reason == exp_stop_reason - - interrupt = result.interrupts[0] - - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) - agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) - responses = [ - { - "interruptResponse": { - "interruptId": interrupt.id, - "response": "APPROVE", - }, - }, - ] - result = agent(responses) - - tru_stop_reason = result.stop_reason - exp_stop_reason = "end_turn" - assert tru_stop_reason == exp_stop_reason - - result_message = json.dumps(result.message).lower() - assert all(string in result_message for string in ["12:00", "sunny"]) diff --git a/tests_integ/interrupts/test_session.py b/tests_integ/interrupts/test_session.py new file mode 100644 index 000000000..83d2cc73d --- /dev/null +++ b/tests_integ/interrupts/test_session.py @@ -0,0 +1,79 @@ +import json + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider +from strands.session import FileSessionManager + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + if event.tool_use["name"] == "weather_tool": + return + + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool, weather_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool, weather_tool]) + + +@pytest.mark.asyncio +def test_interrupt_session(interrupt_hook, time_tool, weather_tool, tmpdir): + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + interrupt = result.interrupts[0] + + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:00", "sunny"]) diff --git a/tests_integ/interrupts/test_tool.py b/tests_integ/interrupts/test_tool.py new file mode 100644 index 000000000..00dbfcc90 --- /dev/null +++ b/tests_integ/interrupts/test_tool.py @@ -0,0 +1,163 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider +from strands.interrupt import Interrupt +from strands.types.tools import ToolContext + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + if event.tool_use["name"] != "time_tool": + return + + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool", context=True) + def func(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_interrupt", reason="need time") + + return func + + +@pytest.fixture +def day_tool(): + @tool(name="day_tool", context=True) + def func(tool_context: ToolContext) -> str: + return tool_context.interrupt("test_interrupt", reason="need day") + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func() -> str: + return "sunny" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool, day_tool, weather_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool, day_tool, weather_tool]) + + +@pytest.mark.asyncio +def test_interrupt(agent): + result = agent("What is the time, day, and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = sorted(result.interrupts, key=lambda interrupt: interrupt.reason) + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="test_interrupt", + reason="need day", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt_approval, interrupt_day = result.interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt_approval.id, + "response": "APPROVE", + }, + }, + { + "interruptResponse": { + "interruptId": interrupt_day.id, + "response": "monday", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need time", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt_time = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt_time.id, + "response": "12:01", + }, + }, + ] + result = agent(responses) + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:01", "monday", "sunny"]) + + tru_tool_results = agent.messages[-2]["content"] + tru_tool_results.sort(key=lambda content: content["toolResult"]["content"][0]["text"]) + + exp_tool_results = [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "12:01"}, + ], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "monday"}, + ], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "sunny"}, + ], + }, + }, + ] + assert tru_tool_results == exp_tool_results From 3a7af77c4c0bfe7538a8c2a02825186a54620938 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 21 Oct 2025 15:11:43 -0400 Subject: [PATCH 18/26] models - litellm - start and stop reasoning (#947) --- src/strands/models/litellm.py | 46 +++++++++++++---- tests/strands/models/test_litellm.py | 63 +++++++++++++++++++----- tests_integ/models/test_model_litellm.py | 16 ++++++ 3 files changed, 104 insertions(+), 21 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 486f67bf8..f1cbf01a2 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -111,6 +111,26 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] return super().format_request_message_content(content) + def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]: + """Handle switching to a new content stream. + + Args: + data_type: The next content data type. + prev_data_type: The previous content data type. + + Returns: + Tuple containing: + - Stop block for previous content and the start block for the next content. + - Next content data type. + """ + chunks = [] + if data_type != prev_data_type: + if prev_data_type is not None: + chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type})) + chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": data_type})) + + return chunks, data_type + @override async def stream( self, @@ -146,9 +166,9 @@ async def stream( logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None async for event in response: # Defensive: skip events with empty or missing choices @@ -156,28 +176,36 @@ async def stream( continue choice = event.choices[0] - if choice.delta.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - ) - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield chunk + yield self.format_chunk( { "chunk_type": "content_delta", - "data_type": "reasoning_content", + "data_type": data_type, "data": choice.delta.reasoning_content, } ) + if choice.delta.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield chunk + + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} + ) + for tool_call in choice.delta.tool_calls or []: tool_calls.setdefault(tool_call.index, []).append(tool_call) if choice.finish_reason: + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) break - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - for tool_deltas in tool_calls.values(): yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 82023cae3..3a427f759 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -142,39 +142,71 @@ def test_format_request_message_content(content, exp_result): @pytest.mark.asyncio async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, alist): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) mock_delta_1 = unittest.mock.Mock( reasoning_content="", content=None, tool_calls=None, ) + mock_delta_2 = unittest.mock.Mock( reasoning_content="\nI'm thinking", content=None, tool_calls=None, ) mock_delta_3 = unittest.mock.Mock( + reasoning_content=None, + content="One second", + tool_calls=None, + ) + mock_delta_4 = unittest.mock.Mock( + reasoning_content="\nI'm think", + content=None, + tool_calls=None, + ) + mock_delta_5 = unittest.mock.Mock( + reasoning_content="ing again", + content=None, + tool_calls=None, + ) + + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_6 = unittest.mock.Mock( content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None ) mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_4 = unittest.mock.Mock( + mock_delta_7 = unittest.mock.Mock( content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None ) - mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) + mock_delta_8 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)]) mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)]) - mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) - mock_event_6 = unittest.mock.Mock() + mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_5)]) + mock_event_6 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_6)]) + mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)]) + mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)]) + mock_event_9 = unittest.mock.Mock() litellm_acompletion.side_effect = unittest.mock.AsyncMock( - return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]) + return_value=agenerator( + [ + mock_event_1, + mock_event_2, + mock_event_3, + mock_event_4, + mock_event_5, + mock_event_6, + mock_event_7, + mock_event_8, + mock_event_9, + ] + ) ) messages = [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}] @@ -184,6 +216,15 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm thinking"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "One second"}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm think"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "ing again"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"text": "I'll calculate"}}}, {"contentBlockDelta": {"delta": {"text": "that for you"}}}, {"contentBlockStop": {}}, @@ -211,9 +252,9 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, { "metadata": { "usage": { - "inputTokens": mock_event_6.usage.prompt_tokens, - "outputTokens": mock_event_6.usage.completion_tokens, - "totalTokens": mock_event_6.usage.total_tokens, + "inputTokens": mock_event_9.usage.prompt_tokens, + "outputTokens": mock_event_9.usage.completion_tokens, + "totalTokens": mock_event_9.usage.total_tokens, }, "metrics": {"latencyMs": 0}, } @@ -253,8 +294,6 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene tru_events = await alist(response) exp_events = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, ] diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index c5a09e3e9..b348c29f4 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -121,6 +121,22 @@ async def test_agent_stream_async(agent): assert all(string in text for string in ["12:00", "sunny"]) +def test_agent_invoke_reasoning(agent, model): + model.update_config( + params={ + "thinking": { + "budget_tokens": 1024, + "type": "enabled", + }, + }, + ) + + result = agent("Please reason about the equation 2+2.") + + assert "reasoningContent" in result.message["content"][0] + assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] + + def test_structured_output(agent, weather): tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather From b69478b9c16703702a3e163c662d4930128aed21 Mon Sep 17 00:00:00 2001 From: Matt Lee <1302416+mr-lee@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:18:25 -0400 Subject: [PATCH 19/26] feat: add experimental AgentConfig with comprehensive tool management (#935) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add experimental AgentConfig with comprehensive tool management - Add AgentConfig class for declarative agent configuration via JSON/dict - Support file:// prefix for loading configurations from JSON files - Implement ToolRegistry integration with automatic default tool loading - Add raise_exception_on_missing_tool parameter for flexible error handling - Support tool selection from registry via tool names in config - Add comprehensive test coverage for all configuration scenarios - Move hook events from experimental to production with updated names - Add OpenAI model provider enhancements and Gemini model improvements - Update event loop and tool executors to use production hook events 🤖 Assisted by Amazon Q Developer * fix: remove AgentConfig import from experimental/__init__.py - Reset experimental/__init__.py to not import AgentConfig by default - This may resolve import issues in CI environments - AgentConfig can still be imported directly from strands.experimental.agent_config 🤖 Assisted by Amazon Q Developer * fix: remove strands-agents-tools test dependency - Reset pyproject.toml to not include strands-agents-tools as test dependency - Tests handle missing strands_tools gracefully with mocking - This should resolve CI dependency issues 🤖 Assisted by Amazon Q Developer * test: remove test that depends on strands_tools availability - Remove test_agent_config_loads_from_default_tools_without_tool_registry - This test assumes strands_tools is available which causes CI failures - Other tests adequately cover AgentConfig functionality 🤖 Assisted by Amazon Q Developer * test: add back tests with proper mocking for strands_tools - Add back test_agent_config_tools_without_tool_registry_error with mocking - Add back test_agent_config_loads_from_default_tools_without_tool_registry with mocking - Mock _create_default_tool_registry to avoid dependency on strands_tools - Add tool import for creating mock tools in tests - All 15 tests now pass without external dependencies 🤖 Assisted by Amazon Q Developer * test: fix Windows compatibility for file prefix test - Use platform-specific tempfile handling in test_agent_config_file_prefix_valid - Use mkstemp() with explicit cleanup on Windows for better permission handling - Keep NamedTemporaryFile on non-Windows platforms for simplicity - Should resolve permission errors on Windows GitHub runners 🤖 Assisted by Amazon Q Developer * refactor: replace AgentConfig class with config_to_agent function BREAKING CHANGE: Replace class-based AgentConfig with function-based config_to_agent - Replace AgentConfig class with config_to_agent function for simpler interface - Remove ToolRegistry dependency - let Agent handle tool loading internally - Remove DEFAULT_TOOLS concept and raise_exception_on_missing_tool parameter - Support both file paths and dictionary inputs with file:// prefix handling - Only pass non-None config values to Agent constructor (use Agent defaults) - Update experimental module exports to expose config_to_agent function - Rewrite all tests to use new function-based interface - Simplify tool handling by delegating to Agent class New interface: from strands.experimental import config_to_agent agent = config_to_agent('/path/to/config.json') Previous interface (removed): from strands.experimental.agent_config import AgentConfig config = AgentConfig('/path/to/config.json') agent = config.to_agent() 🤖 Assisted by Amazon Q Developer * feat: limit config_to_agent to core configuration keys - Remove support for advanced Agent parameters in config_to_agent - Only support: model, prompt, tools, name in configuration - Advanced parameters can still be passed via kwargs - Remove agent_id test and update function mapping - Keep interface simple and focused on basic agent configuration 🤖 Assisted by Amazon Q Developer * fix: use native Python typing instead of typing module - Replace Union[str, Dict[str, Any]] with str | dict[str, any] - Remove typing module imports - Use modern Python 3.10+ native typing syntax 🤖 Assisted by Amazon Q Developer * test: simplify file prefix test with proper context manager - Use NamedTemporaryFile with delete=True for automatic cleanup - Remove manual os.unlink call and try/finally block - Keep file operation within single context manager scope - Add f.flush() to ensure data is written before reading 🤖 Assisted by Amazon Q Developer * feat: add JSON schema validation to config_to_agent - Add jsonschema dependency for configuration validation - Implement JSON schema based on supported configuration keys - Provide detailed validation error messages with field paths - Add validation tests for invalid fields, types, and tool items - Support null values for optional fields (model, prompt, name) - Reject additional properties not in the schema - All 14 tests passing including new validation tests 🤖 Assisted by Amazon Q Developer * refactor: move JSON schema to separate file - Extract agent configuration schema to schemas/agent-config-v1.json - Add _load_schema() function to load schema from file at runtime - Improve code readability by separating schema from Python logic - Enable schema reuse by other tools and documentation - Maintain all existing validation functionality and tests 🤖 Assisted by Amazon Q Developer * perf: use pre-compiled JSON schema validator - Create Draft7Validator instance at module level for better performance - Avoid loading and compiling schema on every validation call - Schema is loaded once at import time and validator is reused - Maintains all existing validation functionality and error messages - Standard best practice for jsonschema validation performance 🤖 Assisted by Amazon Q Developer * feat: add tool validation and clarify limitations - Move JSON schema back to inline variable for simplicity - Add comprehensive tool validation with helpful error messages - Validate tools can be loaded as files, modules, or @tool functions - Add clear documentation about code-based instantiation limitations - Update module docstring and function comments with usage patterns - Add test for tool validation error messages - Remove schemas directory (no longer needed) 🤖 Assisted by Amazon Q Developer * fix: improve tool validation error messages and add comprehensive tests - Fix error message for missing modules to be more descriptive - Remove redundant 'to properly import this tool' text from error messages - Add specific error messages for missing modules vs missing functions - Add unit tests for each error case: - Invalid tool (not file/module/@tool) - Missing module (module doesn't exist) - Missing function (function not found in existing module) - All 17 tests passing with better error coverage 🤖 Assisted by Amazon Q Developer * fix: reference module instead of tool in error message - Change error message from 'Tool X not found' to 'Module X not found' - More accurate since we're trying to import it as a module at this point - Maintains existing test compatibility and error handling logic 🤖 Assisted by Amazon Q Developer * revert: change error message back to reference tool - Revert previous change from 'Module X not found' back to 'Tool X not found' - Keep original error message format as requested 🤖 Assisted by Amazon Q Developer * feat: use agent tool loading logic * fix: address pr comments --------- Co-authored-by: Matt Lee Co-authored-by: Nicholas Clegg --- .gitignore | 3 +- pyproject.toml | 1 + src/strands/experimental/__init__.py | 4 + src/strands/experimental/agent_config.py | 138 ++++++++++++++ .../strands/experimental/test_agent_config.py | 172 ++++++++++++++++++ tests_integ/fixtures/say_tool.py | 7 + tests_integ/fixtures/test_agent.json | 6 + tests_integ/test_agent_json.py | 13 ++ 8 files changed, 343 insertions(+), 1 deletion(-) create mode 100644 src/strands/experimental/agent_config.py create mode 100644 tests/strands/experimental/test_agent_config.py create mode 100644 tests_integ/fixtures/say_tool.py create mode 100644 tests_integ/fixtures/test_agent.json create mode 100644 tests_integ/test_agent_json.py diff --git a/.gitignore b/.gitignore index 888a96bbc..e92a233f8 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ __pycache__* .vscode dist repl_state -.kiro \ No newline at end of file +.kiro +uv.lock diff --git a/pyproject.toml b/pyproject.toml index af8e45ffc..b542c7481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", + "jsonschema>=4.0.0,<5.0.0", "mcp>=1.11.0,<2.0.0", "pydantic>=2.4.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index c40d0fcec..86618c153 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -2,3 +2,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ + +from .agent_config import config_to_agent + +__all__ = ["config_to_agent"] diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py new file mode 100644 index 000000000..d08f89cf9 --- /dev/null +++ b/src/strands/experimental/agent_config.py @@ -0,0 +1,138 @@ +"""Experimental agent configuration utilities. + +This module provides utilities for creating agents from configuration files or dictionaries. + +Note: Configuration-based agent setup only works for tools that don't require code-based +instantiation. For tools that need constructor arguments or complex setup, use the +programmatic approach after creating the agent: + + agent = config_to_agent("config.json") + # Add tools that need code-based instantiation + agent.tool_registry.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))]) +""" + +import json +from pathlib import Path +from typing import Any + +import jsonschema +from jsonschema import ValidationError + +from ..agent import Agent + +# JSON Schema for agent configuration +AGENT_CONFIG_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Agent Configuration", + "description": "Configuration schema for creating agents", + "type": "object", + "properties": { + "name": {"description": "Name of the agent", "type": ["string", "null"], "default": None}, + "model": { + "description": "The model ID to use for this agent. If not specified, uses the default model.", + "type": ["string", "null"], + "default": None, + }, + "prompt": { + "description": "The system prompt for the agent. Provides high level context to the agent.", + "type": ["string", "null"], + "default": None, + }, + "tools": { + "description": "List of tools the agent can use. Can be file paths, " + "Python module names, or @tool annotated functions in files.", + "type": "array", + "items": {"type": "string"}, + "default": [], + }, + }, + "additionalProperties": False, +} + +# Pre-compile validator for better performance +_VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) + + +def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Agent: + """Create an Agent from a configuration file or dictionary. + + This function supports tools that can be loaded declaratively (file paths, module names, + or @tool annotated functions). For tools requiring code-based instantiation with constructor + arguments, add them programmatically after creating the agent: + + agent = config_to_agent("config.json") + agent.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))]) + + Args: + config: Either a file path (with optional file:// prefix) or a configuration dictionary + **kwargs: Additional keyword arguments to pass to the Agent constructor + + Returns: + Agent: A configured Agent instance + + Raises: + FileNotFoundError: If the configuration file doesn't exist + json.JSONDecodeError: If the configuration file contains invalid JSON + ValueError: If the configuration is invalid or tools cannot be loaded + + Examples: + Create agent from file: + >>> agent = config_to_agent("/path/to/config.json") + + Create agent from file with file:// prefix: + >>> agent = config_to_agent("file:///path/to/config.json") + + Create agent from dictionary: + >>> config = {"model": "anthropic.claude-3-5-sonnet-20241022-v2:0", "tools": ["calculator"]} + >>> agent = config_to_agent(config) + """ + # Parse configuration + if isinstance(config, str): + # Handle file path + file_path = config + + # Remove file:// prefix if present + if file_path.startswith("file://"): + file_path = file_path[7:] + + # Load JSON from file + config_path = Path(file_path) + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {file_path}") + + with open(config_path, "r") as f: + config_dict = json.load(f) + elif isinstance(config, dict): + config_dict = config.copy() + else: + raise ValueError("Config must be a file path string or dictionary") + + # Validate configuration against schema + try: + _VALIDATOR.validate(config_dict) + except ValidationError as e: + # Provide more detailed error message + error_path = " -> ".join(str(p) for p in e.absolute_path) if e.absolute_path else "root" + raise ValueError(f"Configuration validation error at {error_path}: {e.message}") from e + + # Prepare Agent constructor arguments + agent_kwargs = {} + + # Map configuration keys to Agent constructor parameters + config_mapping = { + "model": "model", + "prompt": "system_prompt", + "tools": "tools", + "name": "name", + } + + # Only include non-None values from config + for config_key, agent_param in config_mapping.items(): + if config_key in config_dict and config_dict[config_key] is not None: + agent_kwargs[agent_param] = config_dict[config_key] + + # Override with any additional kwargs provided + agent_kwargs.update(kwargs) + + # Create and return Agent + return Agent(**agent_kwargs) diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py new file mode 100644 index 000000000..e6188079b --- /dev/null +++ b/tests/strands/experimental/test_agent_config.py @@ -0,0 +1,172 @@ +"""Tests for experimental config_to_agent function.""" + +import json +import os +import tempfile + +import pytest + +from strands.experimental import config_to_agent + + +def test_config_to_agent_with_dict(): + """Test config_to_agent can be created with dict config.""" + config = {"model": "test-model"} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "test-model" + + +def test_config_to_agent_with_system_prompt(): + """Test config_to_agent handles system prompt correctly.""" + config = {"model": "test-model", "prompt": "Test prompt"} + agent = config_to_agent(config) + assert agent.system_prompt == "Test prompt" + + +def test_config_to_agent_with_tools_list(): + """Test config_to_agent handles tools list without failing.""" + # Use a simple test that doesn't require actual tool loading + config = {"model": "test-model", "tools": []} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "test-model" + + +def test_config_to_agent_with_kwargs_override(): + """Test that kwargs can override config values.""" + config = {"model": "test-model", "prompt": "Config prompt"} + agent = config_to_agent(config, system_prompt="Override prompt") + assert agent.system_prompt == "Override prompt" + + +def test_config_to_agent_file_prefix_required(): + """Test that file paths without file:// prefix work.""" + import json + import tempfile + + config_data = {"model": "test-model"} + temp_path = "" + + # We need to create files like this for windows compatibility + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + json.dump(config_data, f) + f.flush() + temp_path = f.name + + agent = config_to_agent(temp_path) + assert agent.model.config["model_id"] == "test-model" + finally: + # Clean up the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_config_to_agent_file_prefix_valid(): + """Test that file:// prefix is properly handled.""" + config_data = {"model": "test-model", "prompt": "Test prompt"} + temp_path = "" + + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + json.dump(config_data, f) + f.flush() + temp_path = f.name + + agent = config_to_agent(f"file://{temp_path}") + assert agent.model.config["model_id"] == "test-model" + assert agent.system_prompt == "Test prompt" + finally: + # Clean up the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_config_to_agent_file_not_found(): + """Test that FileNotFoundError is raised for missing files.""" + with pytest.raises(FileNotFoundError, match="Configuration file not found"): + config_to_agent("/nonexistent/path/config.json") + + +def test_config_to_agent_invalid_json(): + """Test that JSONDecodeError is raised for invalid JSON.""" + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + f.write("invalid json content") + temp_path = f.name + + with pytest.raises(json.JSONDecodeError): + config_to_agent(temp_path) + finally: + # Clean up the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def test_config_to_agent_invalid_config_type(): + """Test that ValueError is raised for invalid config types.""" + with pytest.raises(ValueError, match="Config must be a file path string or dictionary"): + config_to_agent(123) + + +def test_config_to_agent_with_name(): + """Test config_to_agent handles agent name.""" + config = {"model": "test-model", "name": "TestAgent"} + agent = config_to_agent(config) + assert agent.name == "TestAgent" + + +def test_config_to_agent_ignores_none_values(): + """Test that None values in config are ignored.""" + config = {"model": "test-model", "prompt": None, "name": None} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "test-model" + # Agent should use its defaults for None values + + +def test_config_to_agent_validation_error_invalid_field(): + """Test that invalid fields raise validation errors.""" + config = {"model": "test-model", "invalid_field": "value"} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_wrong_type(): + """Test that wrong field types raise validation errors.""" + config = {"model": "test-model", "tools": "not-a-list"} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_invalid_tool_item(): + """Test that invalid tool items raise validation errors.""" + config = {"model": "test-model", "tools": ["valid-tool", 123]} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_invalid_tool(): + """Test that invalid tools raise helpful error messages.""" + config = {"model": "test-model", "tools": ["nonexistent_tool"]} + with pytest.raises(ValueError, match="Failed to load tool nonexistent_tool"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_missing_module(): + """Test that missing modules raise helpful error messages.""" + config = {"model": "test-model", "tools": ["nonexistent.module.tool"]} + with pytest.raises(ValueError, match="Failed to load tool nonexistent.module.tool"): + config_to_agent(config) + + +def test_config_to_agent_validation_error_missing_function(): + """Test that missing functions in existing modules raise helpful error messages.""" + config = {"model": "test-model", "tools": ["json.nonexistent_function"]} + with pytest.raises(ValueError, match="Failed to load tool json.nonexistent_function"): + config_to_agent(config) + + +def test_config_to_agent_with_tool(): + """Test that missing functions in existing modules raise helpful error messages.""" + config = {"model": "test-model", "tools": ["tests.fixtures.say_tool:say"]} + agent = config_to_agent(config) + assert "say" in agent.tool_names diff --git a/tests_integ/fixtures/say_tool.py b/tests_integ/fixtures/say_tool.py new file mode 100644 index 000000000..454f28240 --- /dev/null +++ b/tests_integ/fixtures/say_tool.py @@ -0,0 +1,7 @@ +from strands import tool + + +@tool +def say(input: str) -> str: + """Say the input""" + return f"Said: {input}" diff --git a/tests_integ/fixtures/test_agent.json b/tests_integ/fixtures/test_agent.json new file mode 100644 index 000000000..e1ffad249 --- /dev/null +++ b/tests_integ/fixtures/test_agent.json @@ -0,0 +1,6 @@ +{ + "model": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "tools": ["tests_integ.fixtures.say_tool:say"], + "prompt": "You use the say tool to communicate", + "name": "Sayer" +} \ No newline at end of file diff --git a/tests_integ/test_agent_json.py b/tests_integ/test_agent_json.py new file mode 100644 index 000000000..387cfd172 --- /dev/null +++ b/tests_integ/test_agent_json.py @@ -0,0 +1,13 @@ +from strands.experimental import config_to_agent + + +def test_load_agent_from_config(): + agent = config_to_agent("file://tests_integ/fixtures/test_agent.json") + + result = agent("Say hello") + + assert "Sayer" == agent.name + assert "You use the say tool to communicate" == agent.system_prompt + assert agent.tool_names[0] == "say" + assert agent.model.get_config().get("model_id") == "global.anthropic.claude-sonnet-4-5-20250929-v1:0" + assert "hello" in str(result).lower() From 78c59b95ffa2b50a8e1dc93e3cdd172772b0b791 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Tue, 21 Oct 2025 15:29:38 -0400 Subject: [PATCH 20/26] fix(telemetry): make strands agent invoke_agent span as INTERNAL spanKind (#1055) * fix(telemetry): make strands agent invoke_agent and chat span as INTERNAL spanKind --- src/strands/telemetry/tracer.py | 4 ++-- tests/strands/telemetry/test_tracer.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 907fd454a..9cefc6911 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -293,7 +293,7 @@ def start_model_invoke_span( # Add additional kwargs as attributes attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))}) - span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT) + span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL) self._add_event_messages(span, messages) return span @@ -588,7 +588,7 @@ def start_agent_span( attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))}) span = self._start_span( - f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT + f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL ) self._add_event_messages(span, messages) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index de677c2cc..05dbe387f 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -153,7 +153,7 @@ def test_start_model_invoke_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" - assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.CLIENT + assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) @@ -188,7 +188,7 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" - assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.CLIENT + assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) @@ -670,6 +670,7 @@ def test_start_agent_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" + assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) From 8a89d91ec1b769d2d2752d61da8e583ac45d13c5 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Thu, 23 Oct 2025 02:35:51 +0800 Subject: [PATCH 21/26] feat: add multiagent hooks, add serialize & deserialize function to multiagent base & agent result (#1070) * feat: add multiagent hooks, add serialize & deserialize function to multiagent base & agent result * Delete __init__.py --- src/strands/agent/agent_result.py | 33 ++++- .../experimental/hooks/multiagent/__init__.py | 20 +++ .../experimental/hooks/multiagent/events.py | 93 ++++++++++++++ src/strands/multiagent/base.py | 114 ++++++++++++++++++ .../fixtures/mock_multiagent_hook_provider.py | 41 +++++++ tests/strands/agent/test_agent_result.py | 45 +++++++ .../experimental/hooks/multiagent/__init__.py | 0 .../hooks/multiagent/test_events.py | 107 ++++++++++++++++ tests/strands/multiagent/test_base.py | 65 ++++++++++ 9 files changed, 517 insertions(+), 1 deletion(-) create mode 100644 src/strands/experimental/hooks/multiagent/__init__.py create mode 100644 src/strands/experimental/hooks/multiagent/events.py create mode 100644 tests/fixtures/mock_multiagent_hook_provider.py create mode 100644 tests/strands/experimental/hooks/multiagent/__init__.py create mode 100644 tests/strands/experimental/hooks/multiagent/test_events.py diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index eb9bc4dd9..12c1f8376 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,7 +4,7 @@ """ from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any, Sequence, cast from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics @@ -46,3 +46,34 @@ def __str__(self) -> str: if isinstance(item, dict) and "text" in item: result += item.get("text", "") + "\n" return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AgentResult": + """Rehydrate an AgentResult from persisted JSON. + + Args: + data: Dictionary containing the serialized AgentResult data + Returns: + AgentResult instance + Raises: + TypeError: If the data format is invalid@ + """ + if data.get("type") != "agent_result": + raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}") + + message = cast(Message, data.get("message")) + stop_reason = cast(StopReason, data.get("stop_reason")) + + return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + + def to_dict(self) -> dict[str, Any]: + """Convert this AgentResult to JSON-serializable dictionary. + + Returns: + Dictionary containing serialized AgentResult data + """ + return { + "type": "agent_result", + "message": self.message, + "stop_reason": self.stop_reason, + } diff --git a/src/strands/experimental/hooks/multiagent/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..d059d0da5 --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -0,0 +1,20 @@ +"""Multi-agent hook events and utilities. + +Provides event classes for hooking into multi-agent orchestrator lifecycle. +""" + +from .events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) + +__all__ = [ + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", +] diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py new file mode 100644 index 000000000..9e54296a4 --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -0,0 +1,93 @@ +"""Multi-agent execution lifecycle events for hook system integration. + +These events are fired by orchestrators (Graph/Swarm) at key points so +hooks can persist, monitor, or debug execution. No intermediate state model +is used—hooks read from the orchestrator directly. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from ....hooks import BaseHookEvent + +if TYPE_CHECKING: + from ....multiagent.base import MultiAgentBase + + +@dataclass +class MultiAgentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class BeforeNodeCallEvent(BaseHookEvent): + """Event triggered before individual node execution starts. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node about to execute + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterNodeCallEvent(BaseHookEvent): + """Event triggered after individual node execution completes. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered before orchestrator execution starts. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 0dbd85d81..07e63577d 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -4,6 +4,7 @@ """ import asyncio +import logging import warnings from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -15,6 +16,8 @@ from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage +logger = logging.getLogger(__name__) + class Status(Enum): """Execution status for both graphs and nodes.""" @@ -59,6 +62,54 @@ def get_agent_results(self) -> list[AgentResult]: flattened.extend(nested_node_result.get_agent_results()) return flattened + def to_dict(self) -> dict[str, Any]: + """Convert NodeResult to JSON-serializable dict, ignoring state field.""" + if isinstance(self.result, Exception): + result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} + elif isinstance(self.result, AgentResult): + result_data = self.result.to_dict() + else: + # MultiAgentResult case + result_data = self.result.to_dict() + + return { + "result": result_data, + "execution_time": self.execution_time, + "status": self.status.value, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeResult": + """Rehydrate a NodeResult from persisted JSON.""" + if "result" not in data: + raise TypeError("NodeResult.from_dict: missing 'result'") + raw = data["result"] + + result: Union[AgentResult, "MultiAgentResult", Exception] + if isinstance(raw, dict) and raw.get("type") == "agent_result": + result = AgentResult.from_dict(raw) + elif isinstance(raw, dict) and raw.get("type") == "exception": + result = Exception(str(raw.get("message", "node failed"))) + elif isinstance(raw, dict) and raw.get("type") == "multiagent_result": + result = MultiAgentResult.from_dict(raw) + else: + raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") + + usage = _parse_usage(data.get("accumulated_usage", {})) + metrics = _parse_metrics(data.get("accumulated_metrics", {})) + + return cls( + result=result, + execution_time=int(data.get("execution_time", 0)), + status=Status(data.get("status", "pending")), + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + ) + @dataclass class MultiAgentResult: @@ -76,6 +127,38 @@ class MultiAgentResult: execution_count: int = 0 execution_time: int = 0 + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": + """Rehydrate a MultiAgentResult from persisted JSON.""" + if data.get("type") != "multiagent_result": + raise TypeError(f"MultiAgentResult.from_dict: unexpected type {data.get('type')!r}") + + results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} + usage = _parse_usage(data.get("accumulated_usage", {})) + metrics = _parse_metrics(data.get("accumulated_metrics", {})) + + multiagent_result = cls( + status=Status(data.get("status", Status.PENDING.value)), + results=results, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + execution_time=int(data.get("execution_time", 0)), + ) + return multiagent_result + + def to_dict(self) -> dict[str, Any]: + """Convert MultiAgentResult to JSON-serializable dict.""" + return { + "type": "multiagent_result", + "status": self.status.value, + "results": {k: v.to_dict() for k, v in self.results.items()}, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, + "execution_time": self.execution_time, + } + class MultiAgentBase(ABC): """Base class for multi-agent helpers. @@ -122,3 +205,34 @@ def execute() -> MultiAgentResult: with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() + + def serialize_state(self) -> dict[str, Any]: + """Return a JSON-serializable snapshot of the orchestrator state.""" + raise NotImplementedError + + def deserialize_state(self, payload: dict[str, Any]) -> None: + """Restore orchestrator state from a session dict.""" + raise NotImplementedError + + +# Private helper function to avoid duplicate code + + +def _parse_usage(usage_data: dict[str, Any]) -> Usage: + """Parse Usage from dict data.""" + usage = Usage( + inputTokens=usage_data.get("inputTokens", 0), + outputTokens=usage_data.get("outputTokens", 0), + totalTokens=usage_data.get("totalTokens", 0), + ) + # Add optional fields if they exist + if "cacheReadInputTokens" in usage_data: + usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] + if "cacheWriteInputTokens" in usage_data: + usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] + return usage + + +def _parse_metrics(metrics_data: dict[str, Any]) -> Metrics: + """Parse Metrics from dict data.""" + return Metrics(latencyMs=metrics_data.get("latencyMs", 0)) diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py new file mode 100644 index 000000000..727d28a48 --- /dev/null +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -0,0 +1,41 @@ +from typing import Iterator, Literal, Tuple, Type + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import ( + HookEvent, + HookProvider, + HookRegistry, +) + + +class MockMultiAgentHookProvider(HookProvider): + def __init__(self, event_types: list[Type] | Literal["all"]): + if event_types == "all": + event_types = [ + MultiAgentInitializedEvent, + BeforeNodeCallEvent, + AfterNodeCallEvent, + AfterMultiAgentInvocationEvent, + ] + + self.events_received = [] + self.events_types = event_types + + @property + def event_types_received(self): + return [type(event) for event in self.events_received] + + def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + return len(self.events_received), iter(self.events_received) + + def register_hooks(self, registry: HookRegistry) -> None: + for event_type in self.events_types: + registry.add_callback(event_type, self.add_event) + + def add_event(self, event: HookEvent) -> None: + self.events_received.append(event) diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 409b08a2d..67a7f2458 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -95,3 +95,48 @@ def test__str__non_dict_content(mock_metrics): message_string = str(result) assert message_string == "Valid text\nMore valid text\n" + + +def test_to_dict(mock_metrics, simple_message: Message): + """Test that to_dict serializes AgentResult correctly.""" + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={"key": "value"}) + + data = result.to_dict() + + assert data == { + "type": "agent_result", + "message": simple_message, + "stop_reason": "end_turn", + } + + +def test_from_dict(): + """Test that from_dict works with valid data.""" + data = { + "type": "agent_result", + "message": {"role": "assistant", "content": [{"text": "Test response"}]}, + "stop_reason": "end_turn", + } + + result = AgentResult.from_dict(data) + + assert result.message == data["message"] + assert result.stop_reason == data["stop_reason"] + assert isinstance(result.metrics, EventLoopMetrics) + assert result.state == {} + + +def test_roundtrip_serialization(mock_metrics, complex_message: Message): + """Test that to_dict() and from_dict() work together correctly.""" + original = AgentResult( + stop_reason="max_tokens", message=complex_message, metrics=mock_metrics, state={"test": "data"} + ) + + # Serialize and deserialize + data = original.to_dict() + restored = AgentResult.from_dict(data) + + assert restored.message == original.message + assert restored.stop_reason == original.stop_reason + assert isinstance(restored.metrics, EventLoopMetrics) + assert restored.state == {} # State is not serialized diff --git a/tests/strands/experimental/hooks/multiagent/__init__.py b/tests/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/hooks/multiagent/test_events.py b/tests/strands/experimental/hooks/multiagent/test_events.py new file mode 100644 index 000000000..6c4d7c4e7 --- /dev/null +++ b/tests/strands/experimental/hooks/multiagent/test_events.py @@ -0,0 +1,107 @@ +"""Tests for multi-agent execution lifecycle events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import BaseHookEvent + + +@pytest.fixture +def orchestrator(): + """Mock orchestrator for testing.""" + return Mock() + + +def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): + """Test MultiAgentInitializedEvent creation with orchestrator only.""" + event = MultiAgentInitializedEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_multi_agent_initialization_event_with_invocation_state(orchestrator): + """Test MultiAgentInitializedEvent creation with invocation state.""" + invocation_state = {"key": "value"} + event = MultiAgentInitializedEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state + + +def test_after_node_invocation_event_with_required_fields(orchestrator): + """Test AfterNodeCallEvent creation with required fields.""" + node_id = "node_1" + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_node_invocation_event_with_invocation_state(orchestrator): + """Test AfterNodeCallEvent creation with invocation state.""" + node_id = "node_2" + invocation_state = {"result": "success"} + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state == invocation_state + + +def test_after_multi_agent_invocation_event_with_orchestrator_only(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with orchestrator only.""" + event = AfterMultiAgentInvocationEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_multi_agent_invocation_event_with_invocation_state(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with invocation state.""" + invocation_state = {"final_state": "completed"} + event = AfterMultiAgentInvocationEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state + + +def test_before_node_call_event(orchestrator): + """Test BeforeNodeCallEvent creation.""" + node_id = "node_1" + event = BeforeNodeCallEvent(source=orchestrator, node_id=node_id) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_before_multi_agent_invocation_event(orchestrator): + """Test BeforeMultiAgentInvocationEvent creation.""" + event = BeforeMultiAgentInvocationEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_events_should_reverse_callbacks(orchestrator): + """Test that After events have should_reverse_callbacks property set to True.""" + after_node_event = AfterNodeCallEvent(source=orchestrator, node_id="test") + after_invocation_event = AfterMultiAgentInvocationEvent(source=orchestrator) + + assert after_node_event.should_reverse_callbacks is True + assert after_invocation_event.should_reverse_callbacks is True diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index ab55b2c84..4e8a5dd06 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -28,6 +28,9 @@ def test_node_result_initialization_and_properties(agent_result): assert node_result.accumulated_metrics == {"latencyMs": 0.0} assert node_result.execution_count == 0 + default_node = NodeResult(result=agent_result) + assert default_node.status == Status.PENDING + # With custom metrics custom_usage = {"inputTokens": 100, "outputTokens": 200, "totalTokens": 300} custom_metrics = {"latencyMs": 250.0} @@ -95,6 +98,7 @@ def test_multi_agent_result_initialization(agent_result): assert result.accumulated_metrics == {"latencyMs": 0.0} assert result.execution_count == 0 assert result.execution_time == 0 + assert result.status == Status.PENDING # Custom values`` node_result = NodeResult(result=agent_result) @@ -141,6 +145,12 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) @@ -164,6 +174,12 @@ async def invoke_async(self, task, invocation_state, **kwargs): status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + agent = TestMultiAgent() # Test with string task @@ -174,3 +190,52 @@ async def invoke_async(self, task, invocation_state, **kwargs): assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"} assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED + + +def test_node_result_to_dict(agent_result): + """Test NodeResult to_dict method.""" + node_result = NodeResult(result=agent_result, execution_time=100, status=Status.COMPLETED) + result_dict = node_result.to_dict() + + assert result_dict["execution_time"] == 100 + assert result_dict["status"] == "completed" + assert result_dict["result"]["type"] == "agent_result" + assert result_dict["result"]["stop_reason"] == agent_result.stop_reason + assert result_dict["result"]["message"] == agent_result.message + + exception_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + result_dict = exception_result.to_dict() + + assert result_dict["result"]["type"] == "exception" + assert result_dict["result"]["message"] == "Test error" + assert result_dict["status"] == "failed" + + +def test_multi_agent_result_to_dict(agent_result): + """Test MultiAgentResult to_dict method.""" + node_result = NodeResult(result=agent_result) + multi_result = MultiAgentResult(status=Status.COMPLETED, results={"test_node": node_result}, execution_time=200) + + result_dict = multi_result.to_dict() + + assert result_dict["status"] == "completed" + assert result_dict["execution_time"] == 200 + assert "test_node" in result_dict["results"] + assert result_dict["results"]["test_node"]["result"]["type"] == "agent_result" + + +def test_serialize_node_result_for_persist(agent_result): + """Test serialize_node_result_for_persist method.""" + + node_result = NodeResult(result=agent_result) + serialized = node_result.to_dict() + + assert "result" in serialized + assert "execution_time" in serialized + assert "status" in serialized + + exception_node_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + serialized_exception = exception_node_result.to_dict() + assert "result" in serialized_exception + assert serialized_exception["result"]["type"] == "exception" + assert serialized_exception["result"]["message"] == "Test error" From 648af228aed534b7fee46d7c0bb485fd4b2fb520 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Wed, 22 Oct 2025 16:47:47 -0400 Subject: [PATCH 22/26] feat: Add Structured Output as part of the agent loop (#943) feat: Add Structured Output as part of the agent loop (#943) Add comprehensive structured output functionality allowing agents to return Pydantic models in the AgentResult. Includes support for validation, retry logic, streaming, and async operations. - Add structured_output_model parameter to Agent constructor and invocation methods - Implement StructuredOutputTool for handling Pydantic model validation - Add structured output context management and retry mechanisms - Extend event system with StructuredOutputEvent and reasoning events - Add structured_output field to AgentResult for accessing parsed models - Support structured output in streaming and async operations - Add comprehensive test coverage for all structured output scenarios - Add integration tests for real-world usage patterns --- src/strands/__init__.py | 10 +- src/strands/agent/agent.py | 92 +++- src/strands/agent/agent_result.py | 4 + .../summarizing_conversation_manager.py | 15 +- src/strands/event_loop/event_loop.py | 89 +++- src/strands/event_loop/streaming.py | 4 +- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 7 + src/strands/tools/_tool_helpers.py | 15 + src/strands/tools/executors/_executor.py | 16 +- src/strands/tools/executors/concurrent.py | 8 +- src/strands/tools/executors/sequential.py | 5 +- src/strands/tools/registry.py | 15 + .../tools/structured_output/__init__.py | 5 + .../_structured_output_context.py | 143 ++++++ .../structured_output_tool.py | 158 ++++++ .../structured_output_utils.py} | 2 +- src/strands/types/_events.py | 17 +- src/strands/types/exceptions.py | 13 + tests/fixtures/mocked_model_provider.py | 6 +- tests/strands/agent/test_agent.py | 4 + tests/strands/agent/test_agent_result.py | 63 ++- .../agent/test_agent_structured_output.py | 414 ++++++++++++++++ tests/strands/event_loop/test_event_loop.py | 19 +- .../test_event_loop_structured_output.py | 439 ++++++++++++++++ tests/strands/event_loop/test_streaming.py | 1 + .../test_streaming_structured_output.py | 157 ++++++ tests/strands/models/test_model.py | 45 ++ .../tools/executors/test_concurrent.py | 18 +- .../tools/executors/test_sequential.py | 87 +++- .../test_structured_output_context.py | 245 +++++++++ .../test_structured_output_tool.py | 307 ++++++++++++ tests/strands/types/test__events.py | 467 ++++++++++++++++++ tests/strands/types/test_exceptions.py | 387 +++++++++++++++ tests_integ/models/test_conformance.py | 17 + .../test_structured_output_agent_loop.py | 330 +++++++++++++ 36 files changed, 3562 insertions(+), 64 deletions(-) create mode 100644 src/strands/tools/_tool_helpers.py create mode 100644 src/strands/tools/structured_output/__init__.py create mode 100644 src/strands/tools/structured_output/_structured_output_context.py create mode 100644 src/strands/tools/structured_output/structured_output_tool.py rename src/strands/tools/{structured_output.py => structured_output/structured_output_utils.py} (99%) create mode 100644 tests/strands/agent/test_agent_structured_output.py create mode 100644 tests/strands/event_loop/test_event_loop_structured_output.py create mode 100644 tests/strands/event_loop/test_streaming_structured_output.py create mode 100644 tests/strands/tools/structured_output/test_structured_output_context.py create mode 100644 tests/strands/tools/structured_output/test_structured_output_tool.py create mode 100644 tests/strands/types/test__events.py create mode 100644 tests/strands/types/test_exceptions.py create mode 100644 tests_integ/test_structured_output_agent_loop.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index ae784a58f..3718a29c5 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -5,4 +5,12 @@ from .tools.decorator import tool from .types.tools import ToolContext -__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"] +__all__ = [ + "Agent", + "agent", + "models", + "tool", + "ToolContext", + "types", + "telemetry", +] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f963f14e7..1de75cfd2 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -50,6 +50,7 @@ from ..tools.executors import ConcurrentToolExecutor from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry +from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput @@ -216,6 +217,7 @@ def __init__( messages: Optional[Messages] = None, tools: Optional[list[Union[str, dict[str, str], Any]]] = None, system_prompt: Optional[str] = None, + structured_output_model: Optional[Type[BaseModel]] = None, callback_handler: Optional[ Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] ] = _DEFAULT_CALLBACK_HANDLER, @@ -251,6 +253,10 @@ def __init__( If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. If None, the model will behave according to its default settings. + structured_output_model: Pydantic model type(s) for structured output. + When specified, all agent calls will attempt to return structured output of this type. + This can be overridden on the agent invocation. + Defaults to None (no structured output). callback_handler: Callback for processing events as they happen during agent execution. If not provided (using the default), a new PrintingCallbackHandler instance is created. If explicitly set to None, null_callback_handler is used. @@ -280,8 +286,8 @@ def __init__( """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] - self.system_prompt = system_prompt + self._default_structured_output_model = structured_output_model self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description @@ -383,7 +389,12 @@ def tool_names(self) -> list[str]: return list(all_tools.keys()) def __call__( - self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -400,6 +411,7 @@ def __call__( - list[Message]: Complete messages with roles - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. + structured_output_model: Pydantic model type(s) for structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -409,17 +421,27 @@ def __call__( - message: The final message from the model - metrics: Performance metrics from the event loop - state: The final state of the event loop + - structured_output: Parsed structured output when structured_output_model was specified """ def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs)) + return asyncio.run( + self.invoke_async( + prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + ) + ) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() async def invoke_async( - self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -436,6 +458,7 @@ async def invoke_async( - list[Message]: Complete messages with roles - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. + structured_output_model: Pydantic model type(s) for structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -446,7 +469,9 @@ async def invoke_async( - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - events = self.stream_async(prompt, invocation_state=invocation_state, **kwargs) + events = self.stream_async( + prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + ) async for event in events: _ = event @@ -473,6 +498,13 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> Raises: ValueError: If no conversation history or prompt is provided. """ + warnings.warn( + "Agent.structured_output method is deprecated." + " You should pass in `structured_output_model` directly into the agent invocation." + " see: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/structured-output/", + category=DeprecationWarning, + stacklevel=2, + ) def execute() -> T: return asyncio.run(self.structured_output_async(output_model, prompt)) @@ -501,6 +533,13 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu if self._interrupt_state.activated: raise RuntimeError("cannot call structured output during interrupt") + warnings.warn( + "Agent.structured_output_async method is deprecated." + " You should pass in `structured_output_model` directly into the agent invocation." + " see: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/structured-output/", + category=DeprecationWarning, + stacklevel=2, + ) self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT @@ -545,7 +584,12 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def stream_async( - self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -562,6 +606,7 @@ async def stream_async( - list[Message]: Complete messages with roles - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. + structured_output_model: Pydantic model type(s) for structured output (overrides agent default). **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -606,7 +651,7 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, invocation_state=merged_state) + events = self._run_loop(messages, merged_state, structured_output_model) async for event in events: event.prepare(invocation_state=merged_state) @@ -658,12 +703,18 @@ def _resume_interrupt(self, prompt: AgentInput) -> None: self._interrupt_state.interrupts[interrupt_id].response = interrupt_response - async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + async def _run_loop( + self, + messages: Messages, + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, + ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. Args: messages: The input messages to add to the conversation. invocation_state: Additional parameters to pass to the event loop. + structured_output_model: Optional Pydantic model type for structured output. Yields: Events from the event loop cycle. @@ -676,8 +727,12 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) for message in messages: self._append_message(message) + structured_output_context = StructuredOutputContext( + structured_output_model or self._default_structured_output_model + ) + # Execute the event loop cycle with retry logic for context limits - events = self._execute_event_loop_cycle(invocation_state) + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) async for event in events: # Signal from the model provider that the message sent by the user should be redacted, # likely due to a guardrail. @@ -698,24 +753,33 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) self.conversation_manager.apply_management(self) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + async def _execute_event_loop_cycle( + self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None + ) -> AsyncGenerator[TypedEvent, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements retry logic for handling context window overflow exceptions by reducing the conversation context and retrying. + Args: + invocation_state: Additional parameters to pass to the event loop. + structured_output_context: Optional structured output context for this invocation. + Yields: Events of the loop cycle. """ # Add `Agent` to invocation_state to keep backwards-compatibility invocation_state["agent"] = self + if structured_output_context: + structured_output_context.register_tool(self.tool_registry) + try: - # Execute the main event loop cycle events = event_loop_cycle( agent=self, invocation_state=invocation_state, + structured_output_context=structured_output_context, ) async for event in events: yield event @@ -728,10 +792,14 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A if self._session_manager: self._session_manager.sync_agent(self) - events = self._execute_event_loop_cycle(invocation_state) + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) async for event in events: yield event + finally: + if structured_output_context: + structured_output_context.cleanup(self.tool_registry) + def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: if self._interrupt_state.activated: return [] diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index 12c1f8376..076a94d7a 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any, Sequence, cast +from pydantic import BaseModel + from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message @@ -22,6 +24,7 @@ class AgentResult: metrics: Performance metrics collected during processing. state: Additional state information from the event loop. interrupts: List of interrupts if raised by user. + structured_output: Parsed structured output when structured_output_model was specified. """ stop_reason: StopReason @@ -29,6 +32,7 @@ class AgentResult: metrics: EventLoopMetrics state: Any interrupts: Sequence[Interrupt] | None = None + structured_output: BaseModel | None = None def __str__(self) -> str: """Get the agent's last message as a string. diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 117626fbe..12185c286 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -5,10 +5,11 @@ from typing_extensions import override -from ...tools import tool +from ...tools._tool_helpers import noop_tool from ...tools.registry import ToolRegistry from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException +from ...types.tools import AgentTool from .conversation_manager import ConversationManager if TYPE_CHECKING: @@ -208,7 +209,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: # Add no-op tool if agent has no tools to satisfy tool spec requirement if not summarization_agent.tool_names: tool_registry = ToolRegistry() - tool_registry.register_tool(self._noop_tool) + tool_registry.register_tool(cast(AgentTool, noop_tool)) summarization_agent.tool_registry = tool_registry summarization_agent.messages = messages @@ -264,13 +265,3 @@ def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_poin raise ContextWindowOverflowException("Unable to trim conversation context!") return split_point - - @tool(name="noop", description="MUST NOT call or summarize") - def _noop_tool(self) -> None: - """No-op tool to satisfy tool spec requirement when tool messages are present. - - Some model provides (e.g., Bedrock) will return an error response if tool uses and tool results are present in - messages without any tool specs configured. Consequently, if the summarization agent has no registered tools, - summarization will fail. As a workaround, we register the no-op tool. - """ - pass diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 7a9c60c3b..116f7956d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -19,6 +19,7 @@ from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools +from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..types._events import ( EventLoopStopEvent, EventLoopThrottleEvent, @@ -27,6 +28,7 @@ ModelStopReason, StartEvent, StartEventLoopEvent, + StructuredOutputEvent, ToolInterruptEvent, ToolResultMessageEvent, TypedEvent, @@ -37,6 +39,7 @@ EventLoopException, MaxTokensReachedException, ModelThrottledException, + StructuredOutputException, ) from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse @@ -53,7 +56,11 @@ MAX_DELAY = 240 # 4 minutes -async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: +async def event_loop_cycle( + agent: "Agent", + invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, +) -> AsyncGenerator[TypedEvent, None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -74,6 +81,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> - request_state: State maintained across cycles - event_loop_cycle_id: Unique ID for this cycle - event_loop_cycle_span: Current tracing Span for this cycle + structured_output_context: Optional context for structured output management. Yields: Model and tool stream events. The last event is a tuple containing: @@ -87,6 +95,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> EventLoopException: If an error occurs during execution ContextWindowOverflowException: If the input is too large for the model """ + structured_output_context = structured_output_context or StructuredOutputContext() + # Initialize cycle state invocation_state["event_loop_cycle_id"] = uuid.uuid4() @@ -113,7 +123,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> message = agent._interrupt_state.context["tool_use_message"] else: - model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) + model_events = _handle_model_execution( + agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context + ) async for model_event in model_events: if not isinstance(model_event, ModelStopReason): yield model_event @@ -138,7 +150,6 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) ) - # If the model is requesting to use tools if stop_reason == "tool_use": # Handle tool execution tool_events = _handle_tool_execution( @@ -150,6 +161,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time=cycle_start_time, invocation_state=invocation_state, tracer=tracer, + structured_output_context=structured_output_context, ) async for tool_event in tool_events: yield tool_event @@ -184,10 +196,33 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e + # Force structured output tool call if LLM didn't use it automatically + if structured_output_context.is_enabled and stop_reason == "end_turn": + if structured_output_context.force_attempted: + raise StructuredOutputException( + "The model failed to invoke the structured output tool even after it was forced." + ) + structured_output_context.set_forced_mode() + logger.debug("Forcing structured output tool") + agent._append_message( + {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} + ) + + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) + async for typed_event in events: + yield typed_event + return + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: +async def recurse_event_loop( + agent: "Agent", + invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, +) -> AsyncGenerator[TypedEvent, None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. @@ -195,7 +230,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - Args: agent: Agent for which the recursive call is being made. invocation_state: Arguments to pass through event_loop_cycle - + structured_output_context: Optional context for structured output management. Yields: Results from event_loop_cycle where the last result contains: @@ -213,7 +248,9 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - yield StartEvent() - events = event_loop_cycle(agent=agent, invocation_state=invocation_state) + events = event_loop_cycle( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) async for event in events: yield event @@ -226,6 +263,7 @@ async def _handle_model_execution( cycle_trace: Trace, invocation_state: dict[str, Any], tracer: Tracer, + structured_output_context: StructuredOutputContext, ) -> AsyncGenerator[TypedEvent, None]: """Handle model execution with retry logic for throttling exceptions. @@ -238,6 +276,7 @@ async def _handle_model_execution( cycle_trace: Trace object for the current event loop cycle. invocation_state: State maintained across cycles. tracer: Tracer instance for span management. + structured_output_context: Context for structured output management. Yields: Model stream events and throttle events during retries. @@ -266,10 +305,15 @@ async def _handle_model_execution( ) ) - tool_specs = agent.tool_registry.get_all_tool_specs() - + if structured_output_context.forced_mode: + tool_spec = structured_output_context.get_tool_spec() + tool_specs = [tool_spec] if tool_spec else [] + else: + tool_specs = agent.tool_registry.get_all_tool_specs() try: - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): + async for event in stream_messages( + agent.model, agent.system_prompt, agent.messages, tool_specs, structured_output_context.tool_choice + ): yield event stop_reason, message, usage, metrics = event["stop"] @@ -354,6 +398,7 @@ async def _handle_tool_execution( cycle_start_time: float, invocation_state: dict[str, Any], tracer: Tracer, + structured_output_context: StructuredOutputContext, ) -> AsyncGenerator[TypedEvent, None]: """Handles the execution of tools requested by the model during an event loop cycle. @@ -366,6 +411,7 @@ async def _handle_tool_execution( cycle_start_time: Start time of the current cycle. invocation_state: Additional keyword arguments, including request state. tracer: Tracer instance for span management. + structured_output_context: Optional context for structured output management. Yields: Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple @@ -394,7 +440,7 @@ async def _handle_tool_execution( interrupts = [] tool_events = agent.tool_executor._execute( - agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) async for tool_event in tool_events: if isinstance(tool_event, ToolInterruptEvent): @@ -402,7 +448,12 @@ async def _handle_tool_execution( yield tool_event - # Store parent cycle ID for the next cycle + structured_output_result = None + if structured_output_context.is_enabled: + if structured_output_result := structured_output_context.extract_result(tool_uses): + yield StructuredOutputEvent(structured_output=structured_output_result) + structured_output_context.stop_loop = True + invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] if interrupts: @@ -416,6 +467,7 @@ async def _handle_tool_execution( agent.event_loop_metrics, invocation_state["request_state"], interrupts, + structured_output=structured_output_result, ) if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message) @@ -431,16 +483,25 @@ async def _handle_tool_execution( agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) + yield ToolResultMessageEvent(message=tool_result_message) if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) - if invocation_state["request_state"].get("stop_event_loop", False): + if invocation_state["request_state"].get("stop_event_loop", False) or structured_output_context.stop_loop: agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + yield EventLoopStopEvent( + stop_reason, + message, + agent.event_loop_metrics, + invocation_state["request_state"], + structured_output=structured_output_result, + ) return - events = recurse_event_loop(agent=agent, invocation_state=invocation_state) + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) async for event in events: yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 73f38de8a..6d847f8af 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -346,6 +346,7 @@ async def stream_messages( system_prompt: Optional[str], messages: Messages, tool_specs: list[ToolSpec], + tool_choice: Optional[Any] = None, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -354,6 +355,7 @@ async def stream_messages( system_prompt: The system prompt to send. messages: List of messages to send. tool_specs: The list of tool specs. + tool_choice: Optional tool choice constraint for forcing specific tool usage. Yields: The reason for stopping, the final message, and the usage metrics @@ -362,7 +364,7 @@ async def stream_messages( messages = remove_blank_messages_content_text(messages) start_time = time.time() - chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt) + chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice) async for event in process_stream(chunks, start_time): yield event diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index a95b0d027..48351da19 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -14,7 +14,7 @@ from typing_extensions import Required, Unpack, override from ..event_loop.streaming import process_stream -from ..tools import convert_pydantic_to_tool_spec +from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c465a2f38..43a3a3ed4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -19,6 +19,7 @@ from .._exception_notes import add_exception_note from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec +from ..tools._tool_helpers import noop_tool from ..types.content import ContentBlock, Messages from ..types.exceptions import ( ContextWindowOverflowException, @@ -204,6 +205,12 @@ def format_request( Returns: A Bedrock converse stream request. """ + if not tool_specs: + has_tool_content = any( + any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages + ) + if has_tool_content: + tool_specs = [noop_tool.tool_spec] return { "modelId": self.config["model_id"], "messages": self._format_bedrock_messages(messages), diff --git a/src/strands/tools/_tool_helpers.py b/src/strands/tools/_tool_helpers.py new file mode 100644 index 000000000..d640f23b8 --- /dev/null +++ b/src/strands/tools/_tool_helpers.py @@ -0,0 +1,15 @@ +"""Helpers for tools.""" + +from strands.tools.decorator import tool + + +# https://github.com/strands-agents/sdk-python/issues/998 +@tool(name="noop", description="This is a fake tool that MUST be completely ignored.") +def noop_tool() -> None: + """No-op tool to satisfy tool spec requirement when tool messages are present. + + Some model providers (e.g., Bedrock) will return an error response if tool uses and tool results are present in + messages without any tool specs configured. Consequently, if the summarization agent has no registered tools, + summarization will fail. As a workaround, we register the no-op tool. + """ + pass diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 44c2dc36a..81a594488 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -17,6 +17,7 @@ from ...types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse +from ..structured_output._structured_output_context import StructuredOutputContext if TYPE_CHECKING: # pragma: no cover from ...agent import Agent @@ -33,6 +34,7 @@ async def _stream( tool_use: ToolUse, tool_results: list[ToolResult], invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Stream tool events. @@ -50,6 +52,7 @@ async def _stream( tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output management. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -57,6 +60,7 @@ async def _stream( """ logger.debug("tool_use=<%s> | streaming", tool_use) tool_name = tool_use["name"] + structured_output_context = structured_output_context or StructuredOutputContext() tool_info = agent.tool_registry.dynamic_tools.get(tool_name) tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) @@ -155,7 +159,8 @@ async def _stream( yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return - + if structured_output_context.is_enabled: + kwargs["structured_output_context"] = structured_output_context async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. @@ -220,6 +225,7 @@ async def _stream_with_trace( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: StructuredOutputContext | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Execute tool with tracing and metrics collection. @@ -231,12 +237,14 @@ async def _stream_with_trace( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output management. **kwargs: Additional keyword arguments for future extensibility. Yields: Tool events with the last being the tool result. """ tool_name = tool_use["name"] + structured_output_context = structured_output_context or StructuredOutputContext() tracer = get_tracer() @@ -245,7 +253,9 @@ async def _stream_with_trace( tool_start_time = time.time() with trace_api.use_span(tool_call_span): - async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): + async for event in ToolExecutor._stream( + agent, tool_use, tool_results, invocation_state, structured_output_context, **kwargs + ): yield event if isinstance(event, ToolInterruptEvent): @@ -273,6 +283,7 @@ def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: "StructuredOutputContext", ) -> AsyncGenerator[TypedEvent, None]: """Execute the given tools according to this executor's strategy. @@ -283,6 +294,7 @@ def _execute( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output management. Yields: Events from the tool execution stream. diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 8ef8a8b65..bf78d6f6a 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ..structured_output._structured_output_context import StructuredOutputContext class ConcurrentToolExecutor(ToolExecutor): @@ -26,6 +27,7 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: "StructuredOutputContext", ) -> AsyncGenerator[TypedEvent, None]: """Execute tools concurrently. @@ -36,6 +38,7 @@ async def _execute( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output handling. Yields: Events from the tool execution stream. @@ -57,6 +60,7 @@ async def _execute( task_queue, task_events[task_id], stop_event, + structured_output_context, ) ) for task_id, tool_use in enumerate(tool_uses) @@ -84,6 +88,7 @@ async def _task( task_queue: asyncio.Queue, task_event: asyncio.Event, stop_event: object, + structured_output_context: "StructuredOutputContext", ) -> None: """Execute a single tool and put results in the task queue. @@ -98,10 +103,11 @@ async def _task( task_queue: Queue to put tool events into. task_event: Event to signal when task can continue. stop_event: Sentinel object to signal task completion. + structured_output_context: Context for structured output handling. """ try: events = ToolExecutor._stream_with_trace( - agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) async for event in events: task_queue.put_nowait((task_id, event)) diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index adbd5a5d3..74024455a 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ..structured_output._structured_output_context import StructuredOutputContext class SequentialToolExecutor(ToolExecutor): @@ -25,6 +26,7 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], + structured_output_context: "StructuredOutputContext", ) -> AsyncGenerator[TypedEvent, None]: """Execute tools sequentially. @@ -37,6 +39,7 @@ async def _execute( cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle. invocation_state: Context for the tool invocation. + structured_output_context: Context for structured output handling. Yields: Events from the tool execution stream. @@ -45,7 +48,7 @@ async def _execute( for tool_use in tool_uses: events = ToolExecutor._stream_with_trace( - agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) async for event in events: if isinstance(event, ToolInterruptEvent): diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 3631c9dee..4f85d1168 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -524,6 +524,21 @@ def get_all_tool_specs(self) -> list[ToolSpec]: tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] return tools + def register_dynamic_tool(self, tool: AgentTool) -> None: + """Register a tool dynamically for temporary use. + + Args: + tool: The tool to register dynamically + + Raises: + ValueError: If a tool with this name already exists + """ + if tool.tool_name in self.registry or tool.tool_name in self.dynamic_tools: + raise ValueError(f"Tool '{tool.tool_name}' already exists") + + self.dynamic_tools[tool.tool_name] = tool + logger.debug("Registered dynamic tool: %s", tool.tool_name) + def validate_tool_spec(self, tool_spec: ToolSpec) -> None: """Validate tool specification against required schema. diff --git a/src/strands/tools/structured_output/__init__.py b/src/strands/tools/structured_output/__init__.py new file mode 100644 index 000000000..777d5d846 --- /dev/null +++ b/src/strands/tools/structured_output/__init__.py @@ -0,0 +1,5 @@ +"""Structured output tools for the Strands Agents framework.""" + +from .structured_output_utils import convert_pydantic_to_tool_spec + +__all__ = ["convert_pydantic_to_tool_spec"] diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py new file mode 100644 index 000000000..f33a06915 --- /dev/null +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -0,0 +1,143 @@ +"""Context management for structured output in the event loop.""" + +import logging +from typing import TYPE_CHECKING, Optional, Type + +from pydantic import BaseModel + +from ...types.tools import ToolChoice, ToolSpec, ToolUse +from .structured_output_tool import StructuredOutputTool + +if TYPE_CHECKING: + from ..registry import ToolRegistry + +logger = logging.getLogger(__name__) + + +class StructuredOutputContext: + """Per-invocation context for structured output execution.""" + + def __init__(self, structured_output_model: Type[BaseModel] | None = None): + """Initialize a new structured output context. + + Args: + structured_output_model: Optional Pydantic model type for structured output. + """ + self.results: dict[str, BaseModel] = {} + self.structured_output_model: Type[BaseModel] | None = structured_output_model + self.structured_output_tool: StructuredOutputTool | None = None + self.forced_mode: bool = False + self.force_attempted: bool = False + self.tool_choice: ToolChoice | None = None + self.stop_loop: bool = False + self.expected_tool_name: Optional[str] = None + + if structured_output_model: + self.structured_output_tool = StructuredOutputTool(structured_output_model) + self.expected_tool_name = self.structured_output_tool.tool_name + + @property + def is_enabled(self) -> bool: + """Check if structured output is enabled for this context. + + Returns: + True if a structured output model is configured, False otherwise. + """ + return self.structured_output_model is not None + + def store_result(self, tool_use_id: str, result: BaseModel) -> None: + """Store a validated structured output result. + + Args: + tool_use_id: Unique identifier for the tool use. + result: Validated Pydantic model instance. + """ + self.results[tool_use_id] = result + + def get_result(self, tool_use_id: str) -> BaseModel | None: + """Retrieve a stored structured output result. + + Args: + tool_use_id: Unique identifier for the tool use. + + Returns: + The validated Pydantic model instance, or None if not found. + """ + return self.results.get(tool_use_id) + + def set_forced_mode(self, tool_choice: dict | None = None) -> None: + """Mark this context as being in forced structured output mode. + + Args: + tool_choice: Optional tool choice configuration. + """ + if not self.is_enabled: + return + self.forced_mode = True + self.force_attempted = True + self.tool_choice = tool_choice or {"any": {}} + + def has_structured_output_tool(self, tool_uses: list[ToolUse]) -> bool: + """Check if any tool uses are for the structured output tool. + + Args: + tool_uses: List of tool use dictionaries to check. + + Returns: + True if any tool use matches the expected structured output tool name, + False if no structured output tool is present or expected. + """ + if not self.expected_tool_name: + return False + return any(tool_use.get("name") == self.expected_tool_name for tool_use in tool_uses) + + def get_tool_spec(self) -> Optional[ToolSpec]: + """Get the tool specification for structured output. + + Returns: + Tool specification, or None if no structured output model. + """ + if self.structured_output_tool: + return self.structured_output_tool.tool_spec + return None + + def extract_result(self, tool_uses: list[ToolUse]) -> BaseModel | None: + """Extract and remove structured output result from stored results. + + Args: + tool_uses: List of tool use dictionaries from the current execution cycle. + + Returns: + The structured output result if found, or None if no result available. + """ + if not self.has_structured_output_tool(tool_uses): + return None + + for tool_use in tool_uses: + if tool_use.get("name") == self.expected_tool_name: + tool_use_id = str(tool_use.get("toolUseId", "")) + result = self.results.pop(tool_use_id, None) + if result is not None: + logger.debug("Extracted structured output for %s", tool_use.get("name")) + return result + return None + + def register_tool(self, registry: "ToolRegistry") -> None: + """Register the structured output tool with the registry. + + Args: + registry: The tool registry to register the tool with. + """ + if self.structured_output_tool and self.structured_output_tool.tool_name not in registry.dynamic_tools: + registry.register_dynamic_tool(self.structured_output_tool) + logger.debug("Registered structured output tool: %s", self.structured_output_tool.tool_name) + + def cleanup(self, registry: "ToolRegistry") -> None: + """Clean up the registered structured output tool from the registry. + + Args: + registry: The tool registry to clean up the tool from. + """ + if self.structured_output_tool and self.structured_output_tool.tool_name in registry.dynamic_tools: + del registry.dynamic_tools[self.structured_output_tool.tool_name] + logger.debug("Cleaned up structured output tool: %s", self.structured_output_tool.tool_name) diff --git a/src/strands/tools/structured_output/structured_output_tool.py b/src/strands/tools/structured_output/structured_output_tool.py new file mode 100644 index 000000000..25173d048 --- /dev/null +++ b/src/strands/tools/structured_output/structured_output_tool.py @@ -0,0 +1,158 @@ +"""Structured output tool implementation. + +This module provides a real tool implementation for structured output that integrates +with the existing tool execution and error handling infrastructure. +""" + +import logging +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Type + +from pydantic import BaseModel, ValidationError +from typing_extensions import override + +from ...types._events import ToolResultEvent +from ...types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse +from .structured_output_utils import convert_pydantic_to_tool_spec + +logger = logging.getLogger(__name__) + +_TOOL_SPEC_CACHE: dict[Type[BaseModel], ToolSpec] = {} + +if TYPE_CHECKING: + from ._structured_output_context import StructuredOutputContext + + +class StructuredOutputTool(AgentTool): + """Tool implementation for structured output validation.""" + + def __init__(self, structured_output_model: Type[BaseModel]) -> None: + """Initialize a structured output tool. + + Args: + structured_output_model: The Pydantic model class that defines the expected output structure. + """ + super().__init__() + self._structured_output_type = structured_output_model + self._tool_spec = self._get_tool_spec(structured_output_model) + self._tool_spec["description"] = ( + "IMPORTANT: This StructuredOutputTool should only be invoked as the last and final tool " + f"before returning the completed result to the caller. " + f"{self._tool_spec.get('description', '')}" + ) + self._tool_name = self._tool_spec.get("name", "StructuredOutputTool") + + @classmethod + def _get_tool_spec(cls, structured_output_model: Type[BaseModel]) -> ToolSpec: + """Get a cached tool spec for the given output type. + + Args: + structured_output_model: The Pydantic model class that defines the expected output structure. + + Returns: + Cached tool specification for the output type. + """ + if structured_output_model not in _TOOL_SPEC_CACHE: + _TOOL_SPEC_CACHE[structured_output_model] = convert_pydantic_to_tool_spec(structured_output_model) + return deepcopy(_TOOL_SPEC_CACHE[structured_output_model]) + + @property + def tool_name(self) -> str: + """Get the name of the tool. + + Returns: + The name of the tool (same as the Pydantic model class name). + """ + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification for this structured output tool. + + Returns: + The tool specification generated from the Pydantic model. + """ + return self._tool_spec + + @property + def tool_type(self) -> str: + """Identifies this as a structured output tool implementation. + + Returns: + "structured_output". + """ + return "structured_output" + + @property + def structured_output_model(self) -> Type[BaseModel]: + """Get the Pydantic model type for this tool. + + Returns: + The Pydantic model class. + """ + return self._structured_output_type + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Validate the structured output and return appropriate result. + + Args: + tool_use: The tool use request containing the data to validate. + invocation_state: Context for the tool invocation (kept for compatibility). + **kwargs: Additional keyword arguments, including structured_output_context. + + Yields: + Tool events with the last being the tool result (success or error). + """ + tool_input: dict[str, Any] = tool_use.get("input", {}) + tool_use_id = str(tool_use.get("toolUseId", "")) + + context: StructuredOutputContext = kwargs.get("structured_output_context") # type: ignore + try: + validated_object = self._structured_output_type(**tool_input) + logger.debug("tool_name=<%s> | structured output validated", self._tool_name) + context.store_result(tool_use_id, validated_object) + + result: ToolResult = { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": f"Successfully validated {self._tool_name} structured output"}], + } + + yield ToolResultEvent(result) + + except ValidationError as e: + error_details = [] + for error in e.errors(): + field_path = " -> ".join(str(loc) for loc in error["loc"]) if error["loc"] else "root" + error_details.append(f"Field '{field_path}': {error['msg']}") + + error_message = f"Validation failed for {self._tool_name}. Please fix the following errors:\n" + "\n".join( + f"- {detail}" for detail in error_details + ) + logger.error( + "tool_name=<%s> | structured output validation failed | error_message=<%s>", + self._tool_name, + error_message, + ) + + # Create error result that will be sent back to the LLM so it can decide if it needs to retry + validation_error_result: ToolResult = { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": error_message}], + } + + yield ToolResultEvent(validation_error_result) + + except Exception as e: + error_message = f"Unexpected error validating {self._tool_name}: {str(e)}" + logger.exception(error_message) + + exception_result: ToolResult = { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": error_message}], + } + + yield ToolResultEvent(exception_result) diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output/structured_output_utils.py similarity index 99% rename from src/strands/tools/structured_output.py rename to src/strands/tools/structured_output/structured_output_utils.py index 2c5922925..093d67f7c 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output/structured_output_utils.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from ..types.tools import ToolSpec +from ...types.tools import ToolSpec def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 13d4a98f9..36977e90f 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Sequence, cast +from pydantic import BaseModel from typing_extensions import override from ..interrupt import Interrupt @@ -222,6 +223,7 @@ def __init__( metrics: "EventLoopMetrics", request_state: Any, interrupts: Sequence[Interrupt] | None = None, + structured_output: BaseModel | None = None, ) -> None: """Initialize with the final execution results. @@ -231,8 +233,9 @@ def __init__( metrics: Execution metrics and performance data request_state: Final state of the agent execution interrupts: Interrupts raised by user during agent execution. + structured_output: Optional structured output result """ - super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts)}) + super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts, structured_output)}) @property @override @@ -240,6 +243,18 @@ def is_callback_event(self) -> bool: return False +class StructuredOutputEvent(TypedEvent): + """Event emitted when structured output is detected and processed.""" + + def __init__(self, structured_output: BaseModel) -> None: + """Initialize with the structured output result. + + Args: + structured_output: The parsed structured output instance + """ + super().__init__({"structured_output": structured_output}) + + class EventLoopThrottleEvent(TypedEvent): """Event emitted when the event loop is throttled due to rate limiting.""" diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 90f2b8d7f..5b17ba6e7 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -75,3 +75,16 @@ class SessionException(Exception): """Exception raised when session operations fail.""" pass + + +class StructuredOutputException(Exception): + """Exception raised when structured output validation fails after maximum retry attempts.""" + + def __init__(self, message: str): + """Initialize the exception with details about the failure. + + Args: + message: The error message describing the structured output failure + """ + self.message = message + super().__init__(message) diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index c05089f34..4523a8352 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -53,7 +53,11 @@ async def structured_output( pass async def stream( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: Optional[Any] = None, ) -> AsyncGenerator[Any, None]: events = self.map_agent_message_to_events(self.agent_responses[self.index]) for event in events: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index b58e5f3fd..9d490c0de 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -329,6 +329,7 @@ def test_agent__call__( ], [tool.tool_spec], system_prompt, + tool_choice=None, ), unittest.mock.call( [ @@ -365,6 +366,7 @@ def test_agent__call__( ], [tool.tool_spec], system_prompt, + tool_choice=None, ), ], ) @@ -484,6 +486,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener expected_messages, unittest.mock.ANY, unittest.mock.ANY, + tool_choice=None, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -627,6 +630,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene expected_messages, unittest.mock.ANY, unittest.mock.ANY, + tool_choice=None, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 67a7f2458..3a3a3f5f7 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -1,7 +1,8 @@ import unittest.mock -from typing import cast +from typing import Optional, cast import pytest +from pydantic import BaseModel from strands.agent.agent_result import AgentResult from strands.telemetry.metrics import EventLoopMetrics @@ -48,6 +49,7 @@ def test__init__(mock_metrics, simple_message: Message): assert result.message == simple_message assert result.metrics == mock_metrics assert result.state == state + assert result.structured_output is None def test__str__simple(mock_metrics, simple_message: Message): @@ -140,3 +142,62 @@ def test_roundtrip_serialization(mock_metrics, complex_message: Message): assert restored.stop_reason == original.stop_reason assert isinstance(restored.metrics, EventLoopMetrics) assert restored.state == {} # State is not serialized + + +# Tests for structured output functionality +class StructuredOutputModel(BaseModel): + """Test model for structured output.""" + + name: str + value: int + optional_field: Optional[str] = None + + +def test__init__with_structured_output(mock_metrics, simple_message: Message): + """Test that AgentResult can be initialized with structured_output.""" + stop_reason: StopReason = "end_turn" + state = {"key": "value"} + structured_output = StructuredOutputModel(name="test", value=42) + + result = AgentResult( + stop_reason=stop_reason, + message=simple_message, + metrics=mock_metrics, + state=state, + structured_output=structured_output, + ) + + assert result.stop_reason == stop_reason + assert result.message == simple_message + assert result.metrics == mock_metrics + assert result.state == state + assert result.structured_output == structured_output + assert isinstance(result.structured_output, StructuredOutputModel) + assert result.structured_output.name == "test" + assert result.structured_output.value == 42 + + +def test__init__structured_output_defaults_to_none(mock_metrics, simple_message: Message): + """Test that structured_output defaults to None when not provided.""" + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + + assert result.structured_output is None + + +def test__str__with_structured_output(mock_metrics, simple_message: Message): + """Test that str() is not affected by structured_output.""" + structured_output = StructuredOutputModel(name="test", value=42) + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + structured_output=structured_output, + ) + + # The string representation should only include the message text, not structured output + message_string = str(result) + assert message_string == "Hello world!\n" + assert "test" not in message_string + assert "42" not in message_string diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py new file mode 100644 index 000000000..b679faed0 --- /dev/null +++ b/tests/strands/agent/test_agent_structured_output.py @@ -0,0 +1,414 @@ +"""Tests for Agent structured output functionality.""" + +from typing import Optional +from unittest import mock +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from strands import Agent +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output.structured_output_tool import StructuredOutputTool +from strands.types._events import EventLoopStopEvent +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +class UserModel(BaseModel): + """Test user model for structured output.""" + + name: str + age: int + email: str + + +class ProductModel(BaseModel): + """Test product model for structured output.""" + + title: str + price: float + description: Optional[str] = None + + +@pytest.fixture +def mock_model(): + """Create a mock model.""" + model = Mock() + + async def mock_stream(*args, **kwargs): + yield {"contentBlockDelta": {"delta": {"text": "test response"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + model.stream.side_effect = lambda *args, **kwargs: mock_stream(*args, **kwargs) + return model + + +@pytest.fixture +def mock_metrics(): + return mock.Mock(spec=EventLoopMetrics) + + +@pytest.fixture +def user_model(): + """Return the test user model class.""" + return UserModel + + +@pytest.fixture +def product_model(): + """Return the test product model class.""" + return ProductModel + + +class TestAgentStructuredOutputInit: + """Test Agent initialization with structured output model.""" + + def test_agent_init_with_structured_output_model(self, user_model): + """Test that Agent can be initialized with a structured_output_model.""" + agent = Agent(structured_output_model=user_model) + + assert agent._default_structured_output_model == user_model + assert agent.model is not None + + def test_agent_init_without_structured_output_model(self): + """Test that Agent can be initialized without structured_output_model.""" + agent = Agent() + + assert agent._default_structured_output_model is None + assert agent.model is not None + + +class TestAgentStructuredOutputInvocation: + """Test Agent invocation with structured output.""" + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_structured_output_model(self, mock_event_loop, user_model, mock_model, mock_metrics): + """Test Agent.__call__ with structured_output_model parameter.""" + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == user_model + + # Return a successful result + test_user = UserModel(name="John", age=30, email="john@example.com") + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=test_user, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent and call with structured_output_model + agent = Agent(model=mock_model) + agent("Extract user info", structured_output_model=user_model) + + # Verify event_loop_cycle was called with correct context + mock_event_loop.assert_called_once() + call_kwargs = mock_event_loop.call_args[1] + assert "structured_output_context" in call_kwargs + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_default_structured_output_model( + self, mock_event_loop, product_model, mock_model, mock_metrics + ): + """Test Agent.__call__ uses default structured_output_model when not specified.""" + + # Setup mock event loop + pm = ProductModel(title="Widget", price=9.99) + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == product_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=pm, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default structured_output_model + agent = Agent(model=mock_model, structured_output_model=product_model) + result = agent("Get product info") + + # Verify result uses default model + assert result.structured_output is pm + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_override_default_structured_output_model( + self, mock_event_loop, user_model, product_model, mock_model, mock_metrics + ): + """Test that invocation-level structured_output_model overrides default.""" + + # Setup mock event loop + um = UserModel(name="Jane", age=25, email="jane@example.com") + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + # Should use user_model, not the default product_model + assert structured_output_context.structured_output_model == user_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=um, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default product_model, but override with user_model + agent = Agent(model=mock_model, structured_output_model=product_model) + result = agent("Get user info", structured_output_model=user_model) + + # Verify result uses override model + assert result.structured_output is um + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_invoke_async_with_structured_output( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.invoke_async with structured_output_model.""" + + # Setup mock event loop + um = UserModel(name="Alice", age=28, email="alice@example.com") + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == user_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=um, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent and call async + agent = Agent(model=mock_model) + result = await agent.invoke_async("Get user", structured_output_model=user_model) + + # Verify result + assert result.structured_output is um + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_stream_async_with_structured_output( + self, mock_event_loop, product_model, mock_model, mock_metrics + ): + """Test Agent.stream_async with structured_output_model.""" + + # Setup mock event loop + pm = ProductModel(title="Gadget", price=19.99, description="Cool gadget") + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model == product_model + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=pm, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent and stream async + agent = Agent(model=mock_model) + events = [] + async for event in agent.stream_async("Get product", structured_output_model=product_model): + events.append(event) + + # Verify we got result event + assert len(events) > 0 + result_event = events[-1] + assert "result" in result_event + result = result_event["result"] + assert result.structured_output is pm + + +class TestAgentStructuredOutputContext: + """Test StructuredOutputContext integration with Agent.""" + + @patch("strands.agent.agent.event_loop_cycle") + def test_structured_output_context_created_with_model(self, mock_event_loop, user_model, mock_model, mock_metrics): + """Test that StructuredOutputContext is created when structured_output_model is provided.""" + context = None + + async def mock_cycle(*args, **kwargs): + nonlocal context + context = kwargs.get("structured_output_context") + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + agent("Test", structured_output_model=user_model) + + # Verify context was created and passed + assert context is not None + assert isinstance(context, StructuredOutputContext) + assert context.structured_output_model == user_model + assert context.is_enabled is True + + @patch("strands.agent.agent.event_loop_cycle") + def test_structured_output_context_none_without_model(self, mock_event_loop, mock_model, mock_metrics): + """Test that StructuredOutputContext is created with None when no model provided.""" + context = None + + async def mock_cycle(*args, **kwargs): + nonlocal context + context = kwargs.get("structured_output_context") + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + agent("Test") # No structured_output_model + + # Verify context was created but disabled + assert context is not None + assert isinstance(context, StructuredOutputContext) + assert context.structured_output_model is None + assert context.is_enabled is False + + @patch("strands.tools.registry.ToolRegistry.register_dynamic_tool") + @patch("strands.agent.agent.event_loop_cycle") + def test_structured_output_tool_registered_dynamically( + self, mock_event_loop, mock_register, user_model, mock_model, mock_metrics + ): + """Test that StructuredOutputTool is registered dynamically when structured output is used.""" + captured_tool = None + + def capture_tool(tool): + nonlocal captured_tool + captured_tool = tool + + mock_register.side_effect = capture_tool + + async def mock_cycle(*args, **kwargs): + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + agent("Test", structured_output_model=user_model) + + # Verify tool was registered + mock_register.assert_called_once() + assert captured_tool is not None + assert isinstance(captured_tool, StructuredOutputTool) + assert captured_tool.structured_output_model == user_model + + +class TestAgentStructuredOutputEdgeCases: + """Test edge cases for structured output in Agent.""" + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_with_no_structured_output(self, mock_event_loop, mock_model, mock_metrics): + """Test that agent works normally when no structured output is specified.""" + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_model is None + assert structured_output_context.is_enabled is False + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Normal response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model) + result = agent("Normal query") + + # Result should not have structured output + assert result.structured_output is None + assert result.message["content"][0]["text"] == "Normal response" + + def test_agent_multiple_structured_output_models(self, user_model, product_model, mock_metrics): + """Test that agent can switch between different structured output models.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "User response"}]}, + {"role": "assistant", "content": [{"text": "Product response"}]}, + ] + ) + + agent = Agent(model=model) + + # First call with user model + with patch("strands.agent.agent.event_loop_cycle") as mock_event_loop: + um = UserModel(name="Bob", age=40, email="bob@example.com") + + async def mock_user_cycle(*args, **kwargs): + ctx = kwargs.get("structured_output_context") + assert ctx.structured_output_model == user_model + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "User response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=um, + ) + + mock_event_loop.side_effect = mock_user_cycle + result1 = agent("Get user", structured_output_model=user_model) + assert result1.structured_output is um + + # Second call with product model + with patch("strands.agent.agent.event_loop_cycle") as mock_event_loop: + pm = ProductModel(title="Item", price=5.99) + + async def mock_product_cycle(*args, **kwargs): + ctx = kwargs.get("structured_output_context") + assert ctx.structured_output_model == product_model + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Product response"}]}, + metrics=mock_metrics, + request_state={}, + structured_output=pm, + ) + + mock_event_loop.side_effect = mock_product_cycle + result2 = agent("Get product", structured_output_model=product_model) + assert result2.structured_output is pm diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0a694bf1d..2d9af1741 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -173,7 +173,7 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -205,7 +205,7 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -243,7 +243,7 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -334,7 +334,7 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -376,6 +376,7 @@ async def test_event_loop_cycle_tool_result( ], tool_registry.get_all_tool_specs(), "p1", + tool_choice=None, ) @@ -449,7 +450,7 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -751,7 +752,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state, _ = events[-1]["stop"] + _, _, _, tru_request_state, _, _ = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -763,7 +764,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state, _ = events[-1]["stop"] + _, _, _, tru_request_state, _, _ = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state @@ -880,7 +881,7 @@ def interrupt_callback(event): stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) events = await alist(stream) - tru_stop_reason, _, _, _, tru_interrupts = events[-1]["stop"] + tru_stop_reason, _, _, _, tru_interrupts, _ = events[-1]["stop"] exp_stop_reason = "interrupt" exp_interrupts = [ Interrupt( @@ -973,7 +974,7 @@ def interrupt_callback(event): stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) events = await alist(stream) - tru_stop_reason, _, _, _, _ = events[-1]["stop"] + tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" assert tru_stop_reason == exp_stop_reason diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py new file mode 100644 index 000000000..6d3e3a9b5 --- /dev/null +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -0,0 +1,439 @@ +"""Tests for structured output integration in the event loop.""" + +from unittest.mock import Mock, patch + +import pytest +from pydantic import BaseModel + +from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.registry import ToolRegistry +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.types._events import EventLoopStopEvent, StructuredOutputEvent + + +class UserModel(BaseModel): + """Test model for structured output.""" + + name: str + age: int + email: str + + +class ProductModel(BaseModel): + """Another test model.""" + + title: str + price: float + in_stock: bool + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with required attributes.""" + agent = Mock(name="agent") + agent.model = Mock() + agent.system_prompt = "Test system prompt" + agent.messages = [] + agent.tool_registry = ToolRegistry() + agent.event_loop_metrics = EventLoopMetrics() + agent.hooks = Mock() + agent.hooks.invoke_callbacks = Mock() + agent.trace_span = None + agent.tool_executor = Mock() + agent._append_message = Mock() + + # Set up _interrupt_state properly + agent._interrupt_state = Mock() + agent._interrupt_state.activated = False + agent._interrupt_state.context = {} + + return agent + + +@pytest.fixture +def structured_output_context(): + """Create a structured output context with a test model.""" + return StructuredOutputContext(structured_output_model=UserModel) + + +@pytest.fixture +def agenerator(): + """Helper to create async generators.""" + + def _agenerator(items): + async def gen(): + for item in items: + yield item + + return gen() + + return _agenerator + + +@pytest.fixture +def alist(): + """Helper to consume async generators.""" + + async def _alist(async_gen): + items = [] + async for item in async_gen: + items.append(item) + return items + + return _alist + + +@pytest.mark.asyncio +async def test_event_loop_cycle_with_structured_output_context(mock_agent, agenerator, alist): + """Test event_loop_cycle with structured output context passed but not enabled.""" + # Create a context that's not enabled (no model) + structured_output_context = StructuredOutputContext() + + # Setup model to return a text response + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user data"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + # Run event loop cycle with structured output context + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should have received events + assert len(events) > 0 + + # The context should be passed through but not enabled + assert not structured_output_context.is_enabled + + +@pytest.mark.asyncio +async def test_event_loop_forces_structured_output_on_end_turn( + mock_agent, structured_output_context, agenerator, alist +): + """Test that event loop forces structured output tool when model returns end_turn.""" + # First call returns end_turn without using structured output tool + mock_agent.model.stream.side_effect = [ + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + # Second call (forced) uses the structured output tool + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "UserModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"name": "John", "age": 30, "email": "john@example.com"}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), + ] + + # Mock tool executor to handle the structured output tool + mock_agent.tool_executor._execute = Mock( + return_value=agenerator( + [ + # Tool execution events would go here + ] + ) + ) + + # Mock recurse_event_loop to return final result + with patch("strands.event_loop.event_loop.recurse_event_loop") as mock_recurse: + # Create a mock EventLoopStopEvent with the expected structure + mock_stop_event = Mock() + mock_stop_event.stop = ( + "end_turn", + {"role": "assistant", "content": [{"text": "Done"}]}, + mock_agent.event_loop_metrics, + {}, + None, + UserModel(name="John", age=30, email="john@example.com"), + ) + mock_stop_event.__getitem__ = lambda self, key: {"stop": self.stop}[key] + + mock_recurse.return_value = agenerator([mock_stop_event]) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + # Should have appended a message to force structured output + mock_agent._append_message.assert_called_once() + args = mock_agent._append_message.call_args[0][0] + assert args["role"] == "user" + + # Should have called recurse_event_loop with the context + mock_recurse.assert_called_once() + call_kwargs = mock_recurse.call_args[1] + assert call_kwargs["structured_output_context"] == structured_output_context + + +@pytest.mark.asyncio +async def test_structured_output_tool_execution_extracts_result( + mock_agent, structured_output_context, agenerator, alist +): + """Test that structured output result is extracted from tool execution.""" + # Model uses the structured output tool + mock_agent.model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "UserModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"name": "Alice", "age": 25, "email": "alice@test.com"}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + # Mock the tool executor to return an async generator + mock_agent.tool_executor._execute = Mock(return_value=agenerator([])) + + # Mock extract_result to return a model instance + test_result = UserModel(name="Alice", age=25, email="alice@test.com") + structured_output_context.extract_result = Mock(return_value=test_result) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should yield StructuredOutputEvent + structured_output_events = [e for e in events if isinstance(e, StructuredOutputEvent)] + assert len(structured_output_events) == 1 + assert structured_output_events[0]["structured_output"] == test_result + + # Extract_result should have been called + structured_output_context.extract_result.assert_called_once() + + +@pytest.mark.asyncio +async def test_structured_output_context_not_enabled(mock_agent, agenerator, alist): + """Test event loop with structured output context that's not enabled.""" + # Create a context that's not enabled (no model) + structured_output_context = StructuredOutputContext() + assert not structured_output_context.is_enabled + + # Model returns end_turn + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Regular response"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should complete normally without forcing structured output + stop_events = [e for e in events if isinstance(e, EventLoopStopEvent)] + assert len(stop_events) == 1 + assert stop_events[0]["stop"][-1] is None + + +@pytest.mark.asyncio +async def test_structured_output_forced_mode(mock_agent, agenerator, alist): + """Test event loop with structured output in forced mode.""" + # Create context in forced mode + structured_output_context = StructuredOutputContext(structured_output_model=ProductModel) + structured_output_context.set_forced_mode(tool_choice={"tool": {"name": "ProductModel"}}) + + # Model should be called with only the structured output tool spec + mock_agent.model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "ProductModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"title": "Book", "price": 19.99, "in_stock": true}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + # Mock tool executor + mock_agent.tool_executor._execute = Mock(return_value=agenerator([])) + + # Mock extract_result + test_result = ProductModel(title="Book", price=19.99, in_stock=True) + structured_output_context.extract_result = Mock(return_value=test_result) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + # Verify model.stream was called with the forced tool spec + mock_agent.model.stream.assert_called_once() + call_args = mock_agent.model.stream.call_args + + # The model.stream method signature (from streaming.py) is: + # model.stream(messages, tool_specs, system_prompt, tool_choice=tool_choice) + tool_specs = call_args.args[1] if len(call_args.args) > 1 else None + + # In forced mode, only the structured output tool spec should be passed + assert tool_specs is not None, "Expected tool_specs to be provided" + assert isinstance(tool_specs, list), f"Expected tool_specs to be a list, got {type(tool_specs)}" + assert len(tool_specs) == 1 + assert tool_specs[0]["name"] == "ProductModel" + + +@pytest.mark.asyncio +async def test_recurse_event_loop_with_structured_output(mock_agent, structured_output_context, agenerator, alist): + """Test recurse_event_loop preserves structured output context.""" + invocation_state = { + "event_loop_cycle_trace": Mock(), + "request_state": {}, + } + + # Mock event_loop_cycle to verify it receives the context + with patch("strands.event_loop.event_loop.event_loop_cycle") as mock_cycle: + # Create a mock EventLoopStopEvent with the expected structure + mock_stop_event = Mock(spec=EventLoopStopEvent) + mock_stop_event.stop = ( + "end_turn", + {"role": "assistant", "content": [{"text": "Done"}]}, + mock_agent.event_loop_metrics, + {}, + None, + UserModel(name="Test", age=20, email="test@example.com"), + ) + mock_stop_event.__getitem__ = lambda self, key: {"stop": self.stop}[key] + + mock_cycle.return_value = agenerator([mock_stop_event]) + + stream = recurse_event_loop( + agent=mock_agent, + invocation_state=invocation_state, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Verify event_loop_cycle was called with the context + mock_cycle.assert_called_once() + call_kwargs = mock_cycle.call_args[1] + assert call_kwargs["structured_output_context"] == structured_output_context + + # Verify the result includes structured output + stop_events = [ + e for e in events if isinstance(e, EventLoopStopEvent) or (hasattr(e, "stop") and hasattr(e, "__getitem__")) + ] + assert len(stop_events) == 1 + stop_event = stop_events[0] + if hasattr(stop_event, "__getitem__"): + assert stop_event["stop"][5].name == "Test" + else: + assert stop_event.stop[5].name == "Test" + + +@pytest.mark.asyncio +async def test_structured_output_stops_loop_after_extraction(mock_agent, structured_output_context, agenerator, alist): + """Test that loop stops after structured output is extracted.""" + # Model uses the structured output tool + mock_agent.model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "UserModel", + } + } + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"name": "Bob", "age": 35, "email": "bob@test.com"}'}} + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ) + + # Mock tool executor + mock_agent.tool_executor._execute = Mock(return_value=agenerator([])) + + # Mock extract_result to return a result and set stop_loop + test_result = UserModel(name="Bob", age=35, email="bob@test.com") + + def mock_extract(tool_uses): + structured_output_context.stop_loop = True + return test_result + + structured_output_context.extract_result = Mock(side_effect=mock_extract) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + events = await alist(stream) + + # Should have a StructuredOutputEvent with the result + structured_output_events = [e for e in events if isinstance(e, StructuredOutputEvent)] + assert len(structured_output_events) == 1 + assert structured_output_events[0]["structured_output"] == test_result + + # Verify stop_loop was set + assert structured_output_context.stop_loop + + # Extract_result should have been called + structured_output_context.extract_result.assert_called_once() diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 5afa0cb45..92bf0de96 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -791,6 +791,7 @@ async def test_stream_messages(agenerator, alist): [{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}], None, "test prompt", + tool_choice=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py new file mode 100644 index 000000000..e17044527 --- /dev/null +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -0,0 +1,157 @@ +"""Tests for streaming.py with structured output support.""" + +import unittest.mock + +import pytest +from pydantic import BaseModel + +import strands.event_loop.streaming +from strands.tools.structured_output.structured_output_tool import StructuredOutputTool +from strands.types._events import TypedEvent + + +class SampleModel(BaseModel): + """Sample model for structured output.""" + + name: str + age: int + + +@pytest.fixture(autouse=True) +def moto_autouse(moto_env, moto_mock_aws): + _ = moto_env + _ = moto_mock_aws + + +@pytest.mark.asyncio +async def test_stream_messages_with_tool_choice(agenerator, alist): + """Test stream_messages with tool_choice parameter for structured output.""" + mock_model = unittest.mock.MagicMock() + mock_model.stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "test-123", "name": "SampleModel"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "test", "age": 25}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + }, + ] + ) + + # Create a structured output tool and get its spec + structured_tool = StructuredOutputTool(SampleModel) + tool_spec = structured_tool.tool_spec + tool_choice = {"tool": {"name": "SampleModel"}} + + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=[{"role": "user", "content": [{"text": "Generate a test model"}]}], + tool_specs=[tool_spec], + tool_choice=tool_choice, + ) + + tru_events = await alist(stream) + + # Verify the model.stream was called with tool_choice + mock_model.stream.assert_called_with( + [{"role": "user", "content": [{"text": "Generate a test model"}]}], + [tool_spec], + "test prompt", + tool_choice=tool_choice, + ) + + # Verify we get the expected events + assert len(tru_events) > 0 + + # Find the stop event + stop_event = None + for event in tru_events: + if isinstance(event, dict) and "stop" in event: + stop_event = event + break + + assert stop_event is not None + assert stop_event["stop"][0] == "tool_use" + + # Ensure that we're getting typed events + non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] + assert non_typed_events == [] + + +@pytest.mark.asyncio +async def test_stream_messages_with_forced_structured_output(agenerator, alist): + """Test stream_messages with forced structured output tool.""" + mock_model = unittest.mock.MagicMock() + + # Simulate a response with tool use + mock_model.stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "SampleModel"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"name": "Alice", "age": 30}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + "metrics": {"latencyMs": 150}, + } + }, + ] + ) + + # Create a structured output tool and get its spec + structured_tool = StructuredOutputTool(SampleModel) + tool_spec = structured_tool.tool_spec + tool_choice = {"any": {}} + + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="Extract user information", + messages=[{"role": "user", "content": [{"text": "Alice is 30 years old"}]}], + tool_specs=[tool_spec], + tool_choice=tool_choice, + ) + + tru_events = await alist(stream) + + # Verify the model.stream was called with the forced tool choice + mock_model.stream.assert_called_with( + [{"role": "user", "content": [{"text": "Alice is 30 years old"}]}], + [tool_spec], + "Extract user information", + tool_choice=tool_choice, + ) + + assert len(tru_events) > 0 + + # Find the stop event and verify it contains the extracted data + stop_event = None + for event in tru_events: + if isinstance(event, dict) and "stop" in event: + stop_event = event + break + + assert stop_event is not None + stop_reason, message, usage, metrics = stop_event["stop"] + + assert stop_reason == "tool_use" + assert message["role"] == "assistant" + assert len(message["content"]) > 0 + + # Check that the tool use contains the expected data + tool_use_content = None + for content in message["content"]: + if "toolUse" in content: + tool_use_content = content["toolUse"] + break + + assert tool_use_content is not None + assert tool_use_content["name"] == "SampleModel" + assert tool_use_content["input"] == {"name": "Alice", "age": 30} diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 219561025..b8249f504 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -128,3 +128,48 @@ async def stream(self, messages, tool_specs=None, system_prompt=None): assert len(events) == 3 assert events[1]["contentBlockDelta"]["delta"]["text"] == "Legacy model works" + + +@pytest.mark.asyncio +async def test_stream_with_tool_choice_parameter(messages, tool_specs, system_prompt, alist): + """Test that model can accept tool_choice parameter.""" + + class ModernModel(SAModel): + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs): + yield {"output": output_model(name="test", age=20)} + + async def stream(self, messages, tool_specs=None, system_prompt=None, *, tool_choice=None, **kwargs): + yield {"messageStart": {"role": "assistant"}} + if tool_choice: + yield {"contentBlockDelta": {"delta": {"text": f"Tool choice: {tool_choice}"}}} + else: + yield {"contentBlockDelta": {"delta": {"text": "No tool choice"}}} + yield {"messageStop": {"stopReason": "end_turn"}} + + model = ModernModel() + + # Test with tool_choice="auto" + response = model.stream(messages, tool_specs, system_prompt, tool_choice="auto") + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "Tool choice: auto" + + # Test with tool_choice="any" + response = model.stream(messages, tool_specs, system_prompt, tool_choice="any") + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "Tool choice: any" + + # Test with tool_choice={"type": "tool", "name": "test_tool"} + response = model.stream(messages, tool_specs, system_prompt, tool_choice={"tool": {"name": "SampleModel"}}) + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "Tool choice: {'tool': {'name': 'SampleModel'}}" + + # Test without tool_choice + response = model.stream(messages, tool_specs, system_prompt) + events = await alist(response) + assert events[1]["contentBlockDelta"]["delta"]["text"] == "No tool choice" diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 4b62a8a9a..ce07ee4ce 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -3,6 +3,7 @@ from strands.hooks import BeforeToolCallEvent from strands.interrupt import Interrupt from strands.tools.executors import ConcurrentToolExecutor +from strands.tools.structured_output._structured_output_context import StructuredOutputContext from strands.types._events import ToolInterruptEvent, ToolResultEvent @@ -11,15 +12,22 @@ def executor(): return ConcurrentToolExecutor() +@pytest.fixture +def structured_output_context(): + return StructuredOutputContext(structured_output_model=None) + + @pytest.mark.asyncio async def test_concurrent_executor_execute( - executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist ): tool_uses = [ {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] - stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ @@ -35,7 +43,7 @@ async def test_concurrent_executor_execute( @pytest.mark.asyncio async def test_concurrent_executor_interrupt( - executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist ): interrupt = Interrupt( id="v1:before_tool_call:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", @@ -54,7 +62,9 @@ def interrupt_callback(event): {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, ] - stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index a6c2c2277..10e3ad484 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,9 +1,20 @@ import pytest +from pydantic import BaseModel from strands.hooks import BeforeToolCallEvent from strands.interrupt import Interrupt +from strands.tools.decorator import tool from strands.tools.executors import SequentialToolExecutor +from strands.tools.structured_output._structured_output_context import StructuredOutputContext from strands.types._events import ToolInterruptEvent, ToolResultEvent +from strands.types.tools import ToolUse + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str + age: int @pytest.fixture @@ -11,6 +22,34 @@ def executor(): return SequentialToolExecutor() +@pytest.fixture +def structured_output_context(): + """Create a structured output context with SampleModel.""" + return StructuredOutputContext(structured_output_model=SampleModel) + + +@pytest.fixture +def capture_tool(): + """Create a tool that captures kwargs passed to it.""" + captured_kwargs = {} + + @tool(name="capture_tool") + def func(): + return "captured" + + # Override the stream method to capture kwargs + original_stream = func.stream + + async def capturing_stream(tool_use, invocation_state, **kwargs): + captured_kwargs.update(kwargs) + async for event in original_stream(tool_use, invocation_state, **kwargs): + yield event + + func.stream = capturing_stream + func.captured_kwargs = captured_kwargs + return func + + @pytest.mark.asyncio async def test_sequential_executor_execute( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist @@ -19,7 +58,10 @@ async def test_sequential_executor_execute( {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] - stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + structured_output_context = StructuredOutputContext(None) + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) tru_events = await alist(stream) exp_events = [ @@ -53,7 +95,10 @@ def interrupt_callback(event): {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, ] - stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + structured_output_context = StructuredOutputContext(None) + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) tru_events = await alist(stream) exp_events = [ToolInterruptEvent(tool_uses[0], [interrupt])] @@ -62,3 +107,41 @@ def interrupt_callback(event): tru_results = tool_results exp_results = [] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_sequential_executor_passes_structured_output_context( + executor, + agent, + tool_results, + cycle_trace, + cycle_span, + invocation_state, + structured_output_context, + capture_tool, + alist, +): + """Test that sequential executor properly passes structured output context to tools.""" + # Register the capture tool + agent.tool_registry.register_tool(capture_tool) + + # Set up tool uses + tool_uses: list[ToolUse] = [ + {"name": "capture_tool", "toolUseId": "1", "input": {}}, + ] + + # Execute tools with structured output context + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) + + # Collect events + events = await alist(stream) + + # Verify the structured_output_context was passed to the tool + assert "structured_output_context" in capture_tool.captured_kwargs + assert capture_tool.captured_kwargs["structured_output_context"] is structured_output_context + + # Verify event was generated + assert len(events) == 1 + assert events[0].tool_use_id == "1" diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py new file mode 100644 index 000000000..a7eb27ca5 --- /dev/null +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -0,0 +1,245 @@ +"""Tests for StructuredOutputContext class.""" + +from typing import Optional + +from pydantic import BaseModel, Field + +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output.structured_output_tool import StructuredOutputTool + + +class SampleModel(BaseModel): + """Test Pydantic model for testing.""" + + name: str = Field(..., description="Name field") + age: int = Field(..., description="Age field", ge=0) + email: Optional[str] = Field(None, description="Optional email field") + + +class AnotherSampleModel(BaseModel): + """Another test Pydantic model.""" + + value: str + count: int + + +class TestStructuredOutputContext: + """Test suite for StructuredOutputContext.""" + + def test_initialization_with_structured_output_model(self): + """Test initialization with a structured output model.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + assert context.structured_output_model == SampleModel + assert isinstance(context.structured_output_tool, StructuredOutputTool) + assert context.expected_tool_name == "SampleModel" + assert context.results == {} + assert context.forced_mode is False + assert context.tool_choice is None + assert context.stop_loop is False + + def test_initialization_without_structured_output_model(self): + """Test initialization without a structured output model.""" + context = StructuredOutputContext(structured_output_model=None) + + assert context.structured_output_model is None + assert context.structured_output_tool is None + assert context.expected_tool_name is None + assert context.results == {} + assert context.forced_mode is False + assert context.tool_choice is None + assert context.stop_loop is False + + def test_is_enabled_property(self): + """Test the is_enabled property.""" + # Test with model + context_with_model = StructuredOutputContext(structured_output_model=SampleModel) + assert context_with_model.is_enabled is True + + # Test without model + context_without_model = StructuredOutputContext(structured_output_model=None) + assert context_without_model.is_enabled is False + + def test_store_result_and_get_result(self): + """Test storing and retrieving results.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Create test result + test_result = SampleModel(name="John Doe", age=30, email="john@example.com") + tool_use_id = "test_tool_use_123" + + # Store result + context.store_result(tool_use_id, test_result) + assert tool_use_id in context.results + assert context.results[tool_use_id] == test_result + + # Retrieve result + retrieved_result = context.get_result(tool_use_id) + assert retrieved_result == test_result + + # Test retrieving non-existent result + non_existent = context.get_result("non_existent_id") + assert non_existent is None + + def test_set_forced_mode_with_tool_choice(self): + """Test set_forced_mode with custom tool_choice.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + custom_tool_choice = {"specific": {"tool": "SampleModel"}} + context.set_forced_mode(tool_choice=custom_tool_choice) + + assert context.forced_mode is True + assert context.tool_choice == custom_tool_choice + + def test_set_forced_mode_without_tool_choice(self): + """Test set_forced_mode without tool_choice (default).""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + context.set_forced_mode() + + assert context.forced_mode is True + assert context.tool_choice == {"any": {}} + + def test_set_forced_mode_when_disabled(self): + """Test set_forced_mode when context is not enabled.""" + context = StructuredOutputContext(structured_output_model=None) + + # Should not change state when not enabled + context.set_forced_mode(tool_choice={"test": "value"}) + + assert context.forced_mode is False + assert context.tool_choice is None + + def test_has_structured_output_tool(self): + """Test has_structured_output_tool method.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Create tool uses with the expected tool + tool_uses_with_output = [ + {"name": "SampleModel", "toolUseId": "123", "input": {}}, + {"name": "OtherTool", "toolUseId": "456", "input": {}}, + ] + + # Should find the structured output tool + assert context.has_structured_output_tool(tool_uses_with_output) is True + + # Create tool uses without the expected tool + tool_uses_without_output = [ + {"name": "OtherTool", "toolUseId": "456", "input": {}}, + {"name": "AnotherTool", "toolUseId": "789", "input": {}}, + ] + + # Should not find the structured output tool + assert context.has_structured_output_tool(tool_uses_without_output) is False + + # Test with empty list + assert context.has_structured_output_tool([]) is False + + def test_has_structured_output_tool_when_disabled(self): + """Test has_structured_output_tool when no expected tool name.""" + context = StructuredOutputContext(structured_output_model=None) + + tool_uses = [ + {"name": "SampleModel", "toolUseId": "123", "input": {}}, + ] + + # Should return False when no expected tool name + assert context.has_structured_output_tool(tool_uses) is False + + def test_get_tool_spec(self): + """Test get_tool_spec method.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + tool_spec = context.get_tool_spec() + assert tool_spec is not None + assert isinstance(tool_spec, dict) + assert "name" in tool_spec + assert tool_spec["name"] == "SampleModel" + assert "description" in tool_spec + assert "inputSchema" in tool_spec + + def test_get_tool_spec_when_disabled(self): + """Test get_tool_spec when no structured output tool.""" + context = StructuredOutputContext(structured_output_model=None) + + tool_spec = context.get_tool_spec() + assert tool_spec is None + + def test_extract_result(self): + """Test extract_result method.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Store some results + result1 = SampleModel(name="Alice", age=25) + result2 = SampleModel(name="Bob", age=30) + context.store_result("tool_use_1", result1) + context.store_result("tool_use_2", result2) + + # Create tool uses with matching tool + tool_uses = [ + {"name": "SampleModel", "toolUseId": "tool_use_1", "input": {}}, + {"name": "OtherTool", "toolUseId": "tool_use_3", "input": {}}, + ] + + # Extract result should return and remove the first matching result + extracted = context.extract_result(tool_uses) + assert extracted == result1 + assert "tool_use_1" not in context.results + assert "tool_use_2" in context.results # Other result should remain + + def test_extract_result_no_matching_tool(self): + """Test extract_result when no matching tool in tool_uses.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + result = SampleModel(name="Alice", age=25) + context.store_result("tool_use_1", result) + + # Tool uses without the expected tool name + tool_uses = [ + {"name": "OtherTool", "toolUseId": "tool_use_1", "input": {}}, + ] + + # Should return None + extracted = context.extract_result(tool_uses) + assert extracted is None + assert "tool_use_1" in context.results # Result should remain + + def test_extract_result_no_stored_result(self): + """Test extract_result when no stored result for tool use.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Tool uses with matching tool but no stored result + tool_uses = [ + {"name": "SampleModel", "toolUseId": "tool_use_1", "input": {}}, + ] + + # Should return None + extracted = context.extract_result(tool_uses) + assert extracted is None + + def test_extract_result_multiple_matching_tools(self): + """Test extract_result with multiple matching tool uses.""" + context = StructuredOutputContext(structured_output_model=SampleModel) + + # Store multiple results + result1 = SampleModel(name="Alice", age=25) + result2 = SampleModel(name="Bob", age=30) + context.store_result("tool_use_1", result1) + context.store_result("tool_use_2", result2) + + # Multiple matching tool uses + tool_uses = [ + {"name": "SampleModel", "toolUseId": "tool_use_1", "input": {}}, + {"name": "SampleModel", "toolUseId": "tool_use_2", "input": {}}, + ] + + # Should extract the first matching result + extracted = context.extract_result(tool_uses) + assert extracted == result1 + assert "tool_use_1" not in context.results + assert "tool_use_2" in context.results + + # Extract again for the second result + extracted2 = context.extract_result(tool_uses) + assert extracted2 == result2 + assert "tool_use_2" not in context.results diff --git a/tests/strands/tools/structured_output/test_structured_output_tool.py b/tests/strands/tools/structured_output/test_structured_output_tool.py new file mode 100644 index 000000000..66f1d465d --- /dev/null +++ b/tests/strands/tools/structured_output/test_structured_output_tool.py @@ -0,0 +1,307 @@ +"""Tests for StructuredOutputTool class.""" + +from typing import List, Optional +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, Field + +from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output.structured_output_tool import _TOOL_SPEC_CACHE, StructuredOutputTool +from strands.types._events import ToolResultEvent + + +class SimpleModel(BaseModel): + """Simple test model.""" + + name: str = Field(..., description="Name field") + value: int = Field(..., description="Value field") + + +class ComplexModel(BaseModel): + """Complex test model with nested structures.""" + + title: str = Field(..., description="Title field") + count: int = Field(..., ge=0, le=100, description="Count between 0 and 100") + tags: List[str] = Field(default_factory=list, description="List of tags") + metadata: Optional[dict] = Field(None, description="Optional metadata") + + +class ValidationTestModel(BaseModel): + """Model for testing validation.""" + + email: str = Field(..., pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$", description="Email address") + age: int = Field(..., ge=0, le=150, description="Age between 0 and 150") + status: str = Field(..., pattern="^(active|inactive|pending)$", description="Status") + + +class TestStructuredOutputTool: + """Test suite for StructuredOutputTool.""" + + def test_tool_initialization_with_simple_model(self): + """Test tool initialization with a simple Pydantic model.""" + tool = StructuredOutputTool(SimpleModel) + + assert tool.structured_output_model == SimpleModel + assert tool.tool_name == "SimpleModel" + assert tool.tool_type == "structured_output" + assert isinstance(tool.tool_spec, dict) + assert tool.tool_spec["name"] == "SimpleModel" + + def test_tool_initialization_with_complex_model(self): + """Test tool initialization with a complex Pydantic model.""" + tool = StructuredOutputTool(ComplexModel) + + assert tool.structured_output_model == ComplexModel + assert tool.tool_name == "ComplexModel" + assert tool.tool_type == "structured_output" + assert isinstance(tool.tool_spec, dict) + assert tool.tool_spec["name"] == "ComplexModel" + + def test_get_tool_spec_caching_mechanism(self): + """Test that tool specs are cached properly.""" + # Clear cache first + _TOOL_SPEC_CACHE.clear() + + # First call should create and cache the spec + tool1 = StructuredOutputTool(SimpleModel) + spec1 = tool1.tool_spec + + # Cache should now contain the spec + assert SimpleModel in _TOOL_SPEC_CACHE + + # Second call with same model should use cached version + tool2 = StructuredOutputTool(SimpleModel) + spec2 = tool2.tool_spec + + # Specs should be equal but not the same object (deepcopy is used) + assert spec1 == spec2 + assert spec1 is not spec2 + + # Cache should still have only one entry for SimpleModel + assert len([k for k in _TOOL_SPEC_CACHE if k == SimpleModel]) == 1 + + def test_tool_name_property(self): + """Test the tool_name property.""" + tool = StructuredOutputTool(SimpleModel) + assert tool.tool_name == "SimpleModel" + + tool2 = StructuredOutputTool(ComplexModel) + assert tool2.tool_name == "ComplexModel" + + def test_tool_spec_property(self): + """Test the tool_spec property.""" + tool = StructuredOutputTool(SimpleModel) + spec = tool.tool_spec + + assert isinstance(spec, dict) + assert "name" in spec + assert "description" in spec + assert "inputSchema" in spec + assert spec["name"] == "SimpleModel" + + # Check that description includes the important message + assert "IMPORTANT: This StructuredOutputTool should only be invoked" in spec["description"] + + def test_tool_type_property(self): + """Test that tool_type property returns 'structured_output'.""" + tool = StructuredOutputTool(SimpleModel) + assert tool.tool_type == "structured_output" + + def test_structured_output_model_property(self): + """Test the structured_output_model property.""" + tool = StructuredOutputTool(SimpleModel) + assert tool.structured_output_model == SimpleModel + + tool2 = StructuredOutputTool(ComplexModel) + assert tool2.structured_output_model == ComplexModel + + @pytest.mark.asyncio + async def test_stream_with_valid_input(self): + """Test stream method with valid input.""" + tool = StructuredOutputTool(SimpleModel) + context = StructuredOutputContext(structured_output_model=SimpleModel) + + tool_use = {"name": "SimpleModel", "toolUseId": "test_123", "input": {"name": "Test Name", "value": 42}} + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Should have one ToolResultEvent + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + + # Check the result + result = events[0].tool_result + assert result["toolUseId"] == "test_123" + assert result["status"] == "success" + assert "Successfully validated SimpleModel" in result["content"][0]["text"] + + # Check that result was stored in context + stored_result = context.get_result("test_123") + assert stored_result is not None + assert stored_result.name == "Test Name" + assert stored_result.value == 42 + + @pytest.mark.asyncio + async def test_stream_with_missing_fields(self): + """Test stream method with missing required fields.""" + tool = StructuredOutputTool(SimpleModel) + context = StructuredOutputContext(structured_output_model=SimpleModel) + + tool_use = { + "name": "SimpleModel", + "toolUseId": "test_789", + "input": { + "name": "Test Name" + # Missing required 'value' field + }, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Should have one ToolResultEvent with error + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + + # Check the error result + result = events[0].tool_result + assert result["toolUseId"] == "test_789" + assert result["status"] == "error" + + error_text = result["content"][0]["text"] + assert "Validation failed for SimpleModel" in error_text + assert "Field 'value'" in error_text or "field required" in error_text.lower() + + @pytest.mark.asyncio + async def test_stream_with_unexpected_exception(self): + """Test stream method with unexpected exceptions.""" + tool = StructuredOutputTool(SimpleModel) + context = MagicMock() + + # Mock the context to raise an unexpected exception + context.store_result.side_effect = RuntimeError("Unexpected error") + + tool_use = {"name": "SimpleModel", "toolUseId": "test_error", "input": {"name": "Test", "value": 1}} + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Should have one ToolResultEvent with error + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + + # Check the error result + result = events[0].tool_result + assert result["toolUseId"] == "test_error" + assert result["status"] == "error" + + error_text = result["content"][0]["text"] + assert "Unexpected error validating SimpleModel" in error_text + assert "Unexpected error" in error_text + + @pytest.mark.asyncio + async def test_error_message_formatting_single_error(self): + """Test error message formatting with a single validation error.""" + tool = StructuredOutputTool(SimpleModel) + context = StructuredOutputContext(structured_output_model=SimpleModel) + + tool_use = { + "name": "SimpleModel", + "toolUseId": "test_format_1", + "input": { + "name": "Test", + "value": "not an integer", # Wrong type + }, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + result = events[0].tool_result + error_text = result["content"][0]["text"] + + # Check error formatting + assert "Validation failed for SimpleModel" in error_text + assert "Please fix the following errors:" in error_text + assert "- Field 'value':" in error_text + + @pytest.mark.asyncio + async def test_error_message_formatting_multiple_errors(self): + """Test error message formatting with multiple validation errors.""" + tool = StructuredOutputTool(ValidationTestModel) + context = StructuredOutputContext(structured_output_model=ValidationTestModel) + + tool_use = { + "name": "ValidationTestModel", + "toolUseId": "test_format_2", + "input": {"email": "bad-email", "age": -5, "status": "invalid"}, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + result = events[0].tool_result + error_text = result["content"][0]["text"] + + # Check that multiple errors are formatted properly + assert "Validation failed for ValidationTestModel" in error_text + assert "Please fix the following errors:" in error_text + # Should have multiple error lines + error_lines = [line for line in error_text.split("\n") if line.startswith("- Field")] + assert len(error_lines) >= 2 # At least 2 validation errors + + @pytest.mark.asyncio + async def test_stream_with_complex_nested_data(self): + """Test stream method with complex nested data.""" + tool = StructuredOutputTool(ComplexModel) + context = StructuredOutputContext(structured_output_model=ComplexModel) + + tool_use = { + "name": "ComplexModel", + "toolUseId": "test_complex", + "input": { + "title": "Test Title", + "count": 50, + "tags": ["tag1", "tag2", "tag3"], + "metadata": {"key1": "value1", "key2": 123}, + }, + } + + # Call stream method + events = [] + async for event in tool.stream(tool_use, {}, structured_output_context=context): + events.append(event) + + # Check success + result = events[0].tool_result + assert result["status"] == "success" + + # Check stored result + stored_result = context.get_result("test_complex") + assert stored_result.title == "Test Title" + assert stored_result.count == 50 + assert stored_result.tags == ["tag1", "tag2", "tag3"] + assert stored_result.metadata == {"key1": "value1", "key2": 123} + + def test_tool_spec_description_modification(self): + """Test that tool spec description is properly modified.""" + tool = StructuredOutputTool(SimpleModel) + spec = tool.tool_spec + + # Check that the IMPORTANT message is prepended + assert spec["description"].startswith("IMPORTANT: This StructuredOutputTool should only be invoked") + assert "last and final tool" in spec["description"] + assert "" in spec["description"] + assert "" in spec["description"] diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py new file mode 100644 index 000000000..d64cabb83 --- /dev/null +++ b/tests/strands/types/test__events.py @@ -0,0 +1,467 @@ +"""Tests for event types in the strands.types._events module.""" + +from unittest.mock import MagicMock, Mock + +from pydantic import BaseModel + +from strands.telemetry import EventLoopMetrics +from strands.types._events import ( + AgentResultEvent, + CitationStreamEvent, + EventLoopStopEvent, + EventLoopThrottleEvent, + ForceStopEvent, + InitEventLoopEvent, + ModelMessageEvent, + ModelStopReason, + ModelStreamChunkEvent, + ModelStreamEvent, + ReasoningRedactedContentStreamEvent, + ReasoningSignatureStreamEvent, + ReasoningTextStreamEvent, + StartEvent, + StartEventLoopEvent, + StructuredOutputEvent, + TextStreamEvent, + ToolResultEvent, + ToolResultMessageEvent, + ToolStreamEvent, + ToolUseStreamEvent, + TypedEvent, +) +from strands.types.citations import Citation +from strands.types.content import Message +from strands.types.event_loop import Metrics, StopReason, Usage +from strands.types.streaming import ContentBlockDelta, StreamEvent +from strands.types.tools import ToolResult, ToolUse + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str + value: int + + +class TestTypedEvent: + """Tests for the base TypedEvent class.""" + + def test_initialization_with_data(self): + """Test TypedEvent initialization with data.""" + data = {"key": "value", "number": 42} + event = TypedEvent(data) + assert event["key"] == "value" + assert event["number"] == 42 + + def test_initialization_without_data(self): + """Test TypedEvent initialization without data.""" + event = TypedEvent() + assert len(event) == 0 + + def test_is_callback_event_default(self): + """Test that is_callback_event returns True by default.""" + event = TypedEvent() + assert event.is_callback_event is True + + def test_as_dict(self): + """Test as_dict method returns dictionary representation.""" + data = {"test": "data", "nested": {"key": "value"}} + event = TypedEvent(data) + result = event.as_dict() + assert result == data + assert isinstance(result, dict) + + def test_prepare_default_implementation(self): + """Test prepare method default implementation does nothing.""" + event = TypedEvent({"initial": "data"}) + invocation_state = {"state": "value"} + event.prepare(invocation_state) + # Default implementation does nothing + assert event == {"initial": "data"} + + +class TestInitEventLoopEvent: + """Tests for InitEventLoopEvent.""" + + def test_initialization(self): + """Test InitEventLoopEvent initialization.""" + event = InitEventLoopEvent() + assert event["init_event_loop"] is True + + def test_prepare_updates_with_invocation_state(self): + """Test prepare method updates event with invocation state.""" + event = InitEventLoopEvent() + invocation_state = {"request_id": "123", "session": "abc"} + event.prepare(invocation_state) + assert event["request_id"] == "123" + assert event["session"] == "abc" + assert event["init_event_loop"] is True + + +class TestStartEvent: + """Tests for StartEvent (deprecated).""" + + def test_initialization(self): + """Test StartEvent initialization.""" + event = StartEvent() + assert event["start"] is True + + +class TestStartEventLoopEvent: + """Tests for StartEventLoopEvent.""" + + def test_initialization(self): + """Test StartEventLoopEvent initialization.""" + event = StartEventLoopEvent() + assert event["start_event_loop"] is True + + +class TestModelStreamChunkEvent: + """Tests for ModelStreamChunkEvent.""" + + def test_initialization_with_stream_event(self): + """Test ModelStreamChunkEvent initialization with StreamEvent.""" + stream_event = Mock(spec=StreamEvent) + event = ModelStreamChunkEvent(stream_event) + assert event["event"] == stream_event + assert event.chunk == stream_event + + +class TestModelStreamEvent: + """Tests for ModelStreamEvent.""" + + def test_initialization_with_delta_data(self): + """Test ModelStreamEvent initialization with delta data.""" + delta_data = {"type": "text", "content": "hello"} + event = ModelStreamEvent(delta_data) + assert event["type"] == "text" + assert event["content"] == "hello" + + def test_is_callback_event_empty(self): + """Test is_callback_event returns False when empty.""" + event = ModelStreamEvent({}) + assert event.is_callback_event is False + + def test_is_callback_event_non_empty(self): + """Test is_callback_event returns True when non-empty.""" + event = ModelStreamEvent({"data": "value"}) + assert event.is_callback_event is True + + def test_prepare_with_delta(self): + """Test prepare method updates when delta is present.""" + event = ModelStreamEvent({"delta": "content", "other": "data"}) + invocation_state = {"request_id": "456"} + event.prepare(invocation_state) + assert event["request_id"] == "456" + assert event["delta"] == "content" + + def test_prepare_without_delta(self): + """Test prepare method does nothing when delta is not present.""" + event = ModelStreamEvent({"other": "data"}) + invocation_state = {"request_id": "456"} + event.prepare(invocation_state) + assert "request_id" not in event + + +class TestToolUseStreamEvent: + """Tests for ToolUseStreamEvent.""" + + def test_initialization(self): + """Test ToolUseStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + current_tool_use = {"toolUseId": "123", "name": "calculator"} + event = ToolUseStreamEvent(delta, current_tool_use) + assert event["delta"] == delta + assert event["current_tool_use"] == current_tool_use + + +class TestTextStreamEvent: + """Tests for TextStreamEvent.""" + + def test_initialization(self): + """Test TextStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + text = "Hello, world!" + event = TextStreamEvent(delta, text) + assert event["data"] == text + assert event["delta"] == delta + + +class TestCitationStreamEvent: + """Tests for CitationStreamEvent.""" + + def test_initialization(self): + """Test CitationStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + citation = Mock(spec=Citation) + event = CitationStreamEvent(delta, citation) + assert event["callback"]["citation"] == citation + assert event["callback"]["delta"] == delta + + +class TestReasoningTextStreamEvent: + """Tests for ReasoningTextStreamEvent.""" + + def test_initialization_with_reasoning_text(self): + """Test ReasoningTextStreamEvent initialization with text.""" + delta = Mock(spec=ContentBlockDelta) + reasoning_text = "Thinking about the problem..." + event = ReasoningTextStreamEvent(delta, reasoning_text) + assert event["reasoningText"] == reasoning_text + assert event["delta"] == delta + assert event["reasoning"] is True + + def test_initialization_with_none(self): + """Test ReasoningTextStreamEvent initialization with None.""" + delta = Mock(spec=ContentBlockDelta) + event = ReasoningTextStreamEvent(delta, None) + assert event["reasoningText"] is None + assert event["reasoning"] is True + + +class TestReasoningRedactedContentStreamEvent: + """Tests for ReasoningRedactedContentStreamEvent.""" + + def test_initialization_with_redacted_content(self): + """Test ReasoningRedactedContentStreamEvent initialization with content.""" + delta = Mock(spec=ContentBlockDelta) + redacted_content = b"[REDACTED]" + event = ReasoningRedactedContentStreamEvent(delta, redacted_content) + assert event["reasoningRedactedContent"] == redacted_content + assert event["delta"] == delta + assert event["reasoning"] is True + + def test_initialization_with_none(self): + """Test ReasoningRedactedContentStreamEvent initialization with None.""" + delta = Mock(spec=ContentBlockDelta) + event = ReasoningRedactedContentStreamEvent(delta, None) + assert event["reasoningRedactedContent"] is None + assert event["reasoning"] is True + + +class TestReasoningSignatureStreamEvent: + """Tests for ReasoningSignatureStreamEvent.""" + + def test_initialization(self): + """Test ReasoningSignatureStreamEvent initialization.""" + delta = Mock(spec=ContentBlockDelta) + signature = "signature_xyz123" + event = ReasoningSignatureStreamEvent(delta, signature) + assert event["reasoning_signature"] == signature + assert event["delta"] == delta + assert event["reasoning"] is True + + +class TestModelStopReason: + """Tests for ModelStopReason.""" + + def test_initialization(self): + """Test ModelStopReason initialization.""" + stop_reason = Mock(spec=StopReason) + message = Mock(spec=Message) + usage = Mock(spec=Usage) + metrics = Mock(spec=Metrics) + + event = ModelStopReason(stop_reason, message, usage, metrics) + assert event["stop"] == (stop_reason, message, usage, metrics) + assert event.is_callback_event is False + + +class TestEventLoopStopEvent: + """Tests for EventLoopStopEvent.""" + + def test_initialization_without_structured_output(self): + """Test EventLoopStopEvent initialization without structured output.""" + stop_reason = Mock(spec=StopReason) + message = Mock(spec=Message) + metrics = Mock(spec=EventLoopMetrics) + request_state = {"state": "final"} + + event = EventLoopStopEvent(stop_reason, message, metrics, request_state) + assert event["stop"] == (stop_reason, message, metrics, request_state, None, None) + assert event.is_callback_event is False + + def test_initialization_with_structured_output(self): + """Test EventLoopStopEvent initialization with structured output.""" + stop_reason = Mock(spec=StopReason) + message = Mock(spec=Message) + metrics = Mock(spec=EventLoopMetrics) + request_state = {"state": "final"} + structured_output = SampleModel(name="test", value=42) + + event = EventLoopStopEvent(stop_reason, message, metrics, request_state, structured_output) + assert event["stop"] == (stop_reason, message, metrics, request_state, structured_output, None) + assert event.is_callback_event is False + + +class TestStructuredOutputEvent: + """Tests for StructuredOutputEvent.""" + + def test_initialization(self): + """Test StructuredOutputEvent initialization.""" + structured_output = SampleModel(name="output", value=100) + event = StructuredOutputEvent(structured_output) + assert event["structured_output"] == structured_output + assert isinstance(event["structured_output"], SampleModel) + + +class TestEventLoopThrottleEvent: + """Tests for EventLoopThrottleEvent.""" + + def test_initialization(self): + """Test EventLoopThrottleEvent initialization.""" + delay = 5 + event = EventLoopThrottleEvent(delay) + assert event["event_loop_throttled_delay"] == 5 + + def test_prepare_updates_with_invocation_state(self): + """Test prepare method updates event with invocation state.""" + event = EventLoopThrottleEvent(10) + invocation_state = {"request_id": "throttle_123"} + event.prepare(invocation_state) + assert event["request_id"] == "throttle_123" + assert event["event_loop_throttled_delay"] == 10 + + +class TestToolResultEvent: + """Tests for ToolResultEvent.""" + + def test_initialization(self): + """Test ToolResultEvent initialization.""" + tool_result: ToolResult = { + "toolUseId": "tool_123", + "content": [{"text": "Result"}], + "isError": False, + } + event = ToolResultEvent(tool_result) + assert event["tool_result"] == tool_result + assert event.tool_use_id == "tool_123" + assert event.tool_result == tool_result + assert event.is_callback_event is False + + def test_tool_use_id_property(self): + """Test tool_use_id property returns correct ID.""" + tool_result: ToolResult = { + "toolUseId": "unique_id_456", + "content": [], + } + event = ToolResultEvent(tool_result) + assert event.tool_use_id == "unique_id_456" + + +class TestToolStreamEvent: + """Tests for ToolStreamEvent.""" + + def test_initialization(self): + """Test ToolStreamEvent initialization.""" + tool_use: ToolUse = { + "toolUseId": "stream_123", + "name": "streaming_tool", + "input": {}, + } + tool_stream_data = {"progress": 50, "status": "processing"} + event = ToolStreamEvent(tool_use, tool_stream_data) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == tool_stream_data + assert event.tool_use_id == "stream_123" + + def test_tool_use_id_property(self): + """Test tool_use_id property returns correct ID.""" + tool_use: ToolUse = { + "toolUseId": "another_stream_456", + "name": "tool", + "input": {}, + } + event = ToolStreamEvent(tool_use, {}) + assert event.tool_use_id == "another_stream_456" + + +class TestModelMessageEvent: + """Tests for ModelMessageEvent.""" + + def test_initialization(self): + """Test ModelMessageEvent initialization.""" + message = Mock(spec=Message) + event = ModelMessageEvent(message) + assert event["message"] == message + + +class TestToolResultMessageEvent: + """Tests for ToolResultMessageEvent.""" + + def test_initialization(self): + """Test ToolResultMessageEvent initialization.""" + message = {"role": "tool", "content": "Tool result message"} + event = ToolResultMessageEvent(message) + assert event["message"] == message + + +class TestForceStopEvent: + """Tests for ForceStopEvent.""" + + def test_initialization_with_string_reason(self): + """Test ForceStopEvent initialization with string reason.""" + reason = "User requested stop" + event = ForceStopEvent(reason) + assert event["force_stop"] is True + assert event["force_stop_reason"] == "User requested stop" + + def test_initialization_with_exception(self): + """Test ForceStopEvent initialization with exception.""" + exception = ValueError("Something went wrong") + event = ForceStopEvent(exception) + assert event["force_stop"] is True + assert event["force_stop_reason"] == "Something went wrong" + + +class TestAgentResultEvent: + """Tests for AgentResultEvent.""" + + def test_initialization(self): + """Test AgentResultEvent initialization.""" + # Mock the AgentResult + agent_result = MagicMock() + agent_result.messages = [] + agent_result.stop_reason = "max_tokens" + + event = AgentResultEvent(agent_result) + assert event["result"] == agent_result + + +class TestEventSerialization: + """Tests for event serialization and conversion.""" + + def test_typed_event_serialization(self): + """Test that TypedEvent can be serialized to dict.""" + event = TypedEvent({"key": "value", "nested": {"data": 123}}) + serialized = event.as_dict() + assert serialized == {"key": "value", "nested": {"data": 123}} + + def test_complex_event_serialization(self): + """Test complex event serialization.""" + delta = Mock(spec=ContentBlockDelta) + delta.to_dict = Mock(return_value={"type": "delta"}) + + event = TextStreamEvent(delta, "Hello") + # The event should be serializable as a dict + assert isinstance(event.as_dict(), dict) + assert event["data"] == "Hello" + + def test_event_inheritance(self): + """Test that all events inherit from TypedEvent.""" + events = [ + InitEventLoopEvent(), + StartEvent(), + StartEventLoopEvent(), + StructuredOutputEvent(SampleModel(name="test", value=1)), + EventLoopThrottleEvent(5), + ForceStopEvent("test"), + ] + + for event in events: + assert isinstance(event, TypedEvent) + assert isinstance(event, dict) + assert hasattr(event, "is_callback_event") + assert hasattr(event, "as_dict") + assert hasattr(event, "prepare") diff --git a/tests/strands/types/test_exceptions.py b/tests/strands/types/test_exceptions.py new file mode 100644 index 000000000..29f68a7d0 --- /dev/null +++ b/tests/strands/types/test_exceptions.py @@ -0,0 +1,387 @@ +"""Tests for exception types in the strands.types.exceptions module.""" + +import pytest + +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + MCPClientInitializationError, + ModelThrottledException, + SessionException, + StructuredOutputException, +) + + +class TestEventLoopException: + """Tests for EventLoopException class.""" + + def test_initialization_with_request_state(self): + """Test EventLoopException initialization with request state.""" + original_exception = ValueError("Original error") + request_state = {"session_id": "123", "user": "test_user"} + + exception = EventLoopException(original_exception, request_state) + + assert exception.original_exception == original_exception + assert exception.request_state == request_state + assert str(exception) == "Original error" + + def test_initialization_without_request_state(self): + """Test EventLoopException initialization without request state.""" + original_exception = RuntimeError("Runtime error") + + exception = EventLoopException(original_exception) + + assert exception.original_exception == original_exception + assert exception.request_state == {} + assert str(exception) == "Runtime error" + + def test_initialization_with_none_request_state(self): + """Test EventLoopException initialization with None request state.""" + original_exception = TypeError("Type error") + + exception = EventLoopException(original_exception, None) + + assert exception.original_exception == original_exception + assert exception.request_state == {} + assert str(exception) == "Type error" + + def test_inheritance(self): + """Test that EventLoopException inherits from Exception.""" + original_exception = Exception("Test") + exception = EventLoopException(original_exception) + + assert isinstance(exception, Exception) + assert issubclass(EventLoopException, Exception) + + def test_exception_message_from_original(self): + """Test that exception message comes from original exception.""" + original_exception = ValueError("Custom error message") + exception = EventLoopException(original_exception) + + assert str(exception) == "Custom error message" + assert exception.args[0] == "Custom error message" + + +class TestMaxTokensReachedException: + """Tests for MaxTokensReachedException class.""" + + def test_initialization_with_message(self): + """Test MaxTokensReachedException initialization with message.""" + message = "Maximum tokens limit of 4096 reached" + exception = MaxTokensReachedException(message) + + assert str(exception) == message + assert exception.args[0] == message + + def test_inheritance(self): + """Test that MaxTokensReachedException inherits from Exception.""" + exception = MaxTokensReachedException("Test message") + + assert isinstance(exception, Exception) + assert issubclass(MaxTokensReachedException, Exception) + + def test_exception_with_detailed_message(self): + """Test exception with detailed message about token limits.""" + message = ( + "Model reached maximum token limit of 8192 tokens. " + "Consider reducing input size or increasing max_tokens parameter." + ) + exception = MaxTokensReachedException(message) + + assert str(exception) == message + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(MaxTokensReachedException) as exc_info: + raise MaxTokensReachedException("Token limit exceeded") + + assert str(exc_info.value) == "Token limit exceeded" + + +class TestContextWindowOverflowException: + """Tests for ContextWindowOverflowException class.""" + + def test_initialization(self): + """Test ContextWindowOverflowException initialization.""" + exception = ContextWindowOverflowException() + + assert isinstance(exception, Exception) + assert str(exception) == "" + + def test_initialization_with_message(self): + """Test ContextWindowOverflowException with custom message.""" + exception = ContextWindowOverflowException("Context window exceeded 100k tokens") + + assert str(exception) == "Context window exceeded 100k tokens" + + def test_inheritance(self): + """Test that ContextWindowOverflowException inherits from Exception.""" + exception = ContextWindowOverflowException() + + assert isinstance(exception, Exception) + assert issubclass(ContextWindowOverflowException, Exception) + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(ContextWindowOverflowException) as exc_info: + raise ContextWindowOverflowException("Input too large for model") + + assert str(exc_info.value) == "Input too large for model" + + +class TestMCPClientInitializationError: + """Tests for MCPClientInitializationError class.""" + + def test_initialization(self): + """Test MCPClientInitializationError initialization.""" + exception = MCPClientInitializationError() + + assert isinstance(exception, Exception) + assert str(exception) == "" + + def test_initialization_with_message(self): + """Test MCPClientInitializationError with custom message.""" + exception = MCPClientInitializationError("Failed to connect to MCP server") + + assert str(exception) == "Failed to connect to MCP server" + + def test_inheritance(self): + """Test that MCPClientInitializationError inherits from Exception.""" + exception = MCPClientInitializationError() + + assert isinstance(exception, Exception) + assert issubclass(MCPClientInitializationError, Exception) + + def test_exception_with_detailed_error(self): + """Test exception with detailed initialization error.""" + message = "MCP server initialization failed: Connection refused on port 8080" + exception = MCPClientInitializationError(message) + + assert str(exception) == message + + +class TestModelThrottledException: + """Tests for ModelThrottledException class.""" + + def test_initialization_with_message(self): + """Test ModelThrottledException initialization with message.""" + message = "Rate limit exceeded. Please retry after 60 seconds." + exception = ModelThrottledException(message) + + assert exception.message == message + assert str(exception) == message + assert exception.args[0] == message + + def test_inheritance(self): + """Test that ModelThrottledException inherits from Exception.""" + exception = ModelThrottledException("Throttled") + + assert isinstance(exception, Exception) + assert issubclass(ModelThrottledException, Exception) + + def test_message_property(self): + """Test that message property is accessible.""" + message = "API rate limit: 10 requests per minute" + exception = ModelThrottledException(message) + + assert exception.message == message + assert hasattr(exception, "message") + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(ModelThrottledException) as exc_info: + raise ModelThrottledException("Service temporarily unavailable") + + assert exc_info.value.message == "Service temporarily unavailable" + assert str(exc_info.value) == "Service temporarily unavailable" + + +class TestSessionException: + """Tests for SessionException class.""" + + def test_initialization(self): + """Test SessionException initialization.""" + exception = SessionException() + + assert isinstance(exception, Exception) + assert str(exception) == "" + + def test_initialization_with_message(self): + """Test SessionException with custom message.""" + exception = SessionException("Session expired") + + assert str(exception) == "Session expired" + + def test_inheritance(self): + """Test that SessionException inherits from Exception.""" + exception = SessionException() + + assert isinstance(exception, Exception) + assert issubclass(SessionException, Exception) + + def test_exception_with_detailed_message(self): + """Test exception with detailed session error.""" + message = "Failed to restore session: Invalid session ID or session has expired" + exception = SessionException(message) + + assert str(exception) == message + + +class TestStructuredOutputException: + """Tests for StructuredOutputException class.""" + + def test_initialization_with_message(self): + """Test StructuredOutputException initialization with message.""" + message = "Failed to validate structured output after 3 attempts" + exception = StructuredOutputException(message) + + assert exception.message == message + assert str(exception) == message + assert exception.args[0] == message + + def test_inheritance(self): + """Test that StructuredOutputException inherits from Exception.""" + exception = StructuredOutputException("Validation failed") + + assert isinstance(exception, Exception) + assert issubclass(StructuredOutputException, Exception) + + def test_message_property(self): + """Test that message property is accessible.""" + message = "Pydantic validation error: field 'name' is required" + exception = StructuredOutputException(message) + + assert exception.message == message + assert hasattr(exception, "message") + + def test_exception_with_validation_details(self): + """Test exception with detailed validation error message.""" + message = ( + "Structured output validation failed:\n" + "- Field 'age' must be a positive integer\n" + "- Field 'email' must be a valid email address" + ) + exception = StructuredOutputException(message) + + assert exception.message == message + assert str(exception) == message + + def test_exception_raised_properly(self): + """Test that exception can be raised and caught properly.""" + with pytest.raises(StructuredOutputException) as exc_info: + raise StructuredOutputException("Invalid output format") + + assert exc_info.value.message == "Invalid output format" + assert str(exc_info.value) == "Invalid output format" + + +class TestExceptionInheritance: + """Tests for verifying exception inheritance hierarchy.""" + + def test_all_exceptions_inherit_from_exception(self): + """Test that all custom exceptions inherit from Exception.""" + exception_classes = [ + EventLoopException, + MaxTokensReachedException, + ContextWindowOverflowException, + MCPClientInitializationError, + ModelThrottledException, + SessionException, + StructuredOutputException, + ] + + for exc_class in exception_classes: + assert issubclass(exc_class, Exception), f"{exc_class.__name__} should inherit from Exception" + + def test_exception_instances_are_exceptions(self): + """Test that all exception instances are instances of Exception.""" + exceptions = [ + EventLoopException(ValueError("test")), + MaxTokensReachedException("test"), + ContextWindowOverflowException("test"), + MCPClientInitializationError("test"), + ModelThrottledException("test"), + SessionException("test"), + StructuredOutputException("test"), + ] + + for exception in exceptions: + assert isinstance(exception, Exception), f"{type(exception).__name__} instance should be an Exception" + + def test_exceptions_can_be_caught_as_exception(self): + """Test that all custom exceptions can be caught as generic Exception.""" + exceptions_to_raise = [ + (EventLoopException, ValueError("test"), None), + (MaxTokensReachedException, "test", None), + (ContextWindowOverflowException, "test", None), + (MCPClientInitializationError, "test", None), + (ModelThrottledException, "test", None), + (SessionException, "test", None), + (StructuredOutputException, "test", None), + ] + + for exc_class, *args in exceptions_to_raise: + try: + if exc_class == EventLoopException: + raise exc_class(*args) + else: + raise exc_class(args[0]) + except Exception as e: + assert isinstance(e, exc_class) + assert isinstance(e, Exception) + + +class TestExceptionMessages: + """Tests for exception messages and representations.""" + + def test_exception_str_representations(self): + """Test string representations of all exceptions.""" + exceptions = [ + (EventLoopException(ValueError("event loop error")), "event loop error"), + (MaxTokensReachedException("max tokens"), "max tokens"), + (ContextWindowOverflowException("overflow"), "overflow"), + (MCPClientInitializationError("init error"), "init error"), + (ModelThrottledException("throttled"), "throttled"), + (SessionException("session error"), "session error"), + (StructuredOutputException("output error"), "output error"), + ] + + for exception, expected_str in exceptions: + assert str(exception) == expected_str + + def test_exception_repr_contains_class_name(self): + """Test that repr contains the exception class name.""" + exceptions = [ + EventLoopException(ValueError("test")), + MaxTokensReachedException("test"), + ContextWindowOverflowException("test"), + MCPClientInitializationError("test"), + ModelThrottledException("test"), + SessionException("test"), + StructuredOutputException("test"), + ] + + for exception in exceptions: + class_name = type(exception).__name__ + assert class_name in repr(exception) + + def test_exceptions_with_custom_properties(self): + """Test exceptions with custom properties maintain those properties.""" + # EventLoopException with properties + event_loop_exc = EventLoopException(ValueError("test"), {"key": "value"}) + assert hasattr(event_loop_exc, "original_exception") + assert hasattr(event_loop_exc, "request_state") + assert event_loop_exc.original_exception.args[0] == "test" + assert event_loop_exc.request_state == {"key": "value"} + + # ModelThrottledException with message property + throttled_exc = ModelThrottledException("throttle message") + assert hasattr(throttled_exc, "message") + assert throttled_exc.message == "throttle message" + + # StructuredOutputException with message property + structured_exc = StructuredOutputException("validation message") + assert hasattr(structured_exc, "message") + assert structured_exc.message == "validation message" diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index 4df6dd69b..36c21fb7f 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -58,3 +58,20 @@ class Weather(BaseModel): result = agent.structured_output(Weather, "How are you?") assert isinstance(result, Weather) + + +def test_structured_output_is_forced_when_provided_in_agent_invocation(skip_for, model): + """Tests that structured_output is always forced to return a value even if model doesn't have any information.""" + + class UserProfile(BaseModel): + """Basic user profile model.""" + + name: str + age: int + occupation: str + + agent = Agent() + result = agent("Create a profile for John who is a 25 year old dentist", structured_output_model=UserProfile) + assert result.structured_output.name == "John" + assert result.structured_output.age == 25 + assert result.structured_output.occupation == "dentist" diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py new file mode 100644 index 000000000..188f57777 --- /dev/null +++ b/tests_integ/test_structured_output_agent_loop.py @@ -0,0 +1,330 @@ +""" +Comprehensive integration tests for structured output passed into the agent functionality. +""" + +from typing import List, Optional + +import pytest +from pydantic import BaseModel, Field, field_validator + +from strands import Agent +from strands.tools import tool + +# ========== Pydantic Models from notebook ========== + + +class MathResult(BaseModel): + """Math operation result.""" + + operation: str = Field(description="the performed operation") + result: int = Field(description="the result of the operation") + + +class UserProfile(BaseModel): + """Basic user profile model.""" + + name: str + age: int + occupation: str + active: bool = True + + +class Address(BaseModel): + """Address information.""" + + street: str + city: str + state: str + zip_code: str + + +class Contact(BaseModel): + """Contact information.""" + + email: str + phone: Optional[str] = None + preferred_method: str = "email" + + +class Employee(BaseModel): + """Complex nested employee model.""" + + name: str + employee_id: int + department: str + address: Address + contact: Contact + skills: List[str] + hire_date: str + salary_range: str + + +class ProductReview(BaseModel): + """Product review analysis.""" + + product_name: str + rating: int = Field(ge=1, le=5, description="Rating from 1-5 stars") + sentiment: str = Field(pattern="^(positive|negative|neutral)$") + key_points: List[str] + would_recommend: bool + + +class WeatherForecast(BaseModel): + """Weather forecast data.""" + + location: str + temperature: int + condition: str + humidity: int + wind_speed: int + forecast_date: str + + +class TaskList(BaseModel): + """Task management structure.""" + + project_name: str + tasks: List[str] + priority: str = Field(pattern="^(high|medium|low)$") + due_date: str + estimated_hours: int + + +class Person(BaseModel): + """A person's basic information.""" + + name: str = Field(description="Full name") + age: int = Field(description="Age in years", ge=0, le=150) + + +class Company(BaseModel): + """A company or organization.""" + + name: str = Field(description="Company name") + address: Address = Field(description="Company address") + employees: List[Person] = Field(description="list of persons") + + +class Task(BaseModel): + """A task or todo item.""" + + title: str = Field(description="Task title") + description: str = Field(description="Detailed description") + priority: str = Field(description="Priority level: low, medium, high") + completed: bool = Field(description="Whether task is completed", default=False) + + +class NameWithValidation(BaseModel): + """Name model with validation that forces retry.""" + + first_name: str + + @field_validator("first_name") + @classmethod + def validate_first_name(cls, value: str) -> str: + if not value.endswith("abc"): + raise ValueError("You must append 'abc' to the end of my name") + return value + + +# ========== Tool Definitions ========== + + +@tool +def calculator(operation: str, a: float, b: float) -> float: + """Simple calculator tool for testing.""" + if operation == "add": + return a + b + elif operation == "subtract": + return a - b + elif operation == "multiply": + return a * b + elif operation == "divide": + return b / a if a != 0 else 0 + elif operation == "power": + return a**b + else: + return 0 + + +# ========== Test Classes ========== + + +class TestBasicStructuredOutput: + """Test basic structured output functionality.""" + + def test_regular_call_without_structured_output(self): + """Test that regular calls work without structured output.""" + agent = Agent() + result = agent("What can you do for me?") + + assert result.structured_output is None + assert agent._default_structured_output_model is None + + def test_simple_structured_output(self): + """Test basic structured output with UserProfile.""" + agent = Agent() + + result = agent( + "Create a profile for John Doe who is a 25 year old dentist", structured_output_model=UserProfile + ) + + assert result.structured_output is not None + assert isinstance(result.structured_output, UserProfile) + assert result.structured_output.name == "John Doe" + assert result.structured_output.age == 25 + assert result.structured_output.occupation.lower() == "dentist" + + def test_follow_up_without_structured_output(self): + """Test that follow-up calls work without structured output.""" + agent = Agent() + + # First call with structured output + result1 = agent( + "Create a profile for John Doe who is a 25 year old dentist", structured_output_model=UserProfile + ) + assert result1.structured_output is not None + + # Second call without structured output + result2 = agent("what did you just do?") + assert result2.structured_output is None + + +class TestToolUsage: + """Test structured output with tool usage.""" + + def test_tool_use_without_structured_output(self): + """Test tool usage without structured output.""" + agent = Agent(tools=[calculator]) + + result = agent("What is 2 + 2? Use the calculator tool.") + + assert result.structured_output is None + # Check that tool was called (in metrics) + assert result.metrics.tool_metrics is not None + assert len(result.metrics.tool_metrics) > 0 + + def test_tool_use_with_structured_output(self): + """Test tool usage with structured output.""" + agent = Agent(tools=[calculator]) + + result = agent("Calculate 2 + 2 using the calculator tool", structured_output_model=MathResult) + + assert result.structured_output is not None + assert isinstance(result.structured_output, MathResult) + assert result.structured_output.result == 4 + # Check that tool was called + assert result.metrics.tool_metrics is not None + assert len(result.metrics.tool_metrics) > 0 + + +class TestAsyncOperations: + """Test async operations with structured output.""" + + @pytest.mark.asyncio + async def test_async_structured_output(self): + """Test async invocation with structured output.""" + agent = Agent() + + result = await agent.invoke_async( + """ + Analyze this product review: + "This wireless mouse is fantastic! Great battery life, smooth tracking, + and the ergonomic design is perfect for long work sessions. The price + is reasonable too. I'd definitely buy it again and recommend it to others. + Rating: 5 stars" + """, + structured_output_model=ProductReview, + ) + + assert result.structured_output is not None + assert isinstance(result.structured_output, ProductReview) + assert result.structured_output.rating == 5 + assert result.structured_output.sentiment == "positive" + assert result.structured_output.would_recommend is True + + +class TestStreamingOperations: + """Test streaming with structured output.""" + + @pytest.mark.asyncio + async def test_streaming_with_structured_output(self): + """Test streaming with structured output.""" + agent = Agent() + + result_found = False + structured_output_found = False + + async for event in agent.stream_async( + "Generate a weather forecast for Seattle: 68°F, partly cloudy, 55% humidity, 8 mph winds, for tomorrow", + structured_output_model=WeatherForecast, + ): + if "result" in event: + result_found = True + if event["result"].structured_output: + structured_output_found = True + forecast = event["result"].structured_output + assert isinstance(forecast, WeatherForecast) + assert forecast.location == "Seattle" + + assert result_found, "No result event found in stream" + assert structured_output_found, "No structured output found in stream result" + + +class TestMultipleInvocations: + """Test multiple invocations with different structured output models.""" + + def test_multiple_invocations_different_models(self): + """Test using different structured output models in consecutive calls.""" + agent = Agent() + + # First invocation with Person model + person_result = agent("Extract person: John Doe, 35, john@test.com", structured_output_model=Person) + assert person_result.structured_output is not None + assert isinstance(person_result.structured_output, Person) + assert person_result.structured_output.name == "John Doe" + assert person_result.structured_output.age == 35 + + # Second invocation with Task model + task_result = agent("Create task: Review code, high priority, completed", structured_output_model=Task) + assert task_result.structured_output is not None + assert isinstance(task_result.structured_output, Task) + assert task_result.structured_output.title == "Review code" + assert task_result.structured_output.priority == "high" + assert task_result.structured_output.completed is True + + # Third invocation without structured output + normal_result = agent("What tasks do we have?") + assert normal_result.structured_output is None + + +class TestAgentInitialization: + """Test agent initialization with default structured output model.""" + + def test_agent_with_default_structured_output(self): + """Test agent initialized with default structured output model.""" + agent = Agent(structured_output_model=UserProfile) + + result = agent("Create a profile for John Doe who is a 25 year old dentist") + + assert result.structured_output is not None + assert isinstance(result.structured_output, UserProfile) + assert result.structured_output.name == "John Doe" + assert result.structured_output.age == 25 + assert result.structured_output.occupation.lower() == "dentist" + + +class TestValidationRetry: + """Test validation with retry logic.""" + + def test_validation_forces_retry(self): + """Test that validation errors force the model to retry.""" + agent = Agent() + + result = agent("What's Aaron's name?", structured_output_model=NameWithValidation) + + assert result.structured_output is not None + assert isinstance(result.structured_output, NameWithValidation) + # The model should have learned to append 'abc' after validation failure + assert result.structured_output.first_name.endswith("abc") + assert "Aaron" in result.structured_output.first_name or "aaron" in result.structured_output.first_name.lower() From de802fbef2b13dc80adf48ead691ba3d4f496d30 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 23 Oct 2025 07:32:58 -0400 Subject: [PATCH 23/26] integ tests - interrupts - remove asyncio marker (#1045) --- tests_integ/interrupts/test_hook.py | 2 -- tests_integ/interrupts/test_session.py | 1 - tests_integ/interrupts/test_tool.py | 1 - 3 files changed, 4 deletions(-) diff --git a/tests_integ/interrupts/test_hook.py b/tests_integ/interrupts/test_hook.py index 836d7d415..f4341ac76 100644 --- a/tests_integ/interrupts/test_hook.py +++ b/tests_integ/interrupts/test_hook.py @@ -48,7 +48,6 @@ def agent(interrupt_hook, time_tool, weather_tool): return Agent(hooks=[interrupt_hook], tools=[time_tool, weather_tool]) -@pytest.mark.asyncio def test_interrupt(agent): result = agent("What is the time and weather?") @@ -112,7 +111,6 @@ def test_interrupt(agent): assert tru_tool_result_message == exp_tool_result_message -@pytest.mark.asyncio def test_interrupt_reject(agent): result = agent("What is the time and weather?") diff --git a/tests_integ/interrupts/test_session.py b/tests_integ/interrupts/test_session.py index 83d2cc73d..714363fd8 100644 --- a/tests_integ/interrupts/test_session.py +++ b/tests_integ/interrupts/test_session.py @@ -47,7 +47,6 @@ def agent(interrupt_hook, time_tool, weather_tool): return Agent(hooks=[interrupt_hook], tools=[time_tool, weather_tool]) -@pytest.mark.asyncio def test_interrupt_session(interrupt_hook, time_tool, weather_tool, tmpdir): session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) diff --git a/tests_integ/interrupts/test_tool.py b/tests_integ/interrupts/test_tool.py index 00dbfcc90..e200f50a6 100644 --- a/tests_integ/interrupts/test_tool.py +++ b/tests_integ/interrupts/test_tool.py @@ -58,7 +58,6 @@ def agent(interrupt_hook, time_tool, day_tool, weather_tool): return Agent(hooks=[interrupt_hook], tools=[time_tool, day_tool, weather_tool]) -@pytest.mark.asyncio def test_interrupt(agent): result = agent("What is the time, day, and weather?") From d4ef8bf807fd460f2bb5b39913207f2a4beb5fbd Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 24 Oct 2025 09:44:09 -0400 Subject: [PATCH 24/26] interrupt - docstring - fix formatting (#1074) --- src/strands/types/interrupt.py | 67 +++++++--------------------------- 1 file changed, 14 insertions(+), 53 deletions(-) diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py index 2968ed219..001ce6993 100644 --- a/src/strands/types/interrupt.py +++ b/src/strands/types/interrupt.py @@ -1,60 +1,22 @@ """Interrupt related type definitions for human-in-the-loop workflows. Interrupt Flow: - ┌─────────────────┐ - │ Agent Invoke │ - └────────┬────────┘ - │ - ▼ - ┌─────────────────┐ - │ Hook Calls │ - | on Event | - └────────┬────────┘ - │ - ▼ - ┌─────────────────┐ No ┌─────────────────┐ - │ Interrupts │ ────────► │ Continue │ - │ Raised? │ │ Execution │ - └────────┬────────┘ └─────────────────┘ - │ Yes - ▼ - ┌─────────────────┐ - │ Stop Event Loop │◄───────────────────┐ - └────────┬────────┘ | - │ | - ▼ | - ┌─────────────────┐ | - │ Return | | - | Interrupts │ | - └────────┬────────┘ | - │ | - ▼ | - ┌─────────────────┐ | - │ Agent Invoke │ | - │ with Responses │ | - └────────┬────────┘ | - │ | - ▼ | - ┌─────────────────┐ | - │ Hook Calls │ | - | on Event | | - | with Responses | | - └────────┬────────┘ | - │ | - ▼ | - ┌─────────────────┐ Yes ┌────────┴────────┐ - │ New Interrupts │ ────────► │ Store State │ - │ Raised? │ │ │ - └────────┬────────┘ └─────────────────┘ - │ No - ▼ - ┌─────────────────┐ - │ Continue │ - │ Execution │ - └─────────────────┘ + ```mermaid + flowchart TD + A[Invoke Agent] --> B[Execute Hook/Tool] + B --> C{Interrupts Raised?} + C -->|No| D[Continue Agent Loop] + C -->|Yes| E[Stop Agent Loop] + E --> F[Return Interrupts] + F --> G[Respond to Interrupts] + G --> H[Execute Hook/Tool with Responses] + H --> I{New Interrupts?} + I -->|Yes| E + I -->|No| D + ``` Example: - ``` + ```Python from typing import Any from strands import Agent, tool @@ -99,7 +61,6 @@ def approve(self, event: BeforeToolCallEvent) -> None: ``` Details: - - User raises interrupt on their hook event by calling `event.interrupt()`. - User can raise one interrupt per hook callback. - Interrupts stop the agent event loop. From 1544384a8024e18ce3224c7d11e9ade4aa0440e8 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 24 Oct 2025 10:08:54 -0400 Subject: [PATCH 25/26] ci: add pr size labeler (#1082) --- .github/workflows/pr-size-labeler.yml | 58 +++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 .github/workflows/pr-size-labeler.yml diff --git a/.github/workflows/pr-size-labeler.yml b/.github/workflows/pr-size-labeler.yml new file mode 100644 index 000000000..bc4d52c6d --- /dev/null +++ b/.github/workflows/pr-size-labeler.yml @@ -0,0 +1,58 @@ +name: PR Size Labeler + +on: + pull_request_target: + branches: main + +jobs: + label-size: + runs-on: ubuntu-latest + permissions: + pull-requests: write + issues: write + steps: + - name: Calculate PR size and apply label + uses: actions/github-script@v8 + with: + script: | + const pr = context.payload.pull_request; + const totalChanges = pr.additions + pr.deletions; + + // Remove existing size labels + const labels = await github.rest.issues.listLabelsOnIssue({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number + }); + + for (const label of labels.data) { + if (label.name.startsWith('size/')) { + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + name: label.name + }); + } + } + + // Determine and apply new size label + let sizeLabel; + if (totalChanges <= 20) sizeLabel = 'size/xs'; + else if (totalChanges <= 100) sizeLabel = 'size/s'; + else if (totalChanges <= 500) sizeLabel = 'size/m'; + else if (totalChanges <= 1000) sizeLabel = 'size/l'; + else { + sizeLabel = 'size/xl'; + } + + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + labels: [sizeLabel] + }); + + if (sizeLabel === 'size/xl') { + core.setFailed(`PR is too large (${totalChanges} lines). Please split into smaller PRs.`); + } From 999e6548fee448098b09ab62244f80a8e2794614 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 24 Oct 2025 10:52:36 -0400 Subject: [PATCH 26/26] fix: Don't bail out if there are no tool_uses (#1087) Partial fix to #1069 - previously the agent would prematurely exit if the agent generated a tool with an invalid name; this avoids that by ensuring the agent loop continues with zero tool-uses. --------- Co-authored-by: Mackenzie Zastrow --- src/strands/event_loop/event_loop.py | 3 -- tests/fixtures/mocked_model_provider.py | 6 +-- tests/strands/agent/test_agent.py | 47 ++++++++++++++++++ tests/strands/event_loop/test_event_loop.py | 55 ++++++++++++++++++++- tests/strands/types/__init__.py | 0 5 files changed, 104 insertions(+), 7 deletions(-) delete mode 100644 tests/strands/types/__init__.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 116f7956d..5ea062283 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -427,9 +427,6 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] - if not tool_uses: - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) - return if agent._interrupt_state.activated: tool_results.extend(agent._interrupt_state.context["tool_results"]) diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index 4523a8352..56817a6e4 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union from pydantic import BaseModel @@ -25,8 +25,8 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]): - self.agent_responses = agent_responses + def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]): + self.agent_responses = [*agent_responses] self.index = 0 def format_chunk(self, event: Any) -> StreamEvent: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 9d490c0de..892ff86d1 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2065,3 +2065,50 @@ def test_agent_tool_caller_interrupt(user): exp_message = r"cannot directly call tool during interrupt" with pytest.raises(RuntimeError, match=exp_message): agent.tool.test_tool() + + +def test_agent__call__invalid_tool_name(): + @strands.tool + def shell(command: str): + pass + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_use_id", + "name": "invalid tool", + "input": "{}", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ) + + agent = Agent(tools=[shell], model=model) + result = agent("Test") + + # Ensure the stop_reason is + assert result.stop_reason == "end_turn" + + # Assert that there exists a message with a toolResponse + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "Error: tool_name= | invalid tool name pattern"}], + "status": "error", + "toolUseId": "tool_use_id", + } + } + ], + "role": "user", + } + + # And that it continued to the LLM call + assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"} diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 2d9af1741..72c63e897 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,6 +1,6 @@ import concurrent import unittest.mock -from unittest.mock import MagicMock, call, patch +from unittest.mock import ANY, MagicMock, call, patch import pytest @@ -18,6 +18,7 @@ from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry +from strands.types._events import EventLoopStopEvent from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, @@ -25,6 +26,7 @@ ModelThrottledException, ) from tests.fixtures.mock_hook_provider import MockHookProvider +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -744,6 +746,8 @@ async def test_event_loop_cycle_with_parent_span( async def test_request_state_initialization(alist): # Create a mock agent mock_agent = MagicMock() + # not setting this to False results in endless recursion + mock_agent._interrupt_state.activated = False mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) # Call without providing request_state @@ -1011,3 +1015,52 @@ def interrupt_callback(event): "interrupts": {}, } assert tru_state == exp_state + + +@pytest.mark.asyncio +async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): + model.stream = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_use_id", + "name": "invalid tool", + "input": "{}", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ).stream + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + + # ensure that we got end_turn and not tool_use + assert events[-1] == EventLoopStopEvent( + stop_reason="end_turn", + message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"}, + metrics=ANY, + request_state={}, + ) + + # Ensure that an "invalid tool name" message was added properly + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "Error: tool_name= | invalid tool name pattern"}], + "status": "error", + "toolUseId": "tool_use_id", + } + } + ], + "role": "user", + } diff --git a/tests/strands/types/__init__.py b/tests/strands/types/__init__.py deleted file mode 100644 index e69de29bb..000000000